Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: Handle OID and Format Code #76

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 55 additions & 20 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,56 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
return stmtID, fetchSize, nil
}

// getFormatCode decode the formatCodes passed in from Bind struct
// it will return 0 if the format is Text, 1 if Binary
// Note that this will handle empty formatCodes and single format Codes gracefully
func getFormatCode(formatCodes []int16, index int) int16 {
// default format is Text
if len(formatCodes) == 0 {
return 0
}

// if length is one, use that for all arguments
if len(formatCodes) == 1 {
return formatCodes[0]
}

return formatCodes[index]
}

// getOID will return the postgres OID for the specific parameter
// If OID is unset, it will return 0 to indicate unspecified
func getOID(pgOIDs []uint32, index int) uint32 {
// default case where frontend didn't send OID
if len(pgOIDs) == 0 {
return 0
}

// case where oid list is shorter than the parameter we ask
if len(pgOIDs) <= index {
return 0
}

return pgOIDs[index]
}

// parseBindArgs 将客户端传来的参数值解析为 Datum 结构
// PgSQL Modified
func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes []byte, bind pgproto3.Bind, boundParams [][]byte, pgOIDs []uint32) error {
// todo 传参为文本 text 格式时候的处理

hasOID := len(pgOIDs) > 0
var (
oid uint32
formatCode int16
isUnsigned bool
)

for i := 0; i < len(args); i++ {

// todo 使用boundParams
// todo BoundParams
oid = getOID(pgOIDs, i)
formatCode = getFormatCode(bind.ParameterFormatCodes, i)
// todo Check If variable should be signed, currently we are assuming all signed
isUnsigned = false

if bind.Parameters[i] == nil {
var nilDatum types.Datum
Expand All @@ -352,10 +393,6 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue
}

// todo isUnsigned
// isUnsigned 暂时Pg无法判断, 默认为有符号
isUnsigned := false

switch paramTypes[i] {
case mysql.TypeNull:
var nilDatum types.Datum
Expand Down Expand Up @@ -385,7 +422,7 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeInt24, mysql.TypeLong:
if bind.ParameterFormatCodes[i] == 1 { // The data passed in is in binary format
if formatCode == 1 { // The data passed in is in binary format
var b [8]byte
copy(b[8-len(bind.Parameters[i]):], bind.Parameters[i])
val := binary.BigEndian.Uint64(b[:])
Expand Down Expand Up @@ -453,7 +490,7 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
case mysql.TypeNewDecimal:
// fixme decimal 待测试 待修复
var dec types.MyDecimal
if bind.ParameterFormatCodes[i] == 1 {
if formatCode == 1 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
Expand All @@ -477,18 +514,16 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
args[i] = types.NewDatum(tmp)
continue
case mysql.TypeUnspecified:
if hasOID {
if bind.ParameterFormatCodes[i] == 1 && pgOIDs[i] == 23 { // The data passed in is in binary format
args[i] = types.NewBinaryLiteralDatum(bind.Parameters[i])
continue
}
if formatCode == 1 && oid == 23 {
args[i] = types.NewBinaryLiteralDatum(bind.Parameters[i])
continue
}

if bind.ParameterFormatCodes[i] == 1 && pgOIDs[i] == 701 { // The data passed in is in binary format
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
continue
}
if formatCode == 1 && oid == 701 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
args[i] = types.NewFloat64Datum(f64)
continue
}
tmp := string(hack.String(bind.Parameters[i]))
args[i] = types.NewDatum(tmp)
Expand Down