From 9542bdd9491b36ade6b4248c62a0d9a63c621d2f Mon Sep 17 00:00:00 2001 From: Idhor Date: Sun, 12 Apr 2015 16:38:45 +0200 Subject: [PATCH] Enable Multi Results support and discard additional results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - packets.go: flag clientMultiResults, update status when receiving an EOF packet, discard additional results on readRow when EOF is reached - statement.go: currently a nil rows.mc is used as an eof, don’t set it if there are no columns to avoid that Next() waits indefinitely - rows.go: discard additional results on close and avoid panic on Columns() --- packets.go | 45 +++++++++++++++++++++++++++++++++++++++++++-- rows.go | 8 +++++++- statement.go | 2 +- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/packets.go b/packets.go index 290a3887a..fd3022554 100644 --- a/packets.go +++ b/packets.go @@ -214,6 +214,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientLongPassword | clientTransactions | clientLocalFiles | + clientMultiResults | mc.flags&clientLongFlag if mc.cfg.clientFoundRows { @@ -470,6 +471,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } } +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { @@ -484,7 +489,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] - mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 + mc.status = readStatus(data[1+n+m : 1+n+m+2]) // warning count [2 bytes] if !mc.strict { @@ -603,6 +608,11 @@ func (rows *textRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardMoreResultsIfExists(); err != nil { + return err + } rows.mc = nil return io.EOF } @@ -660,6 +670,10 @@ func (mc *mysqlConn) readUntilEOF() error { if err == nil && data[0] != iEOF { continue } + if err == nil && data[0] == iEOF && len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return err // Err or EOF } } @@ -964,6 +978,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } +func (mc *mysqlConn) discardMoreResultsIfExists() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } else { + mc.status &^= statusMoreResultsExists + } + } + return nil +} + // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() @@ -973,11 +1009,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - rows.mc = nil // EOF Packet if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardMoreResultsIfExists(); err != nil { + return err + } + rows.mc = nil return io.EOF } + rows.mc = nil // Error otherwise return rows.mc.handleErrorPacket(data) diff --git a/rows.go b/rows.go index 9d97d6d4f..7f281e8c0 100644 --- a/rows.go +++ b/rows.go @@ -38,7 +38,7 @@ type emptyRows struct{} func (rows *mysqlRows) Columns() []string { columns := make([]string, len(rows.columns)) - if rows.mc.cfg.columnsWithAlias { + if rows.mc != nil && rows.mc.cfg.columnsWithAlias { for i := range columns { columns[i] = rows.columns[i].tableName + "." + rows.columns[i].name } @@ -61,6 +61,12 @@ func (rows *mysqlRows) Close() error { // Remove unread packets from stream err := mc.readUntilEOF() + if err == nil { + if err = mc.discardMoreResultsIfExists(); err != nil { + return err + } + } + rows.mc = nil return err } diff --git a/statement.go b/statement.go index 142ef5416..73758360a 100644 --- a/statement.go +++ b/statement.go @@ -94,9 +94,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } rows := new(binaryRows) - rows.mc = mc if resLen > 0 { + rows.mc = mc // Columns // If not cached, read them and cache them if stmt.columns == nil {