Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: Restructure Parse OID handling #77

Merged
merged 5 commits into from
Sep 28, 2021
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 122 additions & 28 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"time"

"github.com/DigitalChinaOpenSource/DCParser/mysql"
pgOID "github.com/lib/pq/oid"
"github.com/pingcap/errors"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx/stmtctx"
Expand All @@ -81,22 +82,68 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parse pgproto3.Pars

// Get param types in sql plan, and save it in `stmt`.
var paramTypes []byte
if cachedStmt, ok := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt); ok {
cachedParams := cachedStmt.PreparedAst.Params
for i := range cachedParams {
paramTypes = append(paramTypes, cachedParams[i].GetType().Tp)
numberOfParams := stmt.NumParams()

// parameters' oid sent in by frontend
oids := parse.ParameterOIDs

// here we get cacheParams from our prepared plan.
// if oid is not available, we use our prepared plan.
cacheStmt := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt)
cacheParams := cacheStmt.PreparedAst.Params

// If frontend send in OIDs, we save them into stmt
if len(oids) > 0 {
stmt.SetOIDs(oids)
}

for i := 0; i < numberOfParams; i++ {
// check if oids are available, there are three situations in which oid is not available:
// 1. Frontend didn't send them, since it's optional
// 2. Frontend didn't send enough of them
// 3. Frontend send in oid of 0, indicate unspecified
if len(oids) > i && oids[i] != 0 {
// If frontend gives param OID, we convert it to paramTypes directly
paramTypes = append(paramTypes, pgOIDToMySQLType(oids[i]))
} else {
// If OID is not available, we get it from our prepared statement
paramTypes = append(paramTypes, cacheParams[i].GetType().Tp)
}
}

stmt.SetParamsType(paramTypes)

if len(parse.ParameterOIDs) > 0 {
stmt.SetOIDs(parse.ParameterOIDs)
}

return cc.writeParseComplete()
}

// pgOIDToMySQLType converts postgres OID into mysql type
func pgOIDToMySQLType(oid uint32) byte {
switch pgOID.Oid(oid) {
case pgOID.T_int8:
return mysql.TypeLonglong
case pgOID.T_int4:
return mysql.TypeLong
case pgOID.T_int2:
return mysql.TypeShort
case pgOID.T_float4:
return mysql.TypeFloat
case pgOID.T_float8:
return mysql.TypeDouble
case pgOID.T_timestamp:
return mysql.TypeTimestamp
case pgOID.T_date:
return mysql.TypeNewDate
case pgOID.T_numeric:
return mysql.TypeNewDecimal
case pgOID.T_bytea:
return mysql.TypeBlob
case pgOID.T_text, pgOID.T_varchar:
return mysql.TypeVarchar
default:
return mysql.TypeUnspecified
}
}

// handleStmtBind handle bind messages in pgsql's extended query.
// PGSQL Modified
func (cc *clientConn) handleStmtBind(ctx context.Context, bind pgproto3.Bind) (err error) {
Expand Down Expand Up @@ -335,15 +382,56 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
return stmtID, fetchSize, nil
}

// getFormatCode decode the formatCodes passed in from Bind struct
// it will return 0 if the format is Text, 1 if Binary
// Note that this will handle empty formatCodes and single format Codes gracefully
func getFormatCode(formatCodes []int16, index int) int16 {
// default format is Text
if len(formatCodes) == 0 {
return 0
}

// if length is one, use that for all arguments
if len(formatCodes) == 1 {
return formatCodes[0]
}

return formatCodes[index]
}

// getOID will return the postgres OID for the specific parameter
// If OID is unset, it will return 0 to indicate unspecified
func getOID(pgOIDs []uint32, index int) uint32 {
// default case where frontend didn't send OID
if len(pgOIDs) == 0 {
return 0
}

// case where oid list is shorter than the parameter we ask
if len(pgOIDs) <= index {
return 0
}

return pgOIDs[index]
}

// parseBindArgs 将客户端传来的参数值解析为 Datum 结构
// PgSQL Modified
func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes []byte, bind pgproto3.Bind, boundParams [][]byte, pgOIDs []uint32) error {
// todo 传参为文本 text 格式时候的处理

hasOID := len(pgOIDs) > 0
var (
oid uint32
formatCode int16
isUnsigned bool
)

for i := 0; i < len(args); i++ {

// todo 使用boundParams
// todo BoundParams
oid = getOID(pgOIDs, i)
formatCode = getFormatCode(bind.ParameterFormatCodes, i)
// todo Check If variable should be signed, currently we are assuming all signed
isUnsigned = false

if bind.Parameters[i] == nil {
var nilDatum types.Datum
Expand All @@ -352,10 +440,6 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue
}

// todo isUnsigned
// isUnsigned 暂时Pg无法判断, 默认为有符号
isUnsigned := false

switch paramTypes[i] {
case mysql.TypeNull:
var nilDatum types.Datum
Expand Down Expand Up @@ -385,7 +469,7 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeInt24, mysql.TypeLong:
if bind.ParameterFormatCodes[i] == 1 { // The data passed in is in binary format
if formatCode == 1 { // The data passed in is in binary format
var b [8]byte
copy(b[8-len(bind.Parameters[i]):], bind.Parameters[i])
val := binary.BigEndian.Uint64(b[:])
Expand Down Expand Up @@ -416,6 +500,12 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeFloat:
if formatCode == 1 {
bits := binary.BigEndian.Uint32(bind.Parameters[i])
f32 := math.Float32frombits(bits)
args[i] = types.NewFloat32Datum(f32)
continue
}
valFloat, err := strconv.ParseFloat(string(bind.Parameters[i]), 32)
if err != nil {
return err
Expand All @@ -424,6 +514,12 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeDouble:
if formatCode == 1 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
continue
}
valFloat, err := strconv.ParseFloat(string(bind.Parameters[i]), 64)
if err != nil {
return err
Expand Down Expand Up @@ -453,7 +549,7 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
case mysql.TypeNewDecimal:
// fixme decimal 待测试 待修复
var dec types.MyDecimal
if bind.ParameterFormatCodes[i] == 1 {
if formatCode == 1 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
Expand All @@ -477,18 +573,16 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
args[i] = types.NewDatum(tmp)
continue
case mysql.TypeUnspecified:
if hasOID {
if bind.ParameterFormatCodes[i] == 1 && pgOIDs[i] == 23 { // The data passed in is in binary format
args[i] = types.NewBinaryLiteralDatum(bind.Parameters[i])
continue
}
if formatCode == 1 && oid == 23 {
args[i] = types.NewBinaryLiteralDatum(bind.Parameters[i])
continue
}

if bind.ParameterFormatCodes[i] == 1 && pgOIDs[i] == 701 { // The data passed in is in binary format
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
continue
}
if formatCode == 1 && oid == 701 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
continue
}
tmp := string(hack.String(bind.Parameters[i]))
args[i] = types.NewDatum(tmp)
Expand Down