Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DML with THEN RETURN used read-only transaction #339

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 85 additions & 17 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -323,14 +324,15 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
}

return &conn{
connector: c,
client: c.client,
adminClient: c.adminClient,
database: databaseName,
retryAborts: c.retryAbortsInternally,
execSingleQuery: queryInSingleUse,
execSingleDMLTransactional: execInNewRWTransaction,
execSingleDMLPartitioned: execAsPartitionedDML,
connector: c,
client: c.client,
adminClient: c.adminClient,
database: databaseName,
retryAborts: c.retryAbortsInternally,
execSingleQuery: queryInSingleUse,
execSingleQueryTransactional: queryInNewRWTransaction,
execSingleDMLTransactional: execInNewRWTransaction,
execSingleDMLPartitioned: execAsPartitionedDML,
}, nil
}

Expand Down Expand Up @@ -638,9 +640,10 @@ type conn struct {
database string
retryAborts bool

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error)
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (rowIterator, time.Time, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error)

// batch is the currently active DDL or DML batch on this connection.
batch *batch
Expand Down Expand Up @@ -1130,7 +1133,21 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}
var iter rowIterator
if c.tx == nil {
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)}
statementType := detectStatementType(query)
if statementType == statementTypeDml {
// Use a read/write transaction to execute the statement.
var commitTs time.Time
iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, c.createTransactionOptions())
if err != nil {
return nil, err
}
c.commitTs = &commitTs
} else {
// The statement was either detected as being a query, or potentially not recognized at all.
// In that case, just default to using a single-use read-only transaction and let Spanner
// return an error if the statement is not suited for that type of transaction.
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)}
}
} else {
iter = c.tx.Query(ctx, stmt)
}
Expand All @@ -1149,12 +1166,9 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
// Clear the commit timestamp of this connection before we execute the statement.
c.commitTs = nil

statementType := detectStatementType(query)
// Use admin API if DDL statement is provided.
isDDL, err := isDDL(query)
if err != nil {
return nil, err
}
if isDDL {
if statementType == statementTypeDdl {
// Spanner does not support DDL in transactions, and although it is technically possible to execute DDL
// statements while a transaction is active, we return an error to avoid any confusion whether the DDL
// statement is executed as part of the active transaction or not.
Expand Down Expand Up @@ -1312,6 +1326,60 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
return c.Single().WithTimestampBound(tb).Query(ctx, statement)
}

type wrappedRowIterator struct {
*spanner.RowIterator

noRows bool
firstRow *spanner.Row
}

func (ri *wrappedRowIterator) Next() (*spanner.Row, error) {
if ri.noRows {
return nil, iterator.Done
}
if ri.firstRow != nil {
defer func() { ri.firstRow = nil }()
return ri.firstRow, nil
}
return ri.RowIterator.Next()
}

func (ri *wrappedRowIterator) Stop() {
ri.RowIterator.Stop()
}

func (ri *wrappedRowIterator) Metadata() *spannerpb.ResultSetMetadata {
return ri.RowIterator.Metadata
}

func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (rowIterator, time.Time, error) {
var result *wrappedRowIterator
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
it := tx.Query(ctx, statement)
row, err := it.Next()
if err == iterator.Done {
result = &wrappedRowIterator{
RowIterator: it,
noRows: true,
}
} else if err != nil {
it.Stop()
return err
} else {
result = &wrappedRowIterator{
RowIterator: it,
firstRow: row,
}
}
return nil
}
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options)
if err != nil {
return nil, time.Time{}, err
}
return result, resp.CommitTs, nil
}

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) {
var rowsAffected int64
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
Expand Down
59 changes: 59 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,65 @@ func TestQueryWithMissingPositionalParameter(t *testing.T) {
}
}

func TestDmlReturningInAutocommit(t *testing.T) {
t.Parallel()

ctx := context.Background()
db, server, teardown := setupTestDBConnection(t)
defer teardown()

s := "insert into users (id, name) values (@id, @name) then return id"
_ = server.TestSpanner.PutStatementResult(
s,
&testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateSelect1ResultSet(),
},
)

for _, prepare := range []bool{false, true} {
var rows *sql.Rows
var err error
if prepare {
var stmt *sql.Stmt
stmt, err = db.PrepareContext(ctx, s)
if err != nil {
t.Fatal(err)
}
rows, err = stmt.QueryContext(ctx, sql.Named("id", 1), sql.Named("name", "bar"))
} else {
rows, err = db.QueryContext(ctx, s, sql.Named("id", 1), sql.Named("name", "bar"))
}
if err != nil {
t.Fatal(err)
}
if !rows.Next() {
t.Fatal("missing row")
}
var id int
if err := rows.Scan(&id); err != nil {
t.Fatal(err)
}
if g, w := id, 1; g != w {
t.Fatalf("id mismatch\n Got: %v\nWant: %v", g, w)
}
if rows.Next() {
t.Fatal("got more rows than expected")
}

// Verify that a read/write transaction was used.
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
}
}
}

func TestDdlInAutocommit(t *testing.T) {
t.Parallel()

Expand Down
26 changes: 23 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,29 @@ func TestExecContextDml(t *testing.T) {
}

// Clear table for next test.
_, err = db.ExecContext(ctx, `DELETE FROM TestExecContextDml WHERE true`)
if (err != nil) && (!tc.then.wantError) {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
it, err := db.QueryContext(ctx, `DELETE FROM TestExecContextDml WHERE true THEN RETURN *`)
if err != nil {
if !tc.then.wantError {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
}
} else {
if cols, err := it.Columns(); err != nil {
t.Errorf("%s: unexpected Columns() error: %v", tc.name, err)
} else {
if g, w := len(cols), 6; g != w {
t.Errorf("%s: column number mismatch\n Got: %v\nWant: %v", tc.name, g, w)
}
if !cmp.Equal(cols, []string{"key", "testString", "testBytes", "testInt", "testFloat", "testBool"}) {
t.Errorf("%s: column names mismatch: %v", tc.name, cols)
}
}
for it.Next() {
if err := it.Err(); err != nil {
t.Errorf("%s: unexpected iterator error: %v", tc.name, err)
break
}
}
it.Close()
}
}
}
Expand Down
68 changes: 58 additions & 10 deletions statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ func removeCommentsAndTrim(sql string) (string, error) {
// Removes any statement hints at the beginning of the statement.
// It assumes that any comments have already been removed.
func removeStatementHint(sql string) string {
// Return quickly if the statement does not start with a hint.
if len(sql) < 2 || sql[0] != '@' {
return sql
}

// Valid statement hints at the beginning of a query statement can only contain a fixed set of
// possible values. Although it is possible to add a @{FORCE_INDEX=...} as a statement hint, the
// only allowed value is _BASE_TABLE. This means that we can safely assume that the statement
Expand Down Expand Up @@ -303,20 +308,36 @@ func findParams(positionalParamChar rune, sql string) (string, []string, error)
}

// isDDL returns true if the given sql string is a DDL statement.
func isDDL(query string) (bool, error) {
query, err := removeCommentsAndTrim(query)
if err != nil {
return false, err
}
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isDDL(query string) bool {
return isStatementType(query, ddlStatements)
}

// isDml returns true if the given sql string is a Dml statement.
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isDml(query string) bool {
return isStatementType(query, dmlStatements)
}

// isQuery returns true if the given sql string is a SELECT statement.
// This function assumes that any comments and hints at the start
// of the sql string have been removed.
func isQuery(query string) bool {
return isStatementType(query, selectStatements)
}

func isStatementType(query string, keywords map[string]bool) bool {
// We can safely check if the string starts with a specific string, as we
// have already removed all leading spaces, and there are no keywords that
// start with the same substring as one of the DDL keywords.
for ddl := range ddlStatements {
if len(query) >= len(ddl) && strings.EqualFold(query[:len(ddl)], ddl) {
return true, nil
// start with the same substring as one of the keywords.
for keyword := range keywords {
if len(query) >= len(keyword) && strings.EqualFold(query[:len(keyword)], keyword) {
return true
}
}
return false, nil
return false
}

// clientSideStatements are loaded from the client_side_statements.json file.
Expand Down Expand Up @@ -431,3 +452,30 @@ func parseClientSideStatement(c *conn, query string) (*executableClientSideState
}
return nil, nil
}

type statementType int

const (
statementTypeUnknown statementType = iota
statementTypeQuery
statementTypeDml
statementTypeDdl
)

// detectStatementType returns the type of SQL statement based on the first
// keyword that is found in the SQL statement.
func detectStatementType(sql string) statementType {
sql, err := removeCommentsAndTrim(sql)
if err != nil {
return statementTypeUnknown
}
sql = removeStatementHint(sql)
if isQuery(sql) {
return statementTypeQuery
} else if isDml(sql) {
return statementTypeDml
} else if isDDL(sql) {
return statementTypeDdl
}
return statementTypeUnknown
}
Loading
Loading