Skip to content

Commit

Permalink
fix: DML with THEN RETURN used read-only transaction (#339)
Browse files Browse the repository at this point in the history
DML statements that were executed outside an explicit transaction
would use a read-only transaction if the application called ExecQueryContext.

Fixes #235
  • Loading branch information
olavloite authored Jan 15, 2025
1 parent 255d8e0 commit ae36d4c
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 44 deletions.
102 changes: 85 additions & 17 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -347,14 +348,15 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
}

return &conn{
connector: c,
client: c.client,
adminClient: c.adminClient,
database: databaseName,
retryAborts: c.retryAbortsInternally,
execSingleQuery: queryInSingleUse,
execSingleDMLTransactional: execInNewRWTransaction,
execSingleDMLPartitioned: execAsPartitionedDML,
connector: c,
client: c.client,
adminClient: c.adminClient,
database: databaseName,
retryAborts: c.retryAbortsInternally,
execSingleQuery: queryInSingleUse,
execSingleQueryTransactional: queryInNewRWTransaction,
execSingleDMLTransactional: execInNewRWTransaction,
execSingleDMLPartitioned: execAsPartitionedDML,
}, nil
}

Expand Down Expand Up @@ -662,9 +664,10 @@ type conn struct {
database string
retryAborts bool

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error)
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (rowIterator, time.Time, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error)

// batch is the currently active DDL or DML batch on this connection.
batch *batch
Expand Down Expand Up @@ -1166,7 +1169,21 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}
var iter rowIterator
if c.tx == nil {
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)}
statementType := detectStatementType(query)
if statementType == statementTypeDml {
// Use a read/write transaction to execute the statement.
var commitTs time.Time
iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, c.createTransactionOptions())
if err != nil {
return nil, err
}
c.commitTs = &commitTs
} else {
// The statement was either detected as being a query, or potentially not recognized at all.
// In that case, just default to using a single-use read-only transaction and let Spanner
// return an error if the statement is not suited for that type of transaction.
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)}
}
} else {
iter = c.tx.Query(ctx, stmt)
}
Expand All @@ -1188,12 +1205,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
// Clear the commit timestamp of this connection before we execute the statement.
c.commitTs = nil

statementType := detectStatementType(query)
// Use admin API if DDL statement is provided.
isDDL, err := isDDL(query)
if err != nil {
return nil, err
}
if isDDL {
if statementType == statementTypeDdl {
// Spanner does not support DDL in transactions, and although it is technically possible to execute DDL
// statements while a transaction is active, we return an error to avoid any confusion whether the DDL
// statement is executed as part of the active transaction or not.
Expand Down Expand Up @@ -1357,6 +1371,60 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
return c.Single().WithTimestampBound(tb).Query(ctx, statement)
}

type wrappedRowIterator struct {
*spanner.RowIterator

noRows bool
firstRow *spanner.Row
}

func (ri *wrappedRowIterator) Next() (*spanner.Row, error) {
if ri.noRows {
return nil, iterator.Done
}
if ri.firstRow != nil {
defer func() { ri.firstRow = nil }()
return ri.firstRow, nil
}
return ri.RowIterator.Next()
}

func (ri *wrappedRowIterator) Stop() {
ri.RowIterator.Stop()
}

func (ri *wrappedRowIterator) Metadata() *spannerpb.ResultSetMetadata {
return ri.RowIterator.Metadata
}

func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (rowIterator, time.Time, error) {
var result *wrappedRowIterator
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
it := tx.Query(ctx, statement)
row, err := it.Next()
if err == iterator.Done {
result = &wrappedRowIterator{
RowIterator: it,
noRows: true,
}
} else if err != nil {
it.Stop()
return err
} else {
result = &wrappedRowIterator{
RowIterator: it,
firstRow: row,
}
}
return nil
}
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options)
if err != nil {
return nil, time.Time{}, err
}
return result, resp.CommitTs, nil
}

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) {
var rowsAffected int64
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
Expand Down
59 changes: 59 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,65 @@ func TestQueryWithMissingPositionalParameter(t *testing.T) {
}
}

func TestDmlReturningInAutocommit(t *testing.T) {
t.Parallel()

ctx := context.Background()
db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@id, @name) then return id"
_ = server.TestSpanner.PutStatementResult(
s,
&testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateSelect1ResultSet(),
},
)

for _, prepare := range []bool{false, true} {
var rows *sql.Rows
var err error
if prepare {
var stmt *sql.Stmt
stmt, err = db.PrepareContext(ctx, s)
if err != nil {
t.Fatal(err)
}
rows, err = stmt.QueryContext(ctx, sql.Named("id", 1), sql.Named("name", "bar"))
} else {
rows, err = db.QueryContext(ctx, s, sql.Named("id", 1), sql.Named("name", "bar"))
}
if err != nil {
t.Fatal(err)
}
if !rows.Next() {
t.Fatal("missing row")
}
var id int
if err := rows.Scan(&id); err != nil {
t.Fatal(err)
}
if g, w := id, 1; g != w {
t.Fatalf("id mismatch\n Got: %v\nWant: %v", g, w)
}
if rows.Next() {
t.Fatal("got more rows than expected")
}

// Verify that a read/write transaction was used.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
}
}
}

func TestDdlInAutocommit(t *testing.T) {
t.Parallel()

Expand Down
26 changes: 23 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,29 @@ func TestExecContextDml(t *testing.T) {
}

// Clear table for next test.
_, err = db.ExecContext(ctx, `DELETE FROM TestExecContextDml WHERE true`)
if (err != nil) && (!tc.then.wantError) {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
it, err := db.QueryContext(ctx, `DELETE FROM TestExecContextDml WHERE true THEN RETURN *`)
if err != nil {
if !tc.then.wantError {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
}
} else {
if cols, err := it.Columns(); err != nil {
t.Errorf("%s: unexpected Columns() error: %v", tc.name, err)
} else {
if g, w := len(cols), 6; g != w {
t.Errorf("%s: column number mismatch\n Got: %v\nWant: %v", tc.name, g, w)
}
if !cmp.Equal(cols, []string{"key", "testString", "testBytes", "testInt", "testFloat", "testBool"}) {
t.Errorf("%s: column names mismatch: %v", tc.name, cols)
}
}
for it.Next() {
if err := it.Err(); err != nil {
t.Errorf("%s: unexpected iterator error: %v", tc.name, err)
break
}
}
it.Close()
}
}
}
Expand Down
68 changes: 58 additions & 10 deletions statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ func removeCommentsAndTrim(sql string) (string, error) {
// Removes any statement hints at the beginning of the statement.
// It assumes that any comments have already been removed.
func removeStatementHint(sql string) string {
// Return quickly if the statement does not start with a hint.
if len(sql) < 2 || sql[0] != '@' {
return sql
}

// Valid statement hints at the beginning of a query statement can only contain a fixed set of
// possible values. Although it is possible to add a @{FORCE_INDEX=...} as a statement hint, the
// only allowed value is _BASE_TABLE. This means that we can safely assume that the statement
Expand Down Expand Up @@ -303,20 +308,36 @@ func findParams(positionalParamChar rune, sql string) (string, []string, error)
}

// isDDL returns true if the given sql string is a DDL statement.
func isDDL(query string) (bool, error) {
query, err := removeCommentsAndTrim(query)
if err != nil {
return false, err
}
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isDDL(query string) bool {
return isStatementType(query, ddlStatements)
}

// isDml returns true if the given sql string is a Dml statement.
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isDml(query string) bool {
return isStatementType(query, dmlStatements)
}

// isQuery returns true if the given sql string is a SELECT statement.
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isQuery(query string) bool {
return isStatementType(query, selectStatements)
}

func isStatementType(query string, keywords map[string]bool) bool {
// We can safely check if the string starts with a specific string, as we
// have already removed all leading spaces, and there are no keywords that
// start with the same substring as one of the DDL keywords.
for ddl := range ddlStatements {
if len(query) >= len(ddl) && strings.EqualFold(query[:len(ddl)], ddl) {
return true, nil
// start with the same substring as one of the keywords.
for keyword := range keywords {
if len(query) >= len(keyword) && strings.EqualFold(query[:len(keyword)], keyword) {
return true
}
}
return false, nil
return false
}

// clientSideStatements are loaded from the client_side_statements.json file.
Expand Down Expand Up @@ -431,3 +452,30 @@ func parseClientSideStatement(c *conn, query string) (*executableClientSideState
}
return nil, nil
}

type statementType int

const (
statementTypeUnknown statementType = iota
statementTypeQuery
statementTypeDml
statementTypeDdl
)

// detectStatementType returns the type of SQL statement based on the first
// keyword that is found in the SQL statement.
func detectStatementType(sql string) statementType {
sql, err := removeCommentsAndTrim(sql)
if err != nil {
return statementTypeUnknown
}
sql = removeStatementHint(sql)
if isQuery(sql) {
return statementTypeQuery
} else if isDml(sql) {
return statementTypeDml
} else if isDDL(sql) {
return statementTypeDdl
}
return statementTypeUnknown
}
Loading

0 comments on commit ae36d4c

Please sign in to comment.