Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: Improved Extended Query Protocol #74

Merged
merged 3 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"math"
"runtime/trace"
"strconv"
"strings"
"time"

"github.com/DigitalChinaOpenSource/DCParser/mysql"
Expand All @@ -68,17 +69,17 @@ import (

// handleStmtPrepare handle prepare message in pgsql's extended query.
// PgSQL Modified
func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Parse) error {
func (cc *clientConn) handleStmtPrepare(ctx context.Context, parse pgproto3.Parse) error {
//stmt, columns, params, err := cc.ctx.Prepare(parser.Query)
stmt, _, _, err := cc.ctx.Prepare(parser.Query, parser.Name)
stmt, _, _, err := cc.ctx.Prepare(parse.Query, parse.Name)

if err != nil {
return err
}

vars := cc.ctx.GetSessionVars()

// Get param types in sqllan, and save it in `stmt`.
// Get param types in sql plan, and save it in `stmt`.
var paramTypes []byte
if cachedStmt, ok := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt); ok {
cachedParams := cachedStmt.PreparedAst.Params
Expand All @@ -89,6 +90,10 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par

stmt.SetParamsType(paramTypes)

if len(parse.ParameterOIDs) > 0 {
stmt.SetOIDs(parse.ParameterOIDs)
}

return cc.writeParseComplete()
}

Expand Down Expand Up @@ -198,14 +203,17 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
if !isPortal {
// Get param types that analyzed in `handleStmtBind`,
// And convert it to PgSQL data type and return it to the client
paramsType := stmt.GetParamsType()
pgType := make([]uint32, stmt.NumParams())
for i := range paramsType {
pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])

pgOIDs := stmt.GetOIDs()
if len(pgOIDs) == 0 {
paramsType := stmt.GetParamsType()
pgOIDs = make([]uint32, stmt.NumParams())
for i := range paramsType {
pgOIDs[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])
}
}

if err := cc.writeParameterDescription(pgType); err != nil {
if err := cc.writeParameterDescription(pgOIDs); err != nil {
return err
}
}
Expand Down Expand Up @@ -376,6 +384,13 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeInt24, mysql.TypeLong:
if bind.ParameterFormatCodes[i] == 1 { // The data passed in is in binary format
var b [8]byte
copy(b[8-len(bind.Parameters[i]):], bind.Parameters[i])
val := binary.BigEndian.Uint64(b[:])
args[i] = types.NewUintDatum(val)
continue
}
valInt, err := strconv.Atoi(string(bind.Parameters[i]))
if err != nil {
return err
Expand Down Expand Up @@ -414,8 +429,18 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
}
args[i] = types.NewFloat64Datum(valFloat)
continue

case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime:
case mysql.TypeTimestamp:
// we ignore timezone here
timeStr := string(bind.Parameters[i])
tzIndex := strings.Index(timeStr, " +")
if tzIndex == -1 {
args[i] = types.NewDatum(timeStr)
continue
}
noTzStr := timeStr[:tzIndex]
args[i] = types.NewDatum(noTzStr)
continue
case mysql.TypeDate, mysql.TypeDatetime:
// fixme 日期待测试 待修复
args[i] = types.NewDatum(string(bind.Parameters[i]))
continue
Expand All @@ -427,6 +452,15 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
case mysql.TypeNewDecimal:
// fixme decimal 待测试 待修复
var dec types.MyDecimal
if bind.ParameterFormatCodes[i] == 1 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
err := sc.HandleTruncate(dec.FromFloat64(f64))
if err != nil {
return err
}
continue
}
err := sc.HandleTruncate(dec.FromString(bind.Parameters[i]))
if err != nil {
return err
Expand Down
6 changes: 6 additions & 0 deletions server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ type PreparedStatement interface {

// SetResultFormat 设置结果返回的格式 0 为 Text, 1 为 Binary
SetResultFormat(rf []int16)

// GetOIDs returns the postgres OIDs
GetOIDs() []uint32

// SetOIDs set the postgres OIDs
SetOIDs(pgOIDs []uint32)
}

// ResultSet is the result set of an query.
Expand Down
11 changes: 11 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type TiDBStatement struct {
columnInfo []*ColumnInfo
args []types.Datum
resultFormat []int16
paramOIDs []uint32
}

// ID implements PreparedStatement ID method.
Expand Down Expand Up @@ -231,6 +232,16 @@ func (ts *TiDBStatement) SetResultFormat(rf []int16) {
ts.resultFormat = rf
}

// GetOIDs return OIDs for the current statement
func (ts *TiDBStatement) GetOIDs() []uint32 {
return ts.paramOIDs
}

// SetOIDs set OIDs for the current statement
func (ts *TiDBStatement) SetOIDs(pgOIDs []uint32) {
ts.paramOIDs = pgOIDs
}

// OpenCtx implements IDriver.
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (QueryCtx, error) {
se, err := session.CreateSession(qd.store)
Expand Down