diff --git a/driver.go b/driver.go index c851b4af..f1aee6e2 100644 --- a/driver.go +++ b/driver.go @@ -34,6 +34,7 @@ import ( adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/gax-go/v2" + "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -323,14 +324,15 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { } return &conn{ - connector: c, - client: c.client, - adminClient: c.adminClient, - database: databaseName, - retryAborts: c.retryAbortsInternally, - execSingleQuery: queryInSingleUse, - execSingleDMLTransactional: execInNewRWTransaction, - execSingleDMLPartitioned: execAsPartitionedDML, + connector: c, + client: c.client, + adminClient: c.adminClient, + database: databaseName, + retryAborts: c.retryAbortsInternally, + execSingleQuery: queryInSingleUse, + execSingleQueryTransactional: queryInNewRWTransaction, + execSingleDMLTransactional: execInNewRWTransaction, + execSingleDMLPartitioned: execAsPartitionedDML, }, nil } @@ -638,9 +640,10 @@ type conn struct { database string retryAborts bool - execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator - execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error) - execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) + execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator + execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (rowIterator, time.Time, error) + execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error) + execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) // batch is the currently active DDL or DML batch on this connection. batch *batch @@ -1130,7 +1133,21 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } var iter rowIterator if c.tx == nil { - iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)} + statementType := detectStatementType(query) + if statementType == statementTypeDml { + // Use a read/write transaction to execute the statement. + var commitTs time.Time + iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, c.createTransactionOptions()) + if err != nil { + return nil, err + } + c.commitTs = &commitTs + } else { + // The statement was either detected as being a query, or potentially not recognized at all. + // In that case, just default to using a single-use read-only transaction and let Spanner + // return an error if the statement is not suited for that type of transaction. + iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)} + } } else { iter = c.tx.Query(ctx, stmt) } @@ -1149,12 +1166,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name // Clear the commit timestamp of this connection before we execute the statement. c.commitTs = nil + statementType := detectStatementType(query) // Use admin API if DDL statement is provided. - isDDL, err := isDDL(query) - if err != nil { - return nil, err - } - if isDDL { + if statementType == statementTypeDdl { // Spanner does not support DDL in transactions, and although it is technically possible to execute DDL // statements while a transaction is active, we return an error to avoid any confusion whether the DDL // statement is executed as part of the active transaction or not. @@ -1312,6 +1326,60 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner. return c.Single().WithTimestampBound(tb).Query(ctx, statement) } +type wrappedRowIterator struct { + *spanner.RowIterator + + noRows bool + firstRow *spanner.Row +} + +func (ri *wrappedRowIterator) Next() (*spanner.Row, error) { + if ri.noRows { + return nil, iterator.Done + } + if ri.firstRow != nil { + defer func() { ri.firstRow = nil }() + return ri.firstRow, nil + } + return ri.RowIterator.Next() +} + +func (ri *wrappedRowIterator) Stop() { + ri.RowIterator.Stop() +} + +func (ri *wrappedRowIterator) Metadata() *spannerpb.ResultSetMetadata { + return ri.RowIterator.Metadata +} + +func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (rowIterator, time.Time, error) { + var result *wrappedRowIterator + fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + it := tx.Query(ctx, statement) + row, err := it.Next() + if err == iterator.Done { + result = &wrappedRowIterator{ + RowIterator: it, + noRows: true, + } + } else if err != nil { + it.Stop() + return err + } else { + result = &wrappedRowIterator{ + RowIterator: it, + firstRow: row, + } + } + return nil + } + resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options) + if err != nil { + return nil, time.Time{}, err + } + return result, resp.CommitTs, nil +} + func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { var rowsAffected int64 fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 1817c426..c3137f7f 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -1222,6 +1222,64 @@ func TestQueryWithMissingPositionalParameter(t *testing.T) { } } +func TestDmlReturningInAutocommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + s := "insert into users (id, name) values (@id, @name) then return id" + _ = server.TestSpanner.PutStatementResult( + s, + &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: testutil.CreateSelect1ResultSet(), + }, + ) + + for _, prepare := range []bool{false, true} { + var rows *sql.Rows + var err error + if prepare { + stmt, err := db.PrepareContext(ctx, s) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.QueryContext(ctx, sql.Named("id", 1), sql.Named("name", "bar")) + } else { + rows, err = db.QueryContext(ctx, s, sql.Named("id", 1), sql.Named("name", "bar")) + } + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("missing row") + } + var id int + if err := rows.Scan(&id); err != nil { + t.Fatal(err) + } + if g, w := id, 1; g != w { + t.Fatalf("id mismatch\n Got: %v\nWant: %v", g, w) + } + if rows.Next() { + t.Fatal("got more rows than expected") + } + + // Verify that a read/write transaction was used. + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlRequests), 1; g != w { + t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + } +} + func TestDdlInAutocommit(t *testing.T) { t.Parallel() diff --git a/integration_test.go b/integration_test.go index 1c8192a4..6ae58bfc 100644 --- a/integration_test.go +++ b/integration_test.go @@ -876,9 +876,29 @@ func TestExecContextDml(t *testing.T) { } // Clear table for next test. - _, err = db.ExecContext(ctx, `DELETE FROM TestExecContextDml WHERE true`) - if (err != nil) && (!tc.then.wantError) { - t.Errorf("%s: unexpected query error: %v", tc.name, err) + it, err := db.QueryContext(ctx, `DELETE FROM TestExecContextDml WHERE true THEN RETURN *`) + if err != nil { + if !tc.then.wantError { + t.Errorf("%s: unexpected query error: %v", tc.name, err) + } + } else { + if cols, err := it.Columns(); err != nil { + t.Errorf("%s: unexpected Columns() error: %v", tc.name, err) + } else { + if g, w := len(cols), 6; g != w { + t.Errorf("%s: column number mismatch\n Got: %v\nWant: %v", tc.name, g, w) + } + if !cmp.Equal(cols, []string{"key", "testString", "testBytes", "testInt", "testFloat", "testBool"}) { + t.Errorf("%s: column names mismatch: %v", tc.name, cols) + } + } + for it.Next() { + if err := it.Err(); err != nil { + t.Errorf("%s: unexpected iterator error: %v", tc.name, err) + break + } + } + it.Close() } } } diff --git a/statement_parser.go b/statement_parser.go index 3443a6e7..d2ffa8eb 100644 --- a/statement_parser.go +++ b/statement_parser.go @@ -161,6 +161,11 @@ func removeCommentsAndTrim(sql string) (string, error) { // Removes any statement hints at the beginning of the statement. // It assumes that any comments have already been removed. func removeStatementHint(sql string) string { + // Return quickly if the statement does not start with a hint. + if len(sql) < 2 || sql[0] != '@' { + return sql + } + // Valid statement hints at the beginning of a query statement can only contain a fixed set of // possible values. Although it is possible to add a @{FORCE_INDEX=...} as a statement hint, the // only allowed value is _BASE_TABLE. This means that we can safely assume that the statement @@ -303,20 +308,36 @@ func findParams(positionalParamChar rune, sql string) (string, []string, error) } // isDDL returns true if the given sql string is a DDL statement. -func isDDL(query string) (bool, error) { - query, err := removeCommentsAndTrim(query) - if err != nil { - return false, err - } +// This function assumes that any comments and hints at the start +// of the sql string have been removed. +func isDDL(query string) bool { + return isStatementType(query, ddlStatements) +} + +// isDml returns true if the given sql string is a Dml statement. +// This function assumes that any comments and hints at the start +// of the sql string have been removed. +func isDml(query string) bool { + return isStatementType(query, dmlStatements) +} + +// isQuery returns true if the given sql string is a SELECT statement. +// This function assumes that any comments and hints at the start +// of the sql string have been removed. +func isQuery(query string) bool { + return isStatementType(query, selectStatements) +} + +func isStatementType(query string, keywords map[string]bool) bool { // We can safely check if the string starts with a specific string, as we // have already removed all leading spaces, and there are no keywords that - // start with the same substring as one of the DDL keywords. - for ddl := range ddlStatements { - if len(query) >= len(ddl) && strings.EqualFold(query[:len(ddl)], ddl) { - return true, nil + // start with the same substring as one of the keywords. + for keyword := range keywords { + if len(query) >= len(keyword) && strings.EqualFold(query[:len(keyword)], keyword) { + return true } } - return false, nil + return false } // clientSideStatements are loaded from the client_side_statements.json file. @@ -431,3 +452,30 @@ func parseClientSideStatement(c *conn, query string) (*executableClientSideState } return nil, nil } + +type statementType int + +const ( + statementTypeUnknown statementType = iota + statementTypeQuery + statementTypeDml + statementTypeDdl +) + +// detectStatementType returns the type of SQL statement based on the first +// keyword that is found in the SQL statement. +func detectStatementType(sql string) statementType { + sql, err := removeCommentsAndTrim(sql) + if err != nil { + return statementTypeUnknown + } + sql = removeStatementHint(sql) + if isQuery(sql) { + return statementTypeQuery + } else if isDml(sql) { + return statementTypeDml + } else if isDDL(sql) { + return statementTypeDdl + } + return statementTypeUnknown +} diff --git a/statement_parser_test.go b/statement_parser_test.go index 9aa0381a..e592748f 100644 --- a/statement_parser_test.go +++ b/statement_parser_test.go @@ -700,10 +700,10 @@ func FuzzFindParams(f *testing.F) { }) } -// note: isDDL function does not check validity of statement -// just that the statement begins with a DDL instruction. -// Other checking performed by database. -func TestIsDdl(t *testing.T) { +// Note: The detectStatementType function does not check validity of a statement, +// only whether the statement begins with a DDL instruction. +// Actual validity checks are performed by the database. +func TestStatementIsDdl(t *testing.T) { tests := []struct { name string input string @@ -872,10 +872,7 @@ func TestIsDdl(t *testing.T) { } for _, tc := range tests { - got, err := isDDL(tc.input) - if err != nil { - t.Error(err) - } + got := detectStatementType(tc.input) == statementTypeDdl if got != tc.want { t.Errorf("isDDL test failed, %s: wanted %t got %t.", tc.name, tc.want, got) } @@ -888,7 +885,7 @@ func FuzzIsDdl(f *testing.F) { } f.Fuzz(func(t *testing.T, input string) { - _, _ = isDDL(input) + _ = isDDL(input) }) } @@ -1017,6 +1014,68 @@ func TestFindParams_Errors(t *testing.T) { } } +func TestDetectStatementType(t *testing.T) { + tests := []struct { + input string + want statementType + }{ + { + input: "select 1", + want: statementTypeQuery, + }, + { + input: "from test", + want: statementTypeQuery, + }, + { + input: "with t as (select 1) select * from t", + want: statementTypeQuery, + }, + { + input: "GRAPH FinGraph\nMATCH (n)\nRETURN LABELS(n) AS label, n.id", + want: statementTypeQuery, + }, + { + input: "/* this is a comment */ -- this is also a comment\n @ { statement_hint_key=value } select 1", + want: statementTypeQuery, + }, + { + input: "update foo set bar=1 where true", + want: statementTypeDml, + }, + { + input: "insert into foo (id, value) select 1, 'test'", + want: statementTypeDml, + }, + { + input: "delete from foo where true", + want: statementTypeDml, + }, + { + input: "delete from foo where true then return *", + want: statementTypeDml, + }, + { + input: "create table foo (id int64) primary key (id)", + want: statementTypeDdl, + }, + { + input: "drop table if exists foo", + want: statementTypeDdl, + }, + { + input: "input from borkisland", + want: statementTypeUnknown, + }, + } + + for _, test := range tests { + if g, w := detectStatementType(test.input), test.want; g != w { + t.Errorf("statement type mismatch for %q\n Got: %v\nWant: %v", test.input, g, w) + } + } +} + var fuzzQuerySamples = []string{"", "SELECT 1;", "RUN BATCH", "ABORT BATCH", "Show variable Retry_Aborts_Internally", "@{JOIN_METHOD=HASH_JOIN SELECT * FROM PersonsTable"} func init() { diff --git a/stmt.go b/stmt.go index f4e7881c..d89a93a6 100644 --- a/stmt.go +++ b/stmt.go @@ -24,9 +24,10 @@ import ( ) type stmt struct { - conn *conn - numArgs int - query string + conn *conn + numArgs int + query string + statementType statementType } func (s *stmt) Close() error { @@ -59,7 +60,21 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv if s.conn.tx != nil { it = s.conn.tx.Query(ctx, ss) } else { - it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)} + if s.statementType == statementTypeUnknown { + s.statementType = detectStatementType(s.query) + } + if s.statementType == statementTypeDml { + // Use a read/write transaction to execute the statement. + it, _, err = s.conn.execSingleQueryTransactional(ctx, s.conn.client, ss, s.conn.createTransactionOptions()) + if err != nil { + return nil, err + } + } else { + // The statement was either detected as being a query, or potentially not recognized at all. + // In that case, just default to using a single-use read-only transaction and let Spanner + // return an error if the statement is not suited for that type of transaction. + it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)} + } } return &rows{it: it}, nil }