tidb conn_stmt 源码

  • 2022-09-19
  • 浏览 (311)

tidb conn_stmt 代码

文件路径:/server/conn_stmt.go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

package server

import (
	"context"
	"encoding/binary"
	"fmt"
	"math"
	"runtime/trace"
	"strconv"
	"time"

	"github.com/pingcap/errors"
	"github.com/pingcap/tidb/expression"
	"github.com/pingcap/tidb/kv"
	"github.com/pingcap/tidb/parser/ast"
	"github.com/pingcap/tidb/parser/charset"
	"github.com/pingcap/tidb/parser/mysql"
	"github.com/pingcap/tidb/parser/terror"
	plannercore "github.com/pingcap/tidb/planner/core"
	"github.com/pingcap/tidb/sessionctx/stmtctx"
	"github.com/pingcap/tidb/sessiontxn"
	storeerr "github.com/pingcap/tidb/store/driver/error"
	"github.com/pingcap/tidb/types"
	"github.com/pingcap/tidb/util/execdetails"
	"github.com/pingcap/tidb/util/hack"
	"github.com/pingcap/tidb/util/topsql"
	topsqlstate "github.com/pingcap/tidb/util/topsql/state"
	"github.com/tikv/client-go/v2/util"
)

func (cc *clientConn) handleStmtPrepare(ctx context.Context, sql string) error {
	stmt, columns, params, err := cc.ctx.Prepare(sql)
	if err != nil {
		return err
	}
	data := make([]byte, 4, 128)

	// status ok
	data = append(data, 0)
	// stmt id
	data = dumpUint32(data, uint32(stmt.ID()))
	// number columns
	data = dumpUint16(data, uint16(len(columns)))
	// number params
	data = dumpUint16(data, uint16(len(params)))
	// filter [00]
	data = append(data, 0)
	// warning count
	data = append(data, 0, 0) // TODO support warning count

	if err := cc.writePacket(data); err != nil {
		return err
	}

	cc.initResultEncoder(ctx)
	defer cc.rsEncoder.clean()
	if len(params) > 0 {
		for i := 0; i < len(params); i++ {
			data = data[0:4]
			data = params[i].Dump(data, cc.rsEncoder)

			if err := cc.writePacket(data); err != nil {
				return err
			}
		}

		if cc.capability&mysql.ClientDeprecateEOF == 0 {
			// metadata only needs EOF marker for old clients without ClientDeprecateEOF
			if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil {
				return err
			}
		}
	}

	if len(columns) > 0 {
		for i := 0; i < len(columns); i++ {
			data = data[0:4]
			data = columns[i].Dump(data, cc.rsEncoder)

			if err := cc.writePacket(data); err != nil {
				return err
			}
		}

		if cc.capability&mysql.ClientDeprecateEOF == 0 {
			// metadata only needs EOF marker for old clients without ClientDeprecateEOF
			if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil {
				return err
			}
		}
	}
	return cc.flush(ctx)
}

func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err error) {
	defer trace.StartRegion(ctx, "HandleStmtExecute").End()
	if len(data) < 9 {
		return mysql.ErrMalformPacket
	}
	pos := 0
	stmtID := binary.LittleEndian.Uint32(data[0:4])
	pos += 4

	stmt := cc.ctx.GetStatement(int(stmtID))
	if stmt == nil {
		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
			strconv.FormatUint(uint64(stmtID), 10), "stmt_execute")
	}

	flag := data[pos]
	pos++
	// Please refer to https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
	// The client indicates that it wants to use cursor by setting this flag.
	// Now we only support forward-only, read-only cursor.
	useCursor := false
	if flag&mysql.CursorTypeReadOnly > 0 {
		useCursor = true
	}
	if flag&mysql.CursorTypeForUpdate > 0 {
		return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeForUpdate", nil)
	}
	if flag&mysql.CursorTypeScrollable > 0 {
		return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil)
	}

	// skip iteration-count, always 1
	pos += 4

	var (
		nullBitmaps []byte
		paramTypes  []byte
		paramValues []byte
	)
	cc.initInputEncoder(ctx)
	numParams := stmt.NumParams()
	args := make([]expression.Expression, numParams)
	if numParams > 0 {
		nullBitmapLen := (numParams + 7) >> 3
		if len(data) < (pos + nullBitmapLen + 1) {
			return mysql.ErrMalformPacket
		}
		nullBitmaps = data[pos : pos+nullBitmapLen]
		pos += nullBitmapLen

		// new param bound flag
		if data[pos] == 1 {
			pos++
			if len(data) < (pos + (numParams << 1)) {
				return mysql.ErrMalformPacket
			}

			paramTypes = data[pos : pos+(numParams<<1)]
			pos += numParams << 1
			paramValues = data[pos:]
			// Just the first StmtExecute packet contain parameters type,
			// we need save it for further use.
			stmt.SetParamsType(paramTypes)
		} else {
			paramValues = data[pos+1:]
		}

		err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
		stmt.Reset()
		if err != nil {
			return errors.Annotate(err, cc.preparedStmt2String(stmtID))
		}
	}
	return cc.executePlanCacheStmt(ctx, stmt, args, useCursor)
}

func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt interface{}, args []expression.Expression, useCursor bool) (err error) {
	ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
	ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{})
	retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
	if err != nil {
		action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(sessiontxn.StmtErrAfterQuery, err)
		if txnErr != nil {
			return txnErr
		}

		if retryable && action == sessiontxn.StmtActionRetryReady {
			cc.ctx.GetSessionVars().RetryInfo.Retrying = true
			_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
			cc.ctx.GetSessionVars().RetryInfo.Retrying = false
			return err
		}
	}
	_, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash]
	if allowTiFlashFallback && err != nil && errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) && retryable {
		// When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash
		// server and fallback to TiKV.
		prevErr := err
		delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash)
		defer func() {
			cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{}
		}()
		_, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor)
		// We append warning after the retry because `ResetContextOfStmt` may be called during the retry, which clears warnings.
		cc.ctx.GetSessionVars().StmtCtx.AppendError(prevErr)
	}
	return err
}

// The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried.
// Currently the first return value is used to fallback to TiKV when TiFlash is down.
func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []expression.Expression, useCursor bool) (bool, error) {
	prepStmt, err := (&cc.ctx).GetSessionVars().GetPreparedStmtByID(uint32(stmt.ID()))
	if err != nil {
		return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
	}
	execStmt := &ast.ExecuteStmt{
		BinaryArgs: args,
		PrepStmt:   prepStmt,
	}
	rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt)
	if err != nil {
		return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
	}
	if rs == nil {
		return false, cc.writeOK(ctx)
	}
	if result, ok := rs.(*tidbResultSet); ok {
		if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok {
			result.preparedStmt = planCacheStmt
		}
	}

	// if the client wants to use cursor
	// we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo.
	// Tell the client cursor exists in server by setting proper serverStatus.
	if useCursor {
		cc.initResultEncoder(ctx)
		defer cc.rsEncoder.clean()
		stmt.StoreResultSet(rs)
		if err = cc.writeColumnInfo(rs.Columns()); err != nil {
			return false, err
		}
		if cl, ok := rs.(fetchNotifier); ok {
			cl.OnFetchReturned()
		}
		// explicitly flush columnInfo to client.
		err = cc.writeEOF(ctx, cc.ctx.Status()|mysql.ServerStatusCursorExists)
		if err != nil {
			return false, err
		}
		return false, cc.flush(ctx)
	}
	defer terror.Call(rs.Close)
	retryable, err := cc.writeResultset(ctx, rs, true, cc.ctx.Status(), 0)
	if err != nil {
		return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID())))
	}
	return false, nil
}

// maxFetchSize constants
const (
	maxFetchSize = 1024
)

func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) {
	cc.ctx.GetSessionVars().StartTime = time.Now()

	stmtID, fetchSize, err := parseStmtFetchCmd(data)
	if err != nil {
		return err
	}

	stmt := cc.ctx.GetStatement(int(stmtID))
	if stmt == nil {
		return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler,
			strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID))
	}
	if topsqlstate.TopSQLEnabled() {
		prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID)
		if prepareObj != nil && prepareObj.SQLDigest != nil {
			ctx = topsql.AttachAndRegisterSQLInfo(ctx, prepareObj.NormalizedSQL, prepareObj.SQLDigest, false)
		}
	}
	sql := ""
	if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok {
		sql = prepared.sql
	}
	cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0)
	rs := stmt.GetResultSet()
	if rs == nil {
		return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler,
			strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID))
	}

	_, err = cc.writeResultset(ctx, rs, true, cc.ctx.Status()|mysql.ServerStatusCursorExists, int(fetchSize))
	if err != nil {
		return errors.Annotate(err, cc.preparedStmt2String(stmtID))
	}
	return nil
}

func parseStmtFetchCmd(data []byte) (stmtID uint32, fetchSize uint32, err error) {
	if len(data) != 8 {
		return 0, 0, mysql.ErrMalformPacket
	}
	// Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html
	stmtID = binary.LittleEndian.Uint32(data[0:4])
	fetchSize = binary.LittleEndian.Uint32(data[4:8])
	if fetchSize > maxFetchSize {
		fetchSize = maxFetchSize
	}
	return
}

func parseExecArgs(sc *stmtctx.StatementContext, params []expression.Expression, boundParams [][]byte,
	nullBitmap, paramTypes, paramValues []byte, enc *inputDecoder) (err error) {
	pos := 0
	var (
		tmp    interface{}
		v      []byte
		n      int
		isNull bool
	)
	if enc == nil {
		enc = newInputDecoder(charset.CharsetUTF8)
	}

	args := make([]types.Datum, len(params))
	for i := 0; i < len(args); i++ {
		// if params had received via ComStmtSendLongData, use them directly.
		// ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
		// see clientConn#handleStmtSendLongData
		if boundParams[i] != nil {
			args[i] = types.NewBytesDatum(enc.decodeInput(boundParams[i]))
			continue
		}

		// check nullBitMap to determine the NULL arguments.
		// ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
		// notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData,
		// so this check need place after boundParam's check.
		if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
			var nilDatum types.Datum
			nilDatum.SetNull()
			args[i] = nilDatum
			continue
		}

		if (i<<1)+1 >= len(paramTypes) {
			return mysql.ErrMalformPacket
		}

		tp := paramTypes[i<<1]
		isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0

		switch tp {
		case mysql.TypeNull:
			var nilDatum types.Datum
			nilDatum.SetNull()
			args[i] = nilDatum
			continue

		case mysql.TypeTiny:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}

			if isUnsigned {
				args[i] = types.NewUintDatum(uint64(paramValues[pos]))
			} else {
				args[i] = types.NewIntDatum(int64(int8(paramValues[pos])))
			}

			pos++
			continue

		case mysql.TypeShort, mysql.TypeYear:
			if len(paramValues) < (pos + 2) {
				err = mysql.ErrMalformPacket
				return
			}
			valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
			if isUnsigned {
				args[i] = types.NewUintDatum(uint64(valU16))
			} else {
				args[i] = types.NewIntDatum(int64(int16(valU16)))
			}
			pos += 2
			continue

		case mysql.TypeInt24, mysql.TypeLong:
			if len(paramValues) < (pos + 4) {
				err = mysql.ErrMalformPacket
				return
			}
			valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
			if isUnsigned {
				args[i] = types.NewUintDatum(uint64(valU32))
			} else {
				args[i] = types.NewIntDatum(int64(int32(valU32)))
			}
			pos += 4
			continue

		case mysql.TypeLonglong:
			if len(paramValues) < (pos + 8) {
				err = mysql.ErrMalformPacket
				return
			}
			valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
			if isUnsigned {
				args[i] = types.NewUintDatum(valU64)
			} else {
				args[i] = types.NewIntDatum(int64(valU64))
			}
			pos += 8
			continue

		case mysql.TypeFloat:
			if len(paramValues) < (pos + 4) {
				err = mysql.ErrMalformPacket
				return
			}

			args[i] = types.NewFloat32Datum(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
			pos += 4
			continue

		case mysql.TypeDouble:
			if len(paramValues) < (pos + 8) {
				err = mysql.ErrMalformPacket
				return
			}

			args[i] = types.NewFloat64Datum(math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])))
			pos += 8
			continue

		case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}
			// See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
			// for more details.
			length := paramValues[pos]
			pos++
			switch length {
			case 0:
				tmp = types.ZeroDatetimeStr
			case 4:
				pos, tmp = parseBinaryDate(pos, paramValues)
			case 7:
				pos, tmp = parseBinaryDateTime(pos, paramValues)
			case 11:
				pos, tmp = parseBinaryTimestamp(pos, paramValues)
			default:
				err = mysql.ErrMalformPacket
				return
			}
			args[i] = types.NewDatum(tmp) // FIXME: After check works!!!!!!
			continue

		case mysql.TypeDuration:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}
			// See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
			// for more details.
			length := paramValues[pos]
			pos++
			switch length {
			case 0:
				tmp = "0"
			case 8:
				isNegative := paramValues[pos]
				if isNegative > 1 {
					err = mysql.ErrMalformPacket
					return
				}
				pos++
				pos, tmp = parseBinaryDuration(pos, paramValues, isNegative)
			case 12:
				isNegative := paramValues[pos]
				if isNegative > 1 {
					err = mysql.ErrMalformPacket
					return
				}
				pos++
				pos, tmp = parseBinaryDurationWithMS(pos, paramValues, isNegative)
			default:
				err = mysql.ErrMalformPacket
				return
			}
			args[i] = types.NewDatum(tmp)
			continue
		case mysql.TypeNewDecimal:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}

			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
			pos += n
			if err != nil {
				return
			}

			if isNull {
				args[i] = types.NewDecimalDatum(nil)
			} else {
				var dec types.MyDecimal
				err = sc.HandleTruncate(dec.FromString(v))
				if err != nil {
					return err
				}
				args[i] = types.NewDecimalDatum(&dec)
			}
			continue
		case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}
			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
			pos += n
			if err != nil {
				return
			}

			if isNull {
				args[i] = types.NewBytesDatum(nil)
			} else {
				args[i] = types.NewBytesDatum(v)
			}
			continue
		case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString,
			mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit:
			if len(paramValues) < (pos + 1) {
				err = mysql.ErrMalformPacket
				return
			}

			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
			pos += n
			if err != nil {
				return
			}

			if !isNull {
				v = enc.decodeInput(v)
				tmp = string(hack.String(v))
			} else {
				tmp = nil
			}
			args[i] = types.NewDatum(tmp)
			continue
		default:
			err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp)
			return
		}
	}

	for i := range params {
		ft := new(types.FieldType)
		types.DefaultParamTypeForValue(args[i].GetValue(), ft)
		params[i] = &expression.Constant{Value: args[i], RetType: ft}
	}
	return
}

func parseBinaryDate(pos int, paramValues []byte) (int, string) {
	year := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
	pos += 2
	month := paramValues[pos]
	pos++
	day := paramValues[pos]
	pos++
	return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day)
}

func parseBinaryDateTime(pos int, paramValues []byte) (int, string) {
	pos, date := parseBinaryDate(pos, paramValues)
	hour := paramValues[pos]
	pos++
	minute := paramValues[pos]
	pos++
	second := paramValues[pos]
	pos++
	return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second)
}

func parseBinaryTimestamp(pos int, paramValues []byte) (int, string) {
	pos, dateTime := parseBinaryDateTime(pos, paramValues)
	microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
	pos += 4
	return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond)
}

func parseBinaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) {
	sign := ""
	if isNegative == 1 {
		sign = "-"
	}
	days := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
	pos += 4
	hours := paramValues[pos]
	pos++
	minutes := paramValues[pos]
	pos++
	seconds := paramValues[pos]
	pos++
	return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds)
}

func parseBinaryDurationWithMS(pos int, paramValues []byte,
	isNegative uint8) (int, string) {
	pos, dur := parseBinaryDuration(pos, paramValues, isNegative)
	microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
	pos += 4
	return pos, fmt.Sprintf("%s.%06d", dur, microSecond)
}

func (cc *clientConn) handleStmtClose(data []byte) (err error) {
	if len(data) < 4 {
		return
	}

	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
	stmt := cc.ctx.GetStatement(stmtID)
	if stmt != nil {
		return stmt.Close()
	}
	return
}

func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) {
	if len(data) < 6 {
		return mysql.ErrMalformPacket
	}

	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))

	stmt := cc.ctx.GetStatement(stmtID)
	if stmt == nil {
		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
			strconv.Itoa(stmtID), "stmt_send_longdata")
	}

	paramID := int(binary.LittleEndian.Uint16(data[4:6]))
	return stmt.AppendParam(paramID, data[6:])
}

func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) {
	if len(data) < 4 {
		return mysql.ErrMalformPacket
	}

	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
	stmt := cc.ctx.GetStatement(stmtID)
	if stmt == nil {
		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
			strconv.Itoa(stmtID), "stmt_reset")
	}
	stmt.Reset()
	stmt.StoreResultSet(nil)
	return cc.writeOK(ctx)
}

// handleSetOption refer to https://dev.mysql.com/doc/internals/en/com-set-option.html
func (cc *clientConn) handleSetOption(ctx context.Context, data []byte) (err error) {
	if len(data) < 2 {
		return mysql.ErrMalformPacket
	}

	switch binary.LittleEndian.Uint16(data[:2]) {
	case 0:
		cc.capability |= mysql.ClientMultiStatements
		cc.ctx.SetClientCapability(cc.capability)
	case 1:
		cc.capability &^= mysql.ClientMultiStatements
		cc.ctx.SetClientCapability(cc.capability)
	default:
		return mysql.ErrMalformPacket
	}

	if err = cc.writeEOF(ctx, cc.ctx.Status()); err != nil {
		return err
	}

	return cc.flush(ctx)
}

func (cc *clientConn) preparedStmt2String(stmtID uint32) string {
	sv := cc.ctx.GetSessionVars()
	if sv == nil {
		return ""
	}
	if sv.EnableRedactLog {
		return cc.preparedStmt2StringNoArgs(stmtID)
	}
	return cc.preparedStmt2StringNoArgs(stmtID) + sv.PreparedParams.String()
}

func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string {
	sv := cc.ctx.GetSessionVars()
	if sv == nil {
		return ""
	}
	preparedObj, invalid := cc.preparedStmtID2CachePreparedStmt(stmtID)
	if invalid {
		return "invalidate PlanCacheStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10)
	}
	if preparedObj == nil {
		return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10)
	}
	return preparedObj.PreparedAst.Stmt.Text()
}

func (cc *clientConn) preparedStmtID2CachePreparedStmt(stmtID uint32) (_ *plannercore.PlanCacheStmt, invalid bool) {
	sv := cc.ctx.GetSessionVars()
	if sv == nil {
		return nil, false
	}
	preparedPointer, ok := sv.PreparedStmts[stmtID]
	if !ok {
		// not found
		return nil, false
	}
	preparedObj, ok := preparedPointer.(*plannercore.PlanCacheStmt)
	if !ok {
		// invalid cache. should never happen.
		return nil, true
	}
	return preparedObj, false
}

相关信息

tidb 源码目录

相关文章

tidb buffered_read_conn 源码

tidb column 源码

tidb conn 源码

tidb driver 源码

tidb driver_tidb 源码

tidb http_handler 源码

tidb http_status 源码

tidb mock_conn 源码

tidb optimize_trace 源码

tidb packetio 源码

0  赞