diff --git a/README.md b/README.md index b3a62d6f..f5f07771 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,38 @@ tx, err := conn.BeginTx(ctx, &sql.TxOptions{ }) ``` +## Transaction Runner (Retry Transactions) + +Spanner can abort a read/write transaction if concurrent modifications are detected +that would violate the transaction consistency. When this happens, the driver will +return the `ErrAbortedDueToConcurrentModification` error. You can use the +`RunTransaction` function to let the driver automatically retry transactions that +are aborted by Spanner. + +```go +package sample + +import ( + "context" + "database/sql" + "fmt" + + _ "github.com/googleapis/go-sql-spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" +) + +spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + row := tx.QueryRowContext(ctx, "select Name from Singers where SingerId=@id", 123) + var name string + if err := row.Scan(&name); err != nil { + return err + } + return nil +}) +``` + +See also the [transaction runner sample](./examples/run-transaction/main.go). + ## DDL Statements [DDL statements](https://cloud.google.com/spanner/docs/data-definition-language) diff --git a/driver.go b/driver.go index 5cd012d4..647dfc20 100644 --- a/driver.go +++ b/driver.go @@ -32,6 +32,7 @@ import ( adminapi "cloud.google.com/go/spanner/admin/database/apiv1" adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/gax-go/v2" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -403,6 +404,103 @@ func (c *connector) closeClients() (err error) { return err } +// RunTransaction runs the given function in a transaction on the given database. +// If the connection is a connection to a Spanner database, the transaction will +// automatically be retried if the transaction is aborted by Spanner. Any other +// errors will be propagated to the caller and the transaction will be rolled +// back. The transaction will be committed if the supplied function did not +// return an error. +// +// If the connection is to a non-Spanner database, no retries will be attempted, +// and any error that occurs during the transaction will be propagated to the +// caller. +// +// The application should *NOT* call tx.Commit() or tx.Rollback(). This is done +// automatically by this function, depending on whether the transaction function +// returned an error or not. +// +// This function will never return ErrAbortedDueToConcurrentModification. +func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error) error { + // Get a connection from the pool that we can use to run a transaction. + // Getting a connection here already makes sure that we can reserve this + // connection exclusively for the duration of this method. That again + // allows us to temporarily change the state of the connection (e.g. set + // the retryAborts flag to false). + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + + // We don't need to keep track of a running checksum for retries when using + // this method, so we disable internal retries. + // Retries will instead be handled by the loop below. + origRetryAborts := false + var spannerConn SpannerConn + if err := conn.Raw(func(driverConn any) error { + var ok bool + spannerConn, ok = driverConn.(SpannerConn) + if !ok { + // It is not a Spanner connection, so just ignore and continue without any special handling. + return nil + } + origRetryAborts = spannerConn.RetryAbortsInternally() + return spannerConn.SetRetryAbortsInternally(false) + }); err != nil { + return err + } + // Reset the flag for internal retries after the transaction (if applicable). + if origRetryAborts { + defer func() { _ = spannerConn.SetRetryAbortsInternally(origRetryAborts) }() + } + + tx, err := conn.BeginTx(ctx, opts) + if err != nil { + return err + } + for { + err = f(ctx, tx) + if err == nil { + err = tx.Commit() + if err == nil { + return nil + } + } + // Rollback and return the error if: + // 1. The connection is not a Spanner connection. + // 2. Or the error code is not Aborted. + if spannerConn == nil || spanner.ErrCode(err) != codes.Aborted { + // We don't really need to call Rollback here if the error happened + // during the Commit. However, the SQL package treats this as a no-op + // and just returns an ErrTxDone if we do, so this is simpler than + // keeping track of where the error happened. + _ = tx.Rollback() + return err + } + + // The transaction was aborted by Spanner. + // Back off and retry the entire transaction. + if delay, ok := spanner.ExtractRetryDelay(err); ok { + err = gax.Sleep(ctx, delay) + if err != nil { + // We need to 'roll back' the transaction here to tell the sql + // package that there is no active transaction on the connection + // anymore. It does not actually roll back the transaction, as it + // has already been aborted by Spanner. + _ = tx.Rollback() + return err + } + } + + // TODO: Reset the existing transaction for retry instead of creating a new one. + _ = tx.Rollback() + tx, err = conn.BeginTx(ctx, opts) + if err != nil { + return err + } + } +} + // SpannerConn is the public interface for the raw Spanner connection for the // sql driver. This interface can be used with the db.Conn().Raw() method. type SpannerConn interface { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 1acf1d47..9b234ddd 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -1109,7 +1109,7 @@ func TestQueryWithDuplicateNamedParameter(t *testing.T) { defer teardown() s := "insert into users (id, name) values (@name, @name)" - server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1139,7 +1139,7 @@ func TestQueryWithReusedNamedParameter(t *testing.T) { defer teardown() s := "insert into users (id, name) values (@name, @name)" - server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1169,7 +1169,7 @@ func TestQueryWithReusedPositionalParameter(t *testing.T) { defer teardown() s := "insert into users (id, name) values (@name, @name)" - server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1199,7 +1199,7 @@ func TestQueryWithMissingPositionalParameter(t *testing.T) { defer teardown() s := "insert into users (id, name) values (@name, @name)" - server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult(s, &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1229,11 +1229,11 @@ func TestDdlInAutocommit(t *testing.T) { defer teardown() var expectedResponse = &emptypb.Empty{} - any, _ := anypb.New(expectedResponse) + anyMsg, _ := anypb.New(expectedResponse) server.TestDatabaseAdmin.SetResps([]proto.Message{ &longrunningpb.Operation{ Done: true, - Result: &longrunningpb.Operation_Response{Response: any}, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, Name: "test-operation", }, }) @@ -1491,11 +1491,11 @@ func TestDdlBatch(t *testing.T) { defer teardown() var expectedResponse = &emptypb.Empty{} - any, _ := anypb.New(expectedResponse) + anyMsg, _ := anypb.New(expectedResponse) server.TestDatabaseAdmin.SetResps([]proto.Message{ &longrunningpb.Operation{ Done: true, - Result: &longrunningpb.Operation_Response{Response: any}, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, Name: "test-operation", }, }) @@ -1667,7 +1667,7 @@ func TestPartitionedDml(t *testing.T) { t.Fatalf("could not set autocommit dml mode: %v", err) } - server.TestSpanner.PutStatementResult("DELETE FROM Foo WHERE TRUE", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("DELETE FROM Foo WHERE TRUE", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 200, }) @@ -1721,11 +1721,11 @@ func TestAutocommitBatchDml(t *testing.T) { if _, err := c.ExecContext(ctx, "START BATCH DML"); err != nil { t.Fatalf("could not start a DML batch: %v", err) } - server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) - server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1805,11 +1805,11 @@ func TestTransactionBatchDml(t *testing.T) { if _, err := tx.ExecContext(ctx, "START BATCH DML"); err != nil { t.Fatalf("could not start a DML batch: %v", err) } - server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) - server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -1875,7 +1875,7 @@ func TestTransactionBatchDml(t *testing.T) { // Executing another DML statement on the same transaction now that the batch has been // executed should cause the statement to be sent to Spanner. - server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (3, 'Three')", &testutil.StatementResult{ + _ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (3, 'Three')", &testutil.StatementResult{ Type: testutil.StatementResultUpdateCount, UpdateCount: 1, }) @@ -2311,10 +2311,10 @@ func TestExcludeTxnFromChangeStreams_AutoCommitBatchDml(t *testing.T) { t.Fatalf("failed to get a connection: %v", err) } - conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") - conn.ExecContext(ctx, "start batch dml") - conn.ExecContext(ctx, testutil.UpdateBarSetFoo) - conn.ExecContext(ctx, "run batch") + _, _ = conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") + _, _ = conn.ExecContext(ctx, "start batch dml") + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = conn.ExecContext(ctx, "run batch") requests := drainRequestsFromServer(server.TestSpanner) batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) if g, w := len(batchRequests), 1; g != w { @@ -2380,13 +2380,13 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) { if g, w := exclude, false; g != w { t.Fatalf("exclude_txn_from_change_streams mismatch\n Got: %v\nWant: %v", g, w) } - conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") + _, _ = conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) if err != nil { t.Fatal(err) } - conn.ExecContext(ctx, testutil.UpdateBarSetFoo) - tx.Commit() + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Commit() requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{})) @@ -2519,6 +2519,364 @@ func TestCannotReuseClosedConnector(t *testing.T) { } } +func TestRunTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer rows.Close() + // Verify that internal retries are disabled during RunTransaction + row := tx.QueryRow("show variable retry_aborts_internally") + var retry bool + if err := row.Scan(&retry); err != nil { + return err + } + if retry { + return fmt.Errorf("internal retries should be disabled during RunTransaction") + } + + for want := int64(1); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + return err + } + if !cmp.Equal(cols, []string{"FOO"}) { + return fmt.Errorf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"FOO"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if err := rows.Err(); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + // Verify that internal retries are enabled again after RunTransaction + row := db.QueryRow("show variable retry_aborts_internally") + var retry bool + if err := row.Scan(&retry); err != nil { + t.Fatal(err) + } + if !retry { + t.Fatal("internal retries should be enabled after RunTransaction") + } + + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlRequests), 1; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + req := sqlRequests[0].(*sppb.ExecuteSqlRequest) + if req.Transaction == nil { + t.Fatalf("missing transaction for ExecuteSqlRequest") + } + if req.Transaction.GetId() == nil { + t.Fatalf("missing id selector for ExecuteSqlRequest") + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitReq := commitRequests[0].(*sppb.CommitRequest) + if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { + t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + } +} + +func TestRunTransactionCommitAborted(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + attempts := 0 + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + attempts++ + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer rows.Close() + + for want := int64(1); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + return err + } + if !cmp.Equal(cols, []string{"FOO"}) { + return fmt.Errorf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"FOO"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if err := rows.Err(); err != nil { + return err + } + // Instruct the mock server to abort the transaction. + if attempts == 1 { + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + // There should be two requests, as the transaction is aborted and retried. + if g, w := len(sqlRequests), 2; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 2; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + for i := 0; i < 2; i++ { + req := sqlRequests[i].(*sppb.ExecuteSqlRequest) + if req.Transaction == nil { + t.Fatalf("missing transaction for ExecuteSqlRequest") + } + if req.Transaction.GetId() == nil { + t.Fatalf("missing id selector for ExecuteSqlRequest") + } + commitReq := commitRequests[i].(*sppb.CommitRequest) + if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { + t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + } + } +} + +func TestRunTransactionQueryAborted(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + attempts := 0 + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + attempts++ + // Instruct the mock server to abort the transaction. + if attempts == 1 { + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer rows.Close() + + for want := int64(1); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + return err + } + if !cmp.Equal(cols, []string{"FOO"}) { + return fmt.Errorf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"FOO"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if err := rows.Err(); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + // There should be two ExecuteSql requests, as the transaction is aborted and retried. + if g, w := len(sqlRequests), 2; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + // There should be only 1 CommitRequest, as the transaction is aborted before + // the first commit attempt. + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + req := sqlRequests[1].(*sppb.ExecuteSqlRequest) + if req.Transaction == nil { + t.Fatalf("missing transaction for ExecuteSqlRequest") + } + if req.Transaction.GetId() == nil { + t.Fatalf("missing id selector for ExecuteSqlRequest") + } + commitReq := commitRequests[0].(*sppb.CommitRequest) + if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { + t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + } +} + +func TestRunTransactionQueryError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.NotFound, "Table not found")}, + }) + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer rows.Close() + + for want := int64(1); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + return err + } + if !cmp.Equal(cols, []string{"FOO"}) { + return fmt.Errorf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"FOO"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if err := rows.Err(); err != nil { + return err + } + return nil + }) + if err == nil { + t.Fatal("missing transaction error") + } + if g, w := spanner.ErrCode(err), codes.NotFound; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlRequests), 1; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + // There should be no CommitRequest, as the transaction failed + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + // There should be a RollbackRequest, as the transaction failed. + rollbackRequests := requestsOfType(requests, reflect.TypeOf(&sppb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("rollback requests count mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func TestRunTransactionCommitError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + rows, err := tx.Query(testutil.SelectFooFromBar) + if err != nil { + return err + } + defer rows.Close() + + for want := int64(1); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + return err + } + if !cmp.Equal(cols, []string{"FOO"}) { + return fmt.Errorf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"FOO"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + return err + } + if got != want { + return fmt.Errorf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if err := rows.Err(); err != nil { + return err + } + // Add an error for the Commit RPC. This will make the transaction fail, + // as the commit fails. + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.FailedPrecondition, "Unique key constraint violation")}, + }) + return nil + }) + if err == nil { + t.Fatal("missing transaction error") + } + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + requests := drainRequestsFromServer(server.TestSpanner) + sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(sqlRequests), 1; g != w { + t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + // There should be no CommitRequest, as the transaction failed + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + } + // A Rollback request should normally not be necessary, as the Commit RPC + // already closed the transaction. However, the Spanner client also sends + // a RollbackRequest if a Commit fails. + // TODO: Revisit once the client library has been checked whether it is really + // necessary to send a Rollback after a failed Commit. + rollbackRequests := requestsOfType(requests, reflect.TypeOf(&sppb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("rollback requests count mismatch\nGot: %v\nWant: %v", g, w) + } +} + func numeric(v string) big.Rat { res, _ := big.NewRat(1, 1).SetString(v) return *res diff --git a/examples/run-transaction/main.go b/examples/run-transaction/main.go new file mode 100644 index 00000000..4065e754 --- /dev/null +++ b/examples/run-transaction/main.go @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "fmt" + "sync" + + _ "github.com/googleapis/go-sql-spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/examples" +) + +var createTableStatement = "CREATE TABLE Singers (SingerId INT64, Name STRING(MAX)) PRIMARY KEY (SingerId)" + +// Example for running a read/write transaction in a retry loop on a Spanner database. +// The RunTransaction function automatically retries Aborted transactions using a +// retry loop. This guarantees that the transaction will not fail with an +// ErrAbortedDueToConcurrentModification. +// +// Execute the sample with the command `go run main.go` from this directory. +func runTransaction(projectId, instanceId, databaseId string) error { + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return err + } + defer db.Close() + + // Insert a new record that will be updated by multiple different transactions at the same time. + _, err = db.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", 123, "Bruce Allison") + if err != nil { + return err + } + + numTransactions := 10 + errors := make([]error, numTransactions) + wg := sync.WaitGroup{} + for i := 0; i < numTransactions; i++ { + index := i + wg.Add(1) + go func() { + defer wg.Done() + // Run a transaction that adds an index to the name of the singer. + // As we are doing this multiple times in parallel, these transactions + // will be aborted and retried by Spanner multiple times. The end result + // will still be that all transactions succeed and the name contains all + // indexes in an undefined order. + errors[index] = spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + // Query the singer in the transaction. This will take a lock on the row and guarantee that + // the value that we read is still the same when the transaction is committed. If not, Spanner + // will abort the transaction, and the transaction will be retried. + row := tx.QueryRowContext(ctx, "select Name from Singers where SingerId=@id", 123) + var name string + if err := row.Scan(&name); err != nil { + return err + } + // Update the name with the transaction index. + name = fmt.Sprintf("%s %d", name, index) + res, err := tx.ExecContext(ctx, "update Singers set Name=@name where SingerId=@id", name, 123) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected != 1 { + return fmt.Errorf("unexpected affected row count: %d", affected) + } + return nil + }) + }() + } + wg.Wait() + + // The name of the singer should now contain all the indexes that were added in the + // transactions above in arbitrary order. + row := db.QueryRowContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) + if err := row.Err(); err != nil { + return err + } + var id int64 + var name string + if err := row.Scan(&id, &name); err != nil { + return err + } + fmt.Printf("Singer after %d transactions: %v %v\n", numTransactions, id, name) + + return nil +} + +func main() { + examples.RunSampleOnEmulator(runTransaction, createTableStatement) +} diff --git a/transaction.go b/transaction.go index 6e92b6a5..67894133 100644 --- a/transaction.go +++ b/transaction.go @@ -117,6 +117,9 @@ func (tx *readOnlyTransaction) BufferWrite([]*spanner.Mutation) error { // that was aborted by Cloud Spanner, and where the internal retry attempt // failed because it detected that the results during the retry were different // from the initial attempt. +// +// Use the RunTransaction function to execute a read/write transaction in a +// retry loop. This function will never return ErrAbortedDueToConcurrentModification. var ErrAbortedDueToConcurrentModification = status.Error(codes.Aborted, "Transaction was aborted due to a concurrent modification") // readWriteTransaction is the internal structure for go/sql read/write