Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fixed a bug that params description will be sent when no params in handleStmtDescription #61

Merged
merged 1 commit into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
}
return cc.handleStmtDescription(ctx, desc)
case 'H': /* flush */
return cc.flush(ctx)
case 'S': /* sync */
return cc.writeReadyForQuery(ctx, cc.ctx.Status())
case 'X':
Expand Down
75 changes: 42 additions & 33 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import (
"github.com/pingcap/tidb/util/hack"
)

// handleStmtPrepare 预处理语句
// handleStmtPrepare handle prepare message in pgsql's extended query.
// PgSQL Modified
func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Parse) error {
//stmt, columns, params, err := cc.ctx.Prepare(parser.Query)
Expand All @@ -78,7 +78,7 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par

vars := cc.ctx.GetSessionVars()

// 将在 Prepare 阶段解析传来的参数类型在这里获取,并保留在 stmt 中
// Get param types in sqllan, and save it in `stmt`.
var paramTypes []byte
if cachedStmt, ok := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt); ok {
cachedParams := cachedStmt.PreparedAst.Params
Expand All @@ -96,17 +96,17 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par
return cc.flush(ctx)
}

// handleStmtBind 处理pgsql拓展查询协议过程中的bind过程
// handleStmtBind handle bind messages in pgsql's extended query.
// PGSQL Modified
func (cc *clientConn) handleStmtBind(ctx context.Context, bind pgproto3.Bind) (err error) {
vars := cc.ctx.GetSessionVars()

// 当为临时预处理查询,默认设置 Name 为 0
// When it is a temporary prepared stmt, the default name setting is 0
if bind.PreparedStatement == "" {
bind.PreparedStatement = "0"
}

// 获取stmtID 通过ID获取到预处理语句
// Get stmtID through stmt name.
stmtID, ok := vars.PreparedStmtNameToID[bind.PreparedStatement]
if !ok {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
Expand Down Expand Up @@ -137,19 +137,20 @@ func (cc *clientConn) handleStmtBind(ctx context.Context, bind pgproto3.Bind) (e
stmt.SetArgs(args)
}

// 当ResultFormatCode长度为1时,表示一次设定整行数据格式,如果长度大于1小于column,代表客户端传参存在问题
// When the length of `ResultFormatCodes` equ that the data format of the whole row is set at one time.
// If the length is greater than 1 and less than column length,
// it means that there is a problem in the parameter transfer of the client.
if len(bind.ResultFormatCodes) > 1 && len(bind.ResultFormatCodes) < len(stmt.GetColumnInfo()) {
return errors.New("the result format code parameter in the bind message is wrong")
}

// 对返回数据的格式进行设置, 由Bind信息中获取
// 当resultFormat 为空代表默认为Text
// 当resultFormat 只有一个值代表指定一整行的格式

stmt.SetResultFormat(bind.ResultFormatCodes)

// Bind 完成后创建 Portal,Portal Name 由客户端传送过来
// 果如 Portal Name 为空则表示为临时门户,用"0"作为默认 Name
// 这个位置没有进行真正的 Portal 生成,只是绑定 StmtID 在后面的阶段可以通过 Portal Name 找到 Statement ID
// When create `Portal`, clients will send the portal name.
// If portal name is empty, it will be set to "0" by default.
// Notice: there is not a real portal by created,
// we just put portal name and stmtID in map, then you can get stmtID.
if bind.DestinationPortal != "" {
vars.Portal[bind.DestinationPortal] = stmtID
} else {
Expand All @@ -163,24 +164,27 @@ func (cc *clientConn) handleStmtBind(ctx context.Context, bind pgproto3.Bind) (e
return cc.flush(ctx)
}

// handleStmtDescription 处理 Description 请求,通过 stmtName 找到相应的预处理语句
// 返回参数的类型信息和返回值的表结构信息,如果没有返回值,则返回 NoData
// handleStmtDescription handle Description messages in pgsql's extended query,
// find prepared stmt through `stmtName` or `portal`.
// Return `writeParameterDescription` and `WriteRowDescription` when columnInfo is not empty,
// otherwise return `writeNoData`.
func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.Describe) error {
vars := cc.ctx.GetSessionVars()

// 无论是 Stmt Name 还是 Portal Name 当为临时语句的时候,默认 Name 为 "0"
// Whether stmt name or portal name, when it is a temporary statement, the default name is "0".
if desc.Name == "" {
desc.Name = "0"
}

var stmtID uint32
var ok bool

// 如果通过 Portal 来指定运行语句,则直接通过 Portal 找到对应 StmtID 来运行即可
// If it specify the prepared statement through portal,
// here can directly find the corresponding stmtID through portal.
if desc.ObjectType == 'P' {
stmtID, ok = vars.Portal[desc.Name]
} else {
// 获取stmtID 通过ID获取到预处理语句
// Or get prepared stmtID through stmtName.
stmtID, ok = vars.PreparedStmtNameToID[desc.Name]
}

Expand All @@ -189,33 +193,38 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
strconv.FormatUint(uint64(stmtID), 10), "stmt_description")
}

stmt := cc.ctx.GetStatement(int(stmtID))
// Get prepared stmt through stmtID.
stmt :=cc.ctx.GetStatement(int(stmtID))
if stmt == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
strconv.FormatUint(uint64(stmtID), 10), "stmt_description")
}
numParams := stmt.NumParams()

// 将解析阶段解析出的参数类型获取到,并转换为PgSQL数据类型传回到客户端
paramsType := stmt.GetParamsType()
pgType := make([]uint32, numParams)
for i := range paramsType {
pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])
}
numParams := stmt.NumParams()
if numParams > 0 {
// Get param types that analyzed in `handleStmtBind`,
// And convert it to PgSQL data type and return it to the client
paramsType := stmt.GetParamsType()
pgType := make([]uint32, numParams)
for i := range paramsType {
pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])
}

if err := cc.writeParameterDescription(pgType); err != nil {
return err
if err := cc.writeParameterDescription(pgType); err != nil {
return err
}
}

// columnInfo 有数据则返回 WriteRowDescription 没有数据则需要返回 NoData
// Return `WriteRowDescription` when columnInfo is not empty,
// otherwise return `writeNoData`.
columnInfo := stmt.GetColumnInfo()
if columnInfo == nil || len(columnInfo) > 0 {
if err := cc.WriteRowDescription(columnInfo); err != nil {
return err
}

// 如果在这个位子已经输出的行描述信息,则在后面输出结果时不再输出行描述信息

// If the row description information has been output here,
// it will not be output when `writeResultset` later.
} else {
if err := cc.writeNoData(); err != nil {
return err
Expand All @@ -224,12 +233,12 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
return cc.flush(ctx)
}

// handleStmtExecute 处理 Execute 请求
// handleStmtExecute handle execute messages in pgsql's extended query.
// PGSQL Modified
func (cc *clientConn) handleStmtExecute(ctx context.Context, execute pgproto3.Execute) error {
defer trace.StartRegion(ctx, "HandleStmtExecute").End()

// 当为临时预处理查询,默认设置 Name 为 0
// When it is a temporary prepared stmt, the default name setting is "0".
if execute.Portal == "" {
execute.Portal = "0"
}
Expand Down Expand Up @@ -261,7 +270,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, execute pgproto3.Ex
return nil
}

// handleStmtClose 处理 Close 请求
// handleStmtClose handle close messages in pgsql's extended query.
func (cc *clientConn) handleStmtClose(ctx context.Context, close pgproto3.Close) error {
vars := cc.ctx.GetSessionVars()
var stmtID uint32
Expand Down