diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 5cb3ee0bb4..03d1f88096 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -56,6 +56,7 @@ import ( "math" "runtime/trace" "strconv" + "strings" "time" "github.com/DigitalChinaOpenSource/DCParser/mysql" @@ -68,9 +69,9 @@ import ( // handleStmtPrepare handle prepare message in pgsql's extended query. // PgSQL Modified -func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Parse) error { +func (cc *clientConn) handleStmtPrepare(ctx context.Context, parse pgproto3.Parse) error { //stmt, columns, params, err := cc.ctx.Prepare(parser.Query) - stmt, _, _, err := cc.ctx.Prepare(parser.Query, parser.Name) + stmt, _, _, err := cc.ctx.Prepare(parse.Query, parse.Name) if err != nil { return err @@ -78,7 +79,7 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par vars := cc.ctx.GetSessionVars() - // Get param types in sqllan, and save it in `stmt`. + // Get param types in sql plan, and save it in `stmt`. var paramTypes []byte if cachedStmt, ok := vars.PreparedStmts[uint32(stmt.ID())].(*plannercore.CachedPrepareStmt); ok { cachedParams := cachedStmt.PreparedAst.Params @@ -89,6 +90,10 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par stmt.SetParamsType(paramTypes) + if len(parse.ParameterOIDs) > 0 { + stmt.SetOIDs(parse.ParameterOIDs) + } + return cc.writeParseComplete() } @@ -198,14 +203,17 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY if !isPortal { // Get param types that analyzed in `handleStmtBind`, - // And convert it to PgSQL data type and return it to the client - paramsType := stmt.GetParamsType() - pgType := make([]uint32, stmt.NumParams()) - for i := range paramsType { - pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i]) + + pgOIDs := stmt.GetOIDs() + if len(pgOIDs) == 0 { + paramsType := stmt.GetParamsType() + pgOIDs = make([]uint32, stmt.NumParams()) + for i := range paramsType { + pgOIDs[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i]) + } } - if err := cc.writeParameterDescription(pgType); err != nil { + if err := cc.writeParameterDescription(pgOIDs); err != nil { return err } } @@ -376,6 +384,13 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes continue case mysql.TypeInt24, mysql.TypeLong: + if bind.ParameterFormatCodes[i] == 1 { // The data passed in is in binary format + var b [8]byte + copy(b[8-len(bind.Parameters[i]):], bind.Parameters[i]) + val := binary.BigEndian.Uint64(b[:]) + args[i] = types.NewUintDatum(val) + continue + } valInt, err := strconv.Atoi(string(bind.Parameters[i])) if err != nil { return err @@ -414,8 +429,18 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes } args[i] = types.NewFloat64Datum(valFloat) continue - - case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime: + case mysql.TypeTimestamp: + // we ignore timezone here + timeStr := string(bind.Parameters[i]) + tzIndex := strings.Index(timeStr, " +") + if tzIndex == -1 { + args[i] = types.NewDatum(timeStr) + continue + } + noTzStr := timeStr[:tzIndex] + args[i] = types.NewDatum(noTzStr) + continue + case mysql.TypeDate, mysql.TypeDatetime: // fixme 日期待测试 待修复 args[i] = types.NewDatum(string(bind.Parameters[i])) continue @@ -427,6 +452,15 @@ func parseBindArgs(sc *stmtctx.StatementContext, args []types.Datum, paramTypes case mysql.TypeNewDecimal: // fixme decimal 待测试 待修复 var dec types.MyDecimal + if bind.ParameterFormatCodes[i] == 1 { + bits := binary.BigEndian.Uint64(bind.Parameters[i]) + f64 := math.Float64frombits(bits) + err := sc.HandleTruncate(dec.FromFloat64(f64)) + if err != nil { + return err + } + continue + } err := sc.HandleTruncate(dec.FromString(bind.Parameters[i])) if err != nil { return err diff --git a/server/driver.go b/server/driver.go index f8d1be1522..446751368e 100644 --- a/server/driver.go +++ b/server/driver.go @@ -173,6 +173,12 @@ type PreparedStatement interface { // SetResultFormat 设置结果返回的格式 0 为 Text, 1 为 Binary SetResultFormat(rf []int16) + + // GetOIDs returns the postgres OIDs + GetOIDs() []uint32 + + // SetOIDs set the postgres OIDs + SetOIDs(pgOIDs []uint32) } // ResultSet is the result set of an query. diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 6cf55bf498..9e38f330dd 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -82,6 +82,7 @@ type TiDBStatement struct { columnInfo []*ColumnInfo args []types.Datum resultFormat []int16 + paramOIDs []uint32 } // ID implements PreparedStatement ID method. @@ -231,6 +232,16 @@ func (ts *TiDBStatement) SetResultFormat(rf []int16) { ts.resultFormat = rf } +// GetOIDs return OIDs for the current statement +func (ts *TiDBStatement) GetOIDs() []uint32 { + return ts.paramOIDs +} + +// SetOIDs set OIDs for the current statement +func (ts *TiDBStatement) SetOIDs(pgOIDs []uint32) { + ts.paramOIDs = pgOIDs +} + // OpenCtx implements IDriver. func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState) (QueryCtx, error) { se, err := session.CreateSession(qd.store)