Skip to content

Commit

Permalink
feat: support LastInsertId in transactions (#372)
Browse files Browse the repository at this point in the history
Adds support for LastInsertId in read/write transactions.

Fixes #346
  • Loading branch information
olavloite authored Feb 7, 2025
1 parent 3b3eefb commit eb9a4ef
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 93 deletions.
91 changes: 87 additions & 4 deletions aborted_transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestUpdateAborted(t *testing.T) {
if err != nil {
t.Fatalf("begin failed: %v", err)
}
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteSql, testutil.SimulatedExecutionTime{
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
res, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo)
Expand Down Expand Up @@ -588,7 +588,7 @@ func TestSecondUpdateAborted(t *testing.T) {
t.Fatalf("update singers failed: %v", err)
}

server.TestSpanner.PutExecutionTime(testutil.MethodExecuteSql, testutil.SimulatedExecutionTime{
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
// This statement will return Aborted, the transaction will be retried internally and the statement is
Expand Down Expand Up @@ -695,7 +695,7 @@ func TestSecondUpdateAborted_FirstStatementWithSameError(t *testing.T) {
t.Fatalf("error code mismatch\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.NotFound)
}

server.TestSpanner.PutExecutionTime(testutil.MethodExecuteSql, testutil.SimulatedExecutionTime{
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
// This statement will return Aborted, the transaction will be retried internally and the statement is
Expand Down Expand Up @@ -786,7 +786,7 @@ func testSecondUpdateAborted_FirstResultChanged(t *testing.T, firstResult *testu
// Update the result to simulate a different result during the retry.
server.TestSpanner.PutStatementResult(testutil.UpdateSingersSetLastName, secondResult)

server.TestSpanner.PutExecutionTime(testutil.MethodExecuteSql, testutil.SimulatedExecutionTime{
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
// This statement will return Aborted and the transaction will be retried internally. That
Expand Down Expand Up @@ -970,6 +970,89 @@ func TestBatchUpdateAbortedWithError_DifferentErrorDuringRetry(t *testing.T) {
}
}

func TestLastInsertId(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

query := "insert into singers (name) values ('foo') then return id"
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateSingleColumnInt64ResultSet([]int64{1}, "id"),
UpdateCount: 1,
})

ctx := context.Background()
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
t.Fatalf("begin failed: %v", err)
}
if res, err := tx.ExecContext(ctx, query); err != nil {
t.Fatalf("failed to execute statement: %v", err)
} else {
if id, err := res.LastInsertId(); err != nil {
t.Fatalf("failed to get last insert id: %v", err)
} else if g, w := id, int64(1); g != w {
t.Fatalf("last insert id mismatch\n Got: %v\nWant: %v", g, w)
}
if c, err := res.RowsAffected(); err != nil {
t.Fatalf("failed to get update count: %v", err)
} else if g, w := c, int64(1); g != w {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}
if err := tx.Commit(); err != nil {
t.Fatalf("failed to commit: %v", err)
}
}

func TestLastInsertIdChanged(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()

query := "insert into singers (name) values ('foo') then return id"
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateSingleColumnInt64ResultSet([]int64{1}, "id"),
UpdateCount: 1,
})

ctx := context.Background()
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
t.Fatalf("begin failed: %v", err)
}
if res, err := tx.ExecContext(ctx, query); err != nil {
t.Fatalf("failed to execute statement: %v", err)
} else {
if id, err := res.LastInsertId(); err != nil {
t.Fatalf("failed to get last insert id: %v", err)
} else if g, w := id, int64(1); g != w {
t.Fatalf("last insert id mismatch\n Got: %v\nWant: %v", g, w)
}
if c, err := res.RowsAffected(); err != nil {
t.Fatalf("failed to get update count: %v", err)
} else if g, w := c, int64(1); g != w {
t.Fatalf("update count mismatch\n Got: %v\nWant: %v", g, w)
}
}
// Abort the transaction and change the returned ID.
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: testutil.CreateSingleColumnInt64ResultSet([]int64{2}, "id"),
UpdateCount: 1,
})
if err := tx.Commit(); err != ErrAbortedDueToConcurrentModification {
t.Fatalf("commit error mismatch\n Got: %v\nWant: %v", err, ErrAbortedDueToConcurrentModification)
}
}

func firstNonZero(values ...int) int {
for _, v := range values {
if v > 0 {
Expand Down
68 changes: 38 additions & 30 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1791,9 +1791,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
}
}
} else {
var rowsAffected int64
rowsAffected, err = c.tx.ExecContext(ctx, ss, execOptions.QueryOptions)
res = &result{rowsAffected: rowsAffected}
res, err = c.tx.ExecContext(ctx, ss, statementInfo, execOptions.QueryOptions)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -2099,42 +2097,52 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "Exec and ExecContext can only be used with INSERT statements with a THEN RETURN clause that return exactly one row with one column of type INT64. Use Query or QueryContext for DML statements other than INSERT and/or with THEN RETURN clauses that return other/more data."))

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, time.Time, error) {
var rowsAffected int64
var lastInsertId int64
var hasLastInsertId bool
var res *result
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
it := tx.QueryWithOptions(ctx, statement, options.QueryOptions)
defer it.Stop()
row, err := it.Next()
if err != nil && err != iterator.Done {
var err error
res, err = execTransactionalDML(ctx, tx, statement, statementInfo, options.QueryOptions)
if err != nil {
return err
}
if len(it.Metadata.RowType.Fields) != 0 && !(len(it.Metadata.RowType.Fields) == 1 &&
it.Metadata.RowType.Fields[0].Type.Code == spannerpb.TypeCode_INT64 &&
statementInfo.dmlType == dmlTypeInsert) {
return errInvalidDmlForExecContext
}
if err != iterator.Done {
if err := row.Column(0, &lastInsertId); err != nil {
return err
}
// Verify that the result set only contains one row.
_, err = it.Next()
if err == iterator.Done {
hasLastInsertId = true
} else {
// Statement returned more than one row.
return errInvalidDmlForExecContext
}
}
rowsAffected = it.RowCount
return nil
}
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions)
if err != nil {
return &result{}, time.Time{}, err
}
return &result{rowsAffected: rowsAffected, lastInsertId: lastInsertId, hasLastInsertId: hasLastInsertId}, resp.CommitTs, nil
return res, resp.CommitTs, nil
}

func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement spanner.Statement, statementInfo *statementInfo, options spanner.QueryOptions) (*result, error) {
var rowsAffected int64
var lastInsertId int64
var hasLastInsertId bool
it := tx.QueryWithOptions(ctx, statement, options)
defer it.Stop()
row, err := it.Next()
if err != nil && err != iterator.Done {
return nil, err
}
if len(it.Metadata.RowType.Fields) != 0 && !(len(it.Metadata.RowType.Fields) == 1 &&
it.Metadata.RowType.Fields[0].Type.Code == spannerpb.TypeCode_INT64 &&
statementInfo.dmlType == dmlTypeInsert) {
return nil, errInvalidDmlForExecContext
}
if err != iterator.Done {
if err := row.Column(0, &lastInsertId); err != nil {
return nil, err
}
// Verify that the result set only contains one row.
_, err = it.Next()
if err == iterator.Done {
hasLastInsertId = true
} else {
// Statement returned more than one row.
return nil, errInvalidDmlForExecContext
}
}
rowsAffected = it.RowCount
return &result{rowsAffected: rowsAffected, lastInsertId: lastInsertId, hasLastInsertId: hasLastInsertId}, nil
}

func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) {
Expand Down
Loading

0 comments on commit eb9a4ef

Please sign in to comment.