diff --git a/driver.go b/driver.go index 057751ba..a359701b 100644 --- a/driver.go +++ b/driver.go @@ -22,6 +22,7 @@ import ( "math/big" "reflect" "regexp" + "slices" "strconv" "strings" "sync" @@ -535,6 +536,10 @@ type SpannerConn interface { InDDLBatch() bool // InDMLBatch returns true if the connection is currently in a DML batch. InDMLBatch() bool + // GetBatchedStatements returns a copy of the statements that are currently + // buffered to be executed as a DML or DDL batch. It returns an empty slice + // if no batch is active, or if there are no statements buffered. + GetBatchedStatements() []spanner.Statement // RetryAbortsInternally returns true if the connection automatically // retries all aborted transactions. @@ -755,6 +760,13 @@ func (c *conn) InDMLBatch() bool { return (c.batch != nil && c.batch.tp == dml) || (c.inReadWriteTransaction() && c.tx.(*readWriteTransaction).batch != nil) } +func (c *conn) GetBatchedStatements() []spanner.Statement { + if c.batch == nil || c.batch.statements == nil { + return []spanner.Statement{} + } + return slices.Clone(c.batch.statements) +} + func (c *conn) inBatch() bool { return c.InDDLBatch() || c.InDMLBatch() } diff --git a/driver_test.go b/driver_test.go index 8a34f1b2..a5a6f6a5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -18,6 +18,7 @@ import ( "context" "database/sql/driver" "net" + "reflect" "testing" "time" @@ -425,6 +426,52 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) { } } +func TestConn_GetBatchedStatements(t *testing.T) { + t.Parallel() + + ctx := context.Background() + c := &conn{} + if !reflect.DeepEqual(c.GetBatchedStatements(), []spanner.Statement{}) { + t.Fatal("conn should return an empty slice when no batch is active") + } + if err := c.StartBatchDDL(); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(c.GetBatchedStatements(), []spanner.Statement{}) { + t.Fatal("conn should return an empty slice when a batch contains no statements") + } + if _, err := c.ExecContext(ctx, "create table table1", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "create table table2", []driver.NamedValue{}); err != nil { + t.Fatal(err) + } + batchedStatements := c.GetBatchedStatements() + if !reflect.DeepEqual([]spanner.Statement{ + {SQL: "create table table1", Params: map[string]interface{}{}}, + {SQL: "create table table2", Params: map[string]interface{}{}}, + }, batchedStatements) { + t.Errorf("unexpected batched statements: %v", batchedStatements) + } + + // Changing the returned slice does not change the batched statements. + batchedStatements[0] = spanner.Statement{SQL: "drop table table1"} + batchedStatements2 := c.GetBatchedStatements() + if !reflect.DeepEqual([]spanner.Statement{ + {SQL: "create table table1", Params: map[string]interface{}{}}, + {SQL: "create table table2", Params: map[string]interface{}{}}, + }, batchedStatements2) { + t.Errorf("unexpected batched statements: %v", batchedStatements2) + } + + if err := c.AbortBatch(); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(c.GetBatchedStatements(), []spanner.Statement{}) { + t.Fatal("conn should return an empty slice when no batch is active") + } +} + func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) { want := time.Now() c := &conn{