Skip to content

Commit

Permalink
Server: Improved Extended Query Protocol (#74)
Browse files Browse the repository at this point in the history
* binary long input

Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>

* binary float support

Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>

* new field to save pg oid

Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>
  • Loading branch information
AmoebaProtozoa authored Sep 23, 2021
1 parent 08c0517 commit 3bc9cb5
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
56 changes: 45 additions & 11 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"math"
"runtime/trace"
"strconv"
"strings"
"time"

"github.com/DigitalChinaOpenSource/DCParser/mysql"
Expand All @@ -68,17 +69,17 @@ import (

// handleStmtPrepare handle prepare message in pgsql's extended query.
// PgSQL Modified
func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Parse) error {
func (cc *clientConn) handleStmtPrepare(ctx context.Context, parse pgproto3.Parse) error {
//stmt, columns, params, err := cc.ctx.Prepare(parser.Query)
stmt, _, _, err := cc.ctx.Prepare(parser.Query, parser.Name)
stmt, _, _, err := cc.ctx.Prepare(parse.Query, parse.Name)

if err != nil {
return err
}

vars := cc.ctx.GetSessionVars()

// Get param types in sqllan, and save it in `stmt`.
// Get param types in sql plan, and save it in `stmt`.
var paramTypes []byte
if cachedStmt, ok := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt); ok {
cachedParams := cachedStmt.PreparedAst.Params
Expand All @@ -89,6 +90,10 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par

stmt.SetParamsType(paramTypes)

if len(parse.ParameterOIDs) > 0 {
stmt.SetOIDs(parse.ParameterOIDs)
}

return cc.writeParseComplete()
}

Expand Down Expand Up @@ -198,14 +203,17 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
if !isPortal {
// Get param types that analyzed in `handleStmtBind`,
// And convert it to PgSQL data type and return it to the client
paramsType := stmt.GetParamsType()
pgType := make([]uint32, stmt.NumParams())
for i := range paramsType {
pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])

pgOIDs := stmt.GetOIDs()
if len(pgOIDs) == 0 {
paramsType := stmt.GetParamsType()
pgOIDs = make([]uint32, stmt.NumParams())
for i := range paramsType {
pgOIDs[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])
}
}

if err := cc.writeParameterDescription(pgType); err != nil {
if err := cc.writeParameterDescription(pgOIDs); err != nil {
return err
}
}
Expand Down Expand Up @@ -376,6 +384,13 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
continue

case mysql.TypeInt24, mysql.TypeLong:
if bind.ParameterFormatCodes[i] == 1 { // The data passed in is in binary format
var b [8]byte
copy(b[8-len(bind.Parameters[i]):], bind.Parameters[i])
val := binary.BigEndian.Uint64(b[:])
args[i] = types.NewUintDatum(val)
continue
}
valInt, err := strconv.Atoi(string(bind.Parameters[i]))
if err != nil {
return err
Expand Down Expand Up @@ -414,8 +429,18 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
}
args[i] = types.NewFloat64Datum(valFloat)
continue

case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime:
case mysql.TypeTimestamp:
// we ignore timezone here
timeStr := string(bind.Parameters[i])
tzIndex := strings.Index(timeStr, " +")
if tzIndex == -1 {
args[i] = types.NewDatum(timeStr)
continue
}
noTzStr := timeStr[:tzIndex]
args[i] = types.NewDatum(noTzStr)
continue
case mysql.TypeDate, mysql.TypeDatetime:
// fixme 日期待测试 待修复
args[i] = types.NewDatum(string(bind.Parameters[i]))
continue
Expand All @@ -427,6 +452,15 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes
case mysql.TypeNewDecimal:
// fixme decimal 待测试 待修复
var dec types.MyDecimal
if bind.ParameterFormatCodes[i] == 1 {
bits := binary.BigEndian.Uint64(bind.Parameters[i])
f64 := math.Float64frombits(bits)
err := sc.HandleTruncate(dec.FromFloat64(f64))
if err != nil {
return err
}
continue
}
err := sc.HandleTruncate(dec.FromString(bind.Parameters[i]))
if err != nil {
return err
Expand Down
6 changes: 6 additions & 0 deletions server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ type PreparedStatement interface {

// SetResultFormat 设置结果返回的格式 0 为 Text, 1 为 Binary
SetResultFormat(rf []int16)

// GetOIDs returns the postgres OIDs
GetOIDs() []uint32

// SetOIDs set the postgres OIDs
SetOIDs(pgOIDs []uint32)
}

// ResultSet is the result set of an query.
Expand Down
11 changes: 11 additions & 0 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type TiDBStatement struct {
columnInfo []*ColumnInfo
args []types.Datum
resultFormat []int16
paramOIDs []uint32
}

// ID implements PreparedStatement ID method.
Expand Down Expand Up @@ -231,6 +232,16 @@ func (ts *TiDBStatement) SetResultFormat(rf []int16) {
ts.resultFormat = rf
}

// GetOIDs return OIDs for the current statement
func (ts *TiDBStatement) GetOIDs() []uint32 {
return ts.paramOIDs
}

// SetOIDs set OIDs for the current statement
func (ts *TiDBStatement) SetOIDs(pgOIDs []uint32) {
ts.paramOIDs = pgOIDs
}

// OpenCtx implements IDriver.
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (QueryCtx, error) {
se, err := session.CreateSession(qd.store)
Expand Down

0 comments on commit 3bc9cb5

Please sign in to comment.