Skip to content

Commit

Permalink
Conn: Improved extended query handling (#73)
Browse files Browse the repository at this point in the history
* removed unnecessary write back to client

Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>

* fixed extended query describe

Signed-off-by: AmoebaProtozoa <8039876+AmoebaProtozoa@users.noreply.github.com>
  • Loading branch information
AmoebaProtozoa authored Sep 18, 2021
1 parent b4a2c60 commit 08c0517
Showing 1 changed file with 14 additions and 25 deletions.
39 changes: 14 additions & 25 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ func (cc *clientConn) handleStmtPrepare(ctx context.Context, parser pgproto3.Par

stmt.SetParamsType(paramTypes)

err = cc.writeParseComplete()
if err != nil {
return err
}
return cc.flush(ctx)
return cc.writeParseComplete()
}

// handleStmtBind handle bind messages in pgsql's extended query.
Expand Down Expand Up @@ -156,11 +152,7 @@ func (cc *clientConn) handleStmtBind(ctx context.Context, bind pgproto3.Bind) (e
vars.Portal["0"] = stmtID
}

err = cc.writeBindComplete()
if err != nil {
return err
}
return cc.flush(ctx)
return cc.writeBindComplete()
}

// handleStmtDescription handle Description messages in pgsql's extended query,
Expand All @@ -177,14 +169,17 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D

var stmtID uint32
var ok bool
var isPortal bool

// If it specify the prepared statement through portal,
// If it specifies the prepared statement through portal,
// here can directly find the corresponding stmtID through portal.
if desc.ObjectType == 'P' {
stmtID, ok = vars.Portal[desc.Name]
isPortal = true
} else {
// Or get prepared stmtID through stmtName.
stmtID, ok = vars.PreparedStmtNameToID[desc.Name]
isPortal = false
}

if !ok {
Expand All @@ -199,12 +194,13 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
strconv.FormatUint(uint64(stmtID), 10), "stmt_description")
}

numParams := stmt.NumParams()
if numParams > 0 {
// we send parameter description only if this is statement and not a portal
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
if !isPortal {
// Get param types that analyzed in `handleStmtBind`,
// And convert it to PgSQL data type and return it to the client
paramsType := stmt.GetParamsType()
pgType := make([]uint32, numParams)
pgType := make([]uint32, stmt.NumParams())
for i := range paramsType {
pgType[i] = convertMySQLDataTypeToPgSQLDataType(paramsType[i])
}
Expand All @@ -218,18 +214,11 @@ func (cc *clientConn) handleStmtDescription(ctx context.Context, desc pgproto3.D
// otherwise return `writeNoData`.
columnInfo := stmt.GetColumnInfo()
if columnInfo == nil || len(columnInfo) > 0 {
if err := cc.WriteRowDescription(columnInfo); err != nil {
return err
}

// If the row description information has been output here,
// it will not be output when `writeResultset` later.
} else {
if err := cc.writeNoData(); err != nil {
return err
}
return cc.WriteRowDescription(columnInfo)
}
return cc.flush(ctx)
// If the row description information has been output here,
// it will not be output when `writeResultset` later.
return cc.writeNoData()
}

// handleStmtExecute handle execute messages in pgsql's extended query.
Expand Down

0 comments on commit 08c0517

Please sign in to comment.