diff --git a/driver.go b/driver.go index 11cf2e87..0f892fa5 100644 --- a/driver.go +++ b/driver.go @@ -352,10 +352,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } if isDdl { - // TODO: Determine whether we want to return an error if a transaction - // is active. Cloud Spanner does not support DDL in transactions, but - // this makes it seem like the DDL statement is executed on the - // transaction on this connection if it has a transaction. + // Spanner does not support DDL in transactions, and although it is technically possible to execute DDL + // statements while a transaction is active, we return an error to avoid any confusion whether the DDL + // statement is executed as part of the active transaction or not. + if c.inTransaction() { + return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "cannot execute DDL as part of a transaction")) + } op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ Database: c.database, Statements: []string{query}, diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 1b1be7f4..6c979576 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -953,45 +953,18 @@ func TestDdlInTransaction(t *testing.T) { db, server, teardown := setupTestDBConnection(t) defer teardown() - var expectedResponse = &emptypb.Empty{} - any, _ := ptypes.MarshalAny(expectedResponse) - server.TestDatabaseAdmin.SetResps([]proto.Message{ - &longrunningpb.Operation{ - Done: true, - Result: &longrunningpb.Operation_Response{Response: any}, - Name: "test-operation", - }, - }) query := "CREATE TABLE Singers (SingerId INT64, FirstName STRING(100), LastName STRING(100)) PRIMARY KEY (SingerId)" tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { t.Fatal(err) } - res, err := tx.ExecContext(context.Background(), query) - if err != nil { - t.Fatal(err) - } - affected, err := res.RowsAffected() - if err != nil { - t.Fatal(err) - } - if affected != 0 { - t.Fatalf("affected rows count mismatch\nGot: %v\nWant: %v", affected, 0) + if _, err := tx.ExecContext(context.Background(), query); spanner.ErrCode(err) != codes.FailedPrecondition { + t.Fatalf("error mismatch\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.FailedPrecondition) } requests := server.TestDatabaseAdmin.Reqs() - if g, w := len(requests), 1; g != w { + if g, w := len(requests), 0; g != w { t.Fatalf("requests count mismatch\nGot: %v\nWant: %v", g, w) } - if req, ok := requests[0].(*databasepb.UpdateDatabaseDdlRequest); ok { - if g, w := len(req.Statements), 1; g != w { - t.Fatalf("statement count mismatch\nGot: %v\nWant: %v", g, w) - } - if g, w := req.Statements[0], query; g != w { - t.Fatalf("statement mismatch\nGot: %v\nWant: %v", g, w) - } - } else { - t.Fatalf("request type mismatch, got %v", requests[0]) - } } func TestBegin(t *testing.T) {