diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 160fe071d5c1..2b567a62defd 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -56,6 +56,12 @@ import ( const ( directPathIPV6Prefix = "[2001:4860:8040" directPathIPV4Prefix = "34.126" + + singerDDLStatements = "SINGER_DDL_STATEMENTS" + simpleDDLStatements = "SIMPLE_DDL_STATEMENTS" + readDDLStatements = "READ_DDL_STATEMENTS" + backupDDLStatements = "BACKUP_DDL_STATEMENTS" + testTableDDLStatements = "TEST_TABLE_DDL_STATEMENTS" ) var ( @@ -63,6 +69,9 @@ var ( // by setting environment variable GCLOUD_TESTS_GOLANG_PROJECT_ID. testProjectID = testutil.ProjID() + // testDialect specifies the dialect used for testing. + testDialect adminpb.DatabaseDialect + // spannerHost specifies the spanner API host used for testing. It can be changed // by setting the environment variable GCLOUD_TESTS_GOLANG_SPANNER_HOST spannerHost = getSpannerHost() @@ -87,6 +96,26 @@ var ( dpConfig directPathTestConfig peerInfo *peer.Peer + singerDBPGStatements = []string{ + `CREATE TABLE Singers ( + SingerId INT8 NOT NULL, + FirstName VARCHAR(1024), + LastName VARCHAR(1024), + SingerInfo BYTEA, + numeric NUMERIC, + float8 FLOAT8, + PRIMARY KEY(SingerId) + )`, + `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`, + `CREATE TABLE Accounts ( + AccountId BIGINT NOT NULL, + Nickname VARCHAR(100), + Balance BIGINT NOT NULL, + PRIMARY KEY(AccountId) + )`, + `CREATE INDEX AccountByNickname ON Accounts(Nickname)`, + } + singerDBStatements = []string{ `CREATE TABLE Singers ( SingerId INT64 NOT NULL, @@ -131,12 +160,27 @@ var ( `CREATE INDEX TestTableByValueDesc ON TestTable(StringValue DESC)`, } + readDBPGStatements = []string{ + `CREATE TABLE TestTable ( + Key VARCHAR PRIMARY KEY, + StringValue VARCHAR + )`, + `CREATE INDEX TestTableByValue ON TestTable(StringValue)`, + `CREATE INDEX TestTableByValueDesc ON TestTable(StringValue DESC)`, + } + simpleDBStatements = []string{ `CREATE TABLE test ( a STRING(1024), b STRING(1024), ) PRIMARY KEY (a)`, } + simpleDBPGStatements = []string{ + `CREATE TABLE test ( + a VARCHAR(1024) PRIMARY KEY, + b VARCHAR(1024) + )`, + } simpleDBTableColumns = []string{"a", "b"} ctsDBStatements = []string{ @@ -145,7 +189,7 @@ var ( Ts TIMESTAMP OPTIONS (allow_commit_timestamp = true), ) PRIMARY KEY (Key)`, } - backuDBStatements = []string{ + backupDBStatements = []string{ `CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), @@ -180,6 +224,48 @@ var ( ) PRIMARY KEY (RowID)`, } + backupDBPGStatements = []string{ + `CREATE TABLE Singers ( + SingerId BIGINT PRIMARY KEY, + FirstName VARCHAR(1024), + LastName VARCHAR(1024), + SingerInfo BYTEA + )`, + `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`, + `CREATE TABLE Accounts ( + AccountId BIGINT PRIMARY KEY, + Nickname VARCHAR(100), + Balance BIGINT NOT NULL + )`, + `CREATE INDEX AccountByNickname ON Accounts(Nickname)`, + `CREATE TABLE Types ( + RowID BIGINT PRIMARY KEY, + String VARCHAR, + Bytes BYTEA, + Int64a BIGINT, + Bool BOOL, + Float64 DOUBLE PRECISION, + Numeric NUMERIC + )`, + } + + statements = map[adminpb.DatabaseDialect]map[string][]string{ + adminpb.DatabaseDialect_GOOGLE_STANDARD_SQL: { + singerDDLStatements: singerDBStatements, + simpleDDLStatements: simpleDBStatements, + readDDLStatements: readDBStatements, + backupDDLStatements: backupDBStatements, + testTableDDLStatements: readDBStatements, + }, + adminpb.DatabaseDialect_POSTGRESQL: { + singerDDLStatements: singerDBPGStatements, + simpleDDLStatements: simpleDBPGStatements, + readDDLStatements: readDBPGStatements, + backupDDLStatements: backupDBPGStatements, + testTableDDLStatements: readDBPGStatements, + }, + } + validInstancePattern = regexp.MustCompile("^projects/(?P[^/]+)/instances/(?P[^/]+)$") blackholeDpv6Cmd string @@ -228,9 +314,19 @@ const ( func TestMain(m *testing.M) { cleanup := initIntegrationTests() - res := m.Run() - cleanup() - os.Exit(res) + defer cleanup() + for _, dialect := range []adminpb.DatabaseDialect{adminpb.DatabaseDialect_GOOGLE_STANDARD_SQL, adminpb.DatabaseDialect_POSTGRESQL} { + if isEmulatorEnvSet() && dialect == adminpb.DatabaseDialect_POSTGRESQL { + // PG tests are not supported in emulator + continue + } + testDialect = dialect + res := m.Run() + if res != 0 { + cleanup() + os.Exit(res) + } + } } var grpcHeaderChecker = testutil.DefaultHeadersEnforcer() @@ -386,11 +482,12 @@ loop: // Test SingleUse transaction. func TestIntegration_SingleUse(t *testing.T) { t.Parallel() + skipEmulatorTestForPG(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() writes := []struct { @@ -419,10 +516,11 @@ func TestIntegration_SingleUse(t *testing.T) { // Test reading rows with different timestamp bounds. for i, test := range []struct { - name string - want [][]interface{} - tb TimestampBound - checkTs func(time.Time) error + name string + want [][]interface{} + tb TimestampBound + checkTs func(time.Time) error + skipForPG bool }{ { name: "strong", @@ -472,7 +570,10 @@ func TestIntegration_SingleUse(t *testing.T) { }, { name: "exact_staleness", - want: nil, + // PG query with exact_staleness returns error code + // "InvalidArgument", desc = "[ERROR] relation \"singers\" does not exist + skipForPG: true, + want: nil, // Specify a staleness which should be already before this test. tb: ExactStaleness(time.Now().Sub(writes[0].ts) + timeDiff + 30*time.Second), checkTs: func(ts time.Time) error { @@ -484,13 +585,20 @@ func TestIntegration_SingleUse(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { + singersQuery := "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@p1, @p2, @p3) ORDER BY SingerId" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + if test.skipForPG { + t.Skip("Skipping testing of unsupported tests in Postgres dialect.") + } + singersQuery = "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId = $1 OR SingerId = $2 OR SingerId = $3 ORDER BY SingerId" + } // SingleUse.Query su := client.Single().WithTimestampBound(test.tb) got, err := readAll(su.Query( ctx, Statement{ - "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@id1, @id3, @id4) ORDER BY SingerId", - map[string]interface{}{"id1": int64(1), "id3": int64(3), "id4": int64(4)}, + singersQuery, + map[string]interface{}{"p1": int64(1), "p2": int64(3), "p3": int64(4)}, })) if err != nil { t.Fatalf("%d: SingleUse.Query returns error %v, want nil", i, err) @@ -622,7 +730,7 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() writes := []struct { @@ -644,13 +752,17 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) { t.Fatal(err) } } + singersQuery := "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@p1, @p2, @p3)" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId = $1 OR SingerId = $2 OR SingerId = $3" + } qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{ OptimizerVersion: "1", OptimizerStatisticsPackage: "latest", }} got, err := readAll(client.Single().QueryWithOptions(ctx, Statement{ - "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@id1, @id3, @id4)", - map[string]interface{}{"id1": int64(1), "id3": int64(3), "id4": int64(4)}, + singersQuery, + map[string]interface{}{"p1": int64(1), "p2": int64(3), "p3": int64(4)}, }, qo)) if err != nil { @@ -669,7 +781,7 @@ func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() writes := []struct { @@ -713,7 +825,7 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() writes := []struct { @@ -741,9 +853,10 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { // Test reading rows with different timestamp bounds. for i, test := range []struct { - want [][]interface{} - tb TimestampBound - checkTs func(time.Time) error + want [][]interface{} + tb TimestampBound + checkTs func(time.Time) error + skipForPG bool }{ // Note: min_read_timestamp and max_staleness are not supported by // ReadOnlyTransaction. See API document for more details. @@ -757,6 +870,7 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { } return nil }, + false, }, { // read_timestamp @@ -768,6 +882,7 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { } return nil }, + false, }, { // exact_staleness @@ -780,15 +895,24 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { } return nil }, + // PG query with exact_staleness returns Table not found error + true, }, } { // ReadOnlyTransaction.Query ro := client.ReadOnlyTransaction().WithTimestampBound(test.tb) + singersQuery := "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@p1, @p2, @p3) ORDER BY SingerId" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + if test.skipForPG { + t.Skip("Skipping testing of unsupported tests in Postgres dialect.") + } + singersQuery = "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId = $1 OR SingerId = $2 OR SingerId = $3 ORDER BY SingerId" + } got, err := readAll(ro.Query( ctx, Statement{ - "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@id1, @id3, @id4) ORDER BY SingerId", - map[string]interface{}{"id1": int64(1), "id3": int64(3), "id4": int64(4)}, + singersQuery, + map[string]interface{}{"p1": int64(1), "p2": int64(3), "p3": int64(4)}, })) if err != nil { t.Errorf("%d: ReadOnlyTransaction.Query returns error %v, want nil", i, err) @@ -924,7 +1048,7 @@ func TestIntegration_UpdateDuringRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() for i, tb := range []TimestampBound{ @@ -957,7 +1081,7 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { // Give a longer deadline because of transaction backoffs. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Set up two accounts @@ -987,6 +1111,10 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { } } + queryAccountByID := "SELECT Balance FROM Accounts WHERE AccountId = @p1" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + queryAccountByID = "SELECT Balance FROM Accounts WHERE AccountId = $1" + } for i := 0; i < 20; i++ { wg.Add(1) go func(iter int) { @@ -994,7 +1122,7 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { // Query Foo's balance and Bar's balance. bf, e := readBalance(tx.Query(ctx, - Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}})) + Statement{queryAccountByID, map[string]interface{}{"p1": int64(1)}})) if e != nil { return e } @@ -1032,9 +1160,17 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { if ce := r.Column(0, &bf); ce != nil { return ce } - bb, e = readBalance(tx.ReadUsingIndex(ctx, "Accounts", "AccountByNickname", KeySets(Key{"Bar"}), []string{"Balance"})) - if e != nil { - return e + // reading non-indexed column from index is not supported in PG + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + bb, e = readBalance(tx.Read(ctx, "Accounts", Key{int64(2)}, []string{"Balance"})) + if e != nil { + return e + } + } else { + bb, e = readBalance(tx.ReadUsingIndex(ctx, "Accounts", "AccountByNickname", KeySets(Key{"Bar"}), []string{"Balance"})) + if e != nil { + return e + } } verifyDirectPathRemoteAddress(t) if bf != 30 || bb != 21 { @@ -1056,7 +1192,7 @@ func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) { // Give a longer deadline because of transaction backoffs. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Set up two accounts @@ -1088,8 +1224,12 @@ func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) { txOpts := TransactionOptions{CommitOptions: CommitOptions{ReturnCommitStats: true}} resp, err := client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { // Query Foo's balance and Bar's balance. + queryAccountByID := "SELECT Balance FROM Accounts WHERE AccountId = @p1" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + queryAccountByID = "SELECT Balance FROM Accounts WHERE AccountId = $1" + } bf, e := readBalance(tx.Query(ctx, - Statement{"SELECT Balance FROM Accounts WHERE AccountId = @id", map[string]interface{}{"id": int64(1)}})) + Statement{queryAccountByID, map[string]interface{}{"p1": int64(1)}})) if e != nil { return e } @@ -1113,7 +1253,12 @@ func TestIntegration_ReadWriteTransactionWithOptions(t *testing.T) { if resp.CommitStats == nil { t.Fatal("Missing commit stats in commit response") } - if got, want := resp.CommitStats.MutationCount, int64(8); got != want { + expectedMutationCount := int64(8) + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + // for PG mutation count for Balance column is not added for AccountName index + expectedMutationCount = int64(4) + } + if got, want := resp.CommitStats.MutationCount, expectedMutationCount; got != want { t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want) } } @@ -1124,7 +1269,7 @@ func TestIntegration_ReadWriteTransaction_StatementBased(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Set up two accounts @@ -1217,7 +1362,7 @@ func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Set up two accounts @@ -1292,7 +1437,12 @@ func TestIntegration_ReadWriteTransaction_StatementBasedWithOptions(t *testing.T if resp.CommitStats == nil { t.Fatal("Missing commit stats in commit response") } - if got, want := resp.CommitStats.MutationCount, int64(8); got != want { + expectedMutationCount := int64(8) + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + // for PG mutation count for Balance column is not added for AccountName index + expectedMutationCount = int64(4) + } + if got, want := resp.CommitStats.MutationCount, expectedMutationCount; got != want { t.Errorf("Mismatch mutation count - got: %v, want: %v", got, want) } } @@ -1303,7 +1453,7 @@ func TestIntegration_Reads(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][testTableDDLStatements]) defer cleanup() // Includes k0..k14. Strings sort lexically, eg "k1" < "k10" < "k2". @@ -1393,7 +1543,7 @@ func TestIntegration_EarlyTimestamp(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][testTableDDLStatements]) defer cleanup() var ms []*Mutation @@ -1439,7 +1589,7 @@ func TestIntegration_NestedTransaction(t *testing.T) { // You cannot use a transaction from inside a read-write transaction. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { @@ -1467,6 +1617,7 @@ func TestIntegration_NestedTransaction(t *testing.T) { func TestIntegration_CreateDBRetry(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) if databaseAdmin == nil { t.Skip("Integration tests skipped") @@ -1495,6 +1646,7 @@ func TestIntegration_CreateDBRetry(t *testing.T) { op, err := dbAdmin.CreateDatabaseWithRetry(ctx, &adminpb.CreateDatabaseRequest{ Parent: fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID), CreateStatement: "CREATE DATABASE " + dbName, + DatabaseDialect: testDialect, }) if err != nil { t.Fatalf("failed to create database: %v", err) @@ -1516,7 +1668,7 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { defer cancel() // Create a client with MinOpened=0 to prevent the session pool maintainer // from repeatedly trying to create sessions for the invalid database. - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, SessionPoolConfig{}, singerDBStatements) + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, SessionPoolConfig{}, statements[testDialect][singerDDLStatements]) defer cleanup() // Drop the testing database. @@ -1534,7 +1686,7 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { // Recreate database and table. dbName := dbPath[strings.LastIndex(dbPath, "/")+1:] - op, err := databaseAdmin.CreateDatabaseWithRetry(ctx, &adminpb.CreateDatabaseRequest{ + req := &adminpb.CreateDatabaseRequest{ Parent: fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID), CreateStatement: "CREATE DATABASE " + dbName, ExtraStatements: []string{ @@ -1545,13 +1697,37 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { SingerInfo BYTES(MAX) ) PRIMARY KEY (SingerId)`, }, - }) + DatabaseDialect: testDialect, + } + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + req.ExtraStatements = []string{} + } + op, err := databaseAdmin.CreateDatabaseWithRetry(ctx, req) if err != nil { t.Fatalf("cannot recreate testing DB %v: %v", dbPath, err) } if _, err := op.Wait(ctx); err != nil { t.Fatalf("cannot recreate testing DB %v: %v", dbPath, err) } + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + op, err := databaseAdmin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: dbPath, + Statements: []string{` + CREATE TABLE Singers ( + SingerId INT8 NOT NULL, + FirstName VARCHAR(1024), + LastName VARCHAR(1024), + SingerInfo BYTEA, + PRIMARY KEY(SingerId) + )`}, + }) + if err != nil { + t.Fatalf("cannot create singers table %v: %v", dbPath, err) + } + if err := op.Wait(ctx); err != nil { + t.Fatalf("timeout creating singers table %v: %v", dbPath, err) + } + } // Now, send the query again. iter = client.Single().Query(ctx, Statement{SQL: "SELECT SingerId FROM Singers"}) @@ -1566,6 +1742,7 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { // Test encoding/decoding non-struct Cloud Spanner types. func TestIntegration_BasicTypes(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -1840,6 +2017,7 @@ func TestIntegration_BasicTypes(t *testing.T) { // Test decoding Cloud Spanner STRUCT type. func TestIntegration_StructTypes(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -1928,6 +2106,7 @@ func TestIntegration_StructTypes(t *testing.T) { func TestIntegration_StructParametersUnsupported(t *testing.T) { skipEmulatorTest(t) t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -1973,6 +2152,7 @@ func TestIntegration_StructParametersUnsupported(t *testing.T) { // Test queries of the form "SELECT expr". func TestIntegration_QueryExpressions(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -2033,7 +2213,7 @@ func TestIntegration_QueryStats(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() accounts := []*Mutation{ @@ -2096,7 +2276,7 @@ func TestIntegration_ReadErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][readDDLStatements]) defer cleanup() var ms []*Mutation @@ -2124,7 +2304,11 @@ func TestIntegration_ReadErrors(t *testing.T) { iter := client.Single().Query(ctx, Statement{SQL: "SELECT Apples AND Oranges"}) defer iter.Stop() _, err = iter.Next() - if msg, ok := matchError(err, codes.InvalidArgument, "unrecognized name"); !ok { + errorMessage := "unrecognized name" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + errorMessage = "does not exist" + } + if msg, ok := matchError(err, codes.InvalidArgument, errorMessage); !ok { t.Error(msg) } @@ -2159,7 +2343,7 @@ func TestIntegration_TransactionRunner(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Test 1: User error should abort the transaction. @@ -2305,7 +2489,7 @@ func TestIntegration_BatchQuery(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][simpleDDLStatements]) defer cleanup() if err = populate(ctx, client); err != nil { @@ -2383,7 +2567,8 @@ func TestIntegration_BatchQuery(t *testing.T) { // Test PartitionRead of BatchReadOnlyTransaction, similar to TestBatchQuery func TestIntegration_BatchRead(t *testing.T) { t.Parallel() - + // skipping PG test because of rpc error: code = InvalidArgument desc = [ERROR] syntax error at or near "." in PartitionRead + skipUnsupportedPGTest(t) // Set up testing environment. var ( client2 *Client @@ -2391,7 +2576,7 @@ func TestIntegration_BatchRead(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][simpleDDLStatements]) defer cleanup() if err = populate(ctx, client); err != nil { @@ -2468,7 +2653,8 @@ func TestIntegration_BatchRead(t *testing.T) { // Test normal txReadEnv method on BatchReadOnlyTransaction. func TestIntegration_BROTNormal(t *testing.T) { t.Parallel() - + // skipping PG test because of rpc error: code = InvalidArgument desc = [ERROR] syntax error at or near "." in PartitionRead + skipUnsupportedPGTest(t) // Set up testing environment and create txn. var ( txn *BatchReadOnlyTransaction @@ -2478,7 +2664,7 @@ func TestIntegration_BROTNormal(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][simpleDDLStatements]) defer cleanup() if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil { @@ -2504,6 +2690,7 @@ func TestIntegration_BROTNormal(t *testing.T) { func TestIntegration_CommitTimestamp(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -2577,7 +2764,7 @@ func TestIntegration_DML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() // Function that reads a single row's first name from within a transaction. @@ -2616,10 +2803,29 @@ func TestIntegration_DML(t *testing.T) { return got } + singersQuery := []string{`INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Umm", "Kulthum")`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (2, "Eduard", "Khil")`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`, + `UPDATE Singers SET FirstName = "Oum" WHERE SingerId = 1`, + `UPDATE Singers SET FirstName = "Eddie" WHERE SingerId = 2`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`, + `SELECT FirstName FROM Singers`, + } + + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = []string{`INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Umm', 'Kulthum')`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (2, 'Eduard', 'Khil')`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, 'Audra', 'McDonald')`, + `UPDATE Singers SET FirstName = 'Oum' WHERE SingerId = 1`, + `UPDATE Singers SET FirstName = 'Eddie' WHERE SingerId = 2`, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, 'Audra', 'McDonald')`, + `SELECT FirstName FROM Singers`, + } + } // Use ReadWriteTransaction.Query to execute a DML statement. _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { iter := tx.Query(ctx, Statement{ - SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Umm", "Kulthum")`, + SQL: singersQuery[0], }) defer iter.Stop() row, err := iter.Next() @@ -2649,7 +2855,7 @@ func TestIntegration_DML(t *testing.T) { // Use ReadWriteTransaction.Update to execute a DML statement. _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { count, err := tx.Update(ctx, Statement{ - SQL: `Insert INTO Singers (SingerId, FirstName, LastName) VALUES (2, "Eduard", "Khil")`, + SQL: singersQuery[1], }) if err != nil { return err @@ -2675,7 +2881,7 @@ func TestIntegration_DML(t *testing.T) { var fail = errors.New("fail") _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { _, err := tx.Update(ctx, Statement{ - SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`, + SQL: singersQuery[2], }) if err != nil { return err @@ -2692,11 +2898,11 @@ func TestIntegration_DML(t *testing.T) { // Run two DML statements in the same transaction. _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { - _, err := tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Oum" WHERE SingerId = 1`}) + _, err := tx.Update(ctx, Statement{SQL: singersQuery[3]}) if err != nil { return err } - _, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Eddie" WHERE SingerId = 2`}) + _, err = tx.Update(ctx, Statement{SQL: singersQuery[4]}) if err != nil { return err } @@ -2714,7 +2920,7 @@ func TestIntegration_DML(t *testing.T) { // Run a DML statement and an ordinary mutation in the same transaction. _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { _, err := tx.Update(ctx, Statement{ - SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`, + SQL: singersQuery[5], }) if err != nil { return err @@ -2735,7 +2941,7 @@ func TestIntegration_DML(t *testing.T) { // Attempt to run a query using update. _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { - _, err := tx.Update(ctx, Statement{SQL: `SELECT FirstName from Singers`}) + _, err := tx.Update(ctx, Statement{SQL: singersQuery[6]}) return err }) if got, want := ErrCode(err), codes.InvalidArgument; got != want { @@ -2745,6 +2951,7 @@ func TestIntegration_DML(t *testing.T) { func TestIntegration_StructParametersBind(t *testing.T) { t.Parallel() + skipUnsupportedPGTest(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -2918,7 +3125,7 @@ func TestIntegration_PDML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2937,12 +3144,16 @@ func TestIntegration_PDML(t *testing.T) { if _, err := client.Apply(ctx, muts); err != nil { t.Fatal(err) } + query := `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId >= 1 AND Singers.SingerId <= @p1` + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + query = `UPDATE Singers SET FirstName = 'changed' WHERE SingerId >= 1 AND SingerId <= $1` + } // Identifiers in PDML statements must be fully qualified. // TODO(jba): revisit the above. count, err := client.PartitionedUpdate(ctx, Statement{ - SQL: `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId >= 1 AND Singers.SingerId <= @end`, + SQL: query, Params: map[string]interface{}{ - "end": 3, + "p1": 3, }, }) if err != nil { @@ -2972,7 +3183,7 @@ func TestIntegration_BatchDML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2990,12 +3201,22 @@ func TestIntegration_BatchDML(t *testing.T) { t.Fatal(err) } + singersQuery := []string{`UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`, + `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`, + `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`, + } + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = []string{`UPDATE Singers SET FirstName = 'changed 1' WHERE SingerId = 1`, + `UPDATE Singers SET FirstName = 'changed 2' WHERE SingerId = 2`, + `UPDATE Singers SET FirstName = 'changed 3' WHERE SingerId = 3`, + } + } var counts []int64 _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { counts, err = tx.BatchUpdate(ctx, []Statement{ - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`}, - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`}, - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`}, + {SQL: singersQuery[0]}, + {SQL: singersQuery[1]}, + {SQL: singersQuery[2]}, }) return err }) @@ -3025,7 +3246,7 @@ func TestIntegration_BatchDML_NoStatements(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { @@ -3049,7 +3270,7 @@ func TestIntegration_BatchDML_TwoStatements(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -3067,19 +3288,30 @@ func TestIntegration_BatchDML_TwoStatements(t *testing.T) { t.Fatal(err) } + singersQuery := []string{`UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`, + `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`, + `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`, + `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`} + + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = []string{`UPDATE Singers SET FirstName = 'changed 1' WHERE SingerId = 1`, + `UPDATE Singers SET FirstName = 'changed 2' WHERE SingerId = 2`, + `UPDATE Singers SET FirstName = 'changed 3' WHERE SingerId = 3`, + `UPDATE Singers SET FirstName = 'changed 1' WHERE SingerId = 1`} + } var updateCount int64 var batchCounts []int64 _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { batchCounts, err = tx.BatchUpdate(ctx, []Statement{ - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`}, - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`}, - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`}, + {SQL: singersQuery[0]}, + {SQL: singersQuery[1]}, + {SQL: singersQuery[2]}, }) if err != nil { return err } - updateCount, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`}) + updateCount, err = tx.Update(ctx, Statement{SQL: singersQuery[3]}) return err }) if err != nil { @@ -3098,7 +3330,7 @@ func TestIntegration_BatchDML_Error(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -3116,11 +3348,20 @@ func TestIntegration_BatchDML_Error(t *testing.T) { t.Fatal(err) } + singersQuery := []string{`UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`, + `some illegal statement`, + `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`} + + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = []string{`UPDATE Singers SET FirstName = 'changed 1' WHERE SingerId = 1`, + `some illegal statement`, + `UPDATE Singers SET FirstName = 'changed 3' WHERE SingerId = 3`} + } _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { counts, err := tx.BatchUpdate(ctx, []Statement{ - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`}, - {SQL: `some illegal statement`}, - {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`}, + {SQL: singersQuery[0]}, + {SQL: singersQuery[1]}, + {SQL: singersQuery[2]}, }) if err == nil { t.Fatal("expected err, got nil") @@ -3158,6 +3399,88 @@ func TestIntegration_BatchDML_Error(t *testing.T) { } } +func TestIntegration_PGNumeric(t *testing.T) { + onlyRunForPGTest(t) + skipEmulatorTest(t) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTestForPG(ctx, t, DefaultSessionPoolConfig, singerDBPGStatements) + defer cleanup() + + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + count, err := tx.Update(ctx, Statement{ + SQL: `INSERT INTO Singers (SingerId, numeric, float8) VALUES ($1, $2, $3)`, + Params: map[string]interface{}{ + "p1": int64(123), + "p2": PGNumeric{"123.456789", true}, + "p3": float64(123.456), + }, + }) + if err != nil { + return err + } + if count != 1 { + t.Errorf("row count: got %d, want 1", count) + } + + count, err = tx.Update(ctx, Statement{ + SQL: `INSERT INTO Singers (SingerId, numeric, float8) VALUES ($1, $2, $3)`, + Params: map[string]interface{}{ + "p1": int64(456), + "p2": PGNumeric{"NaN", true}, + "p3": float64(12345.6), + }, + }) + if err != nil { + return err + } + if count != 1 { + t.Errorf("row count: got %d, want 1", count) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + iter := client.Single().Query(ctx, Statement{SQL: "SELECT SingerId, numeric, float8 FROM Singers"}) + got, err := readPGSingerTable(iter) + if err != nil { + t.Fatalf("failed to read data: %v", err) + } + want := [][]interface{}{ + {int64(123), PGNumeric{"123.456789", true}, float64(123.456)}, + {int64(456), PGNumeric{"NaN", true}, float64(12345.6)}, + } + if !testEqual(got, want) { + t.Errorf("\ngot %v\nwant%v", got, want) + } +} + +func readPGSingerTable(iter *RowIterator) ([][]interface{}, error) { + defer iter.Stop() + var vals [][]interface{} + for { + row, err := iter.Next() + if err == iterator.Done { + return vals, nil + } + if err != nil { + return nil, err + } + var id int64 + var numeric PGNumeric + var float8 float64 + err = row.Columns(&id, &numeric, &float8) + if err != nil { + return nil, err + } + vals = append(vals, []interface{}{id, numeric, float8}) + } +} + func TestIntegration_StartBackupOperation(t *testing.T) { skipEmulatorTest(t) t.Parallel() @@ -3166,7 +3489,7 @@ func TestIntegration_StartBackupOperation(t *testing.T) { // Backups can be slow, so use 1 hour timeout. ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour) defer cancel() - _, testDatabaseName, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, backuDBStatements) + _, testDatabaseName, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][backupDDLStatements]) defer cleanup() backupID := backupIDSpace.New() @@ -3214,7 +3537,7 @@ func TestIntegration_DirectPathFallback(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][readDDLStatements]) defer cleanup() if !dpConfig.attemptDirectPath { @@ -3255,14 +3578,13 @@ func TestIntegration_DirectPathFallback(t *testing.T) { } func TestIntegration_GFE_Latency(t *testing.T) { - t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() te := testutil.NewTestExporter(GFEHeaderMissingCountView, GFELatencyView) setGFELatencyMetricsFlag(true) - client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) defer cleanup() singerColumns := []string{"SingerId", "FirstName", "LastName"} @@ -3335,6 +3657,14 @@ func TestIntegration_GFE_Latency(t *testing.T) { // Prepare initializes Cloud Spanner testing DB and clients. func prepareIntegrationTest(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string) (*Client, string, func()) { + return prepareDBAndClient(ctx, t, spc, statements, testDialect) +} + +func prepareIntegrationTestForPG(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string) (*Client, string, func()) { + return prepareDBAndClient(ctx, t, spc, statements, adminpb.DatabaseDialect_POSTGRESQL) +} + +func prepareDBAndClient(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string, dbDialect adminpb.DatabaseDialect) (*Client, string, func()) { if databaseAdmin == nil { t.Skip("Integration tests skipped") } @@ -3343,17 +3673,34 @@ func prepareIntegrationTest(ctx context.Context, t *testing.T, spc SessionPoolCo dbPath := fmt.Sprintf("projects/%v/instances/%v/databases/%v", testProjectID, testInstanceID, dbName) // Create database and tables. - op, err := databaseAdmin.CreateDatabaseWithRetry(ctx, &adminpb.CreateDatabaseRequest{ + req := &adminpb.CreateDatabaseRequest{ Parent: fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID), CreateStatement: "CREATE DATABASE " + dbName, ExtraStatements: statements, - }) + DatabaseDialect: dbDialect, + } + if dbDialect == adminpb.DatabaseDialect_POSTGRESQL { + req.ExtraStatements = []string{} + } + op, err := databaseAdmin.CreateDatabaseWithRetry(ctx, req) if err != nil { t.Fatalf("cannot create testing DB %v: %v", dbPath, err) } if _, err := op.Wait(ctx); err != nil { t.Fatalf("cannot create testing DB %v: %v", dbPath, err) } + if dbDialect == adminpb.DatabaseDialect_POSTGRESQL && len(statements) > 0 { + op, err := databaseAdmin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: dbPath, + Statements: statements, + }) + if err != nil { + t.Fatalf("cannot create testing table %v: %v", dbPath, err) + } + if err := op.Wait(ctx); err != nil { + t.Fatalf("timeout creating testing table %v: %v", dbPath, err) + } + } client, err := createClient(ctx, dbPath, spc) if err != nil { t.Fatalf("cannot create data client on DB %v: %v", dbPath, err) @@ -3654,6 +4001,24 @@ func skipEmulatorTest(t *testing.T) { } } +func skipEmulatorTestForPG(t *testing.T) { + if isEmulatorEnvSet() && testDialect == adminpb.DatabaseDialect_POSTGRESQL { + t.Skip("Skipping PG testing against the emulator.") + } +} + +func skipUnsupportedPGTest(t *testing.T) { + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + t.Skip("Skipping testing of unsupported tests in Postgres dialect.") + } +} + +func onlyRunForPGTest(t *testing.T) { + if testDialect != adminpb.DatabaseDialect_POSTGRESQL { + t.Skip("Skipping tests supported only in Postgres dialect.") + } +} + func verifyDirectPathRemoteAddress(t *testing.T) { t.Helper() if !dpConfig.attemptDirectPath { diff --git a/spanner/oc_test.go b/spanner/oc_test.go index a7ea7ffe2e13..711458b85cc7 100644 --- a/spanner/oc_test.go +++ b/spanner/oc_test.go @@ -50,6 +50,7 @@ func TestOCStats(t *testing.T) { } func TestOCStats_SessionPool(t *testing.T) { + DisableGfeLatencyAndHeaderMissingCountViews() for _, test := range []struct { name string view *view.View @@ -144,6 +145,7 @@ func testSimpleMetric(t *testing.T, v *view.View, measure, value string) { } func TestOCStats_SessionPool_SessionsCount(t *testing.T) { + DisableGfeLatencyAndHeaderMissingCountViews() te := testutil.NewTestExporter(SessionsCountView) defer te.Unregister() @@ -216,6 +218,7 @@ func TestOCStats_SessionPool_SessionsCount(t *testing.T) { } func TestOCStats_SessionPool_GetSessionTimeoutsCount(t *testing.T) { + DisableGfeLatencyAndHeaderMissingCountViews() te := testutil.NewTestExporter(GetSessionTimeoutsCountView) defer te.Unregister() diff --git a/spanner/protoutils.go b/spanner/protoutils.go index 66eea4b5c294..ebe03c1b0cf7 100644 --- a/spanner/protoutils.go +++ b/spanner/protoutils.go @@ -73,6 +73,10 @@ func numericType() *sppb.Type { return &sppb.Type{Code: sppb.TypeCode_NUMERIC} } +func pgNumericType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_NUMERIC, TypeAnnotation: sppb.TypeAnnotationCode_PG_NUMERIC} +} + func jsonType() *sppb.Type { return &sppb.Type{Code: sppb.TypeCode_JSON} } diff --git a/spanner/value.go b/spanner/value.go index 6ed5cb09c6c8..a9f82ae81c79 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -830,6 +830,52 @@ func (n NullJSON) GormDataType() string { return "JSON" } +// PGNumeric represents a Cloud Spanner PG Numeric that may be NULL. +type PGNumeric struct { + Numeric string // Numeric contains the value when it is non-NULL, and an empty string when NULL. + Valid bool // Valid is true if Numeric is not NULL. +} + +// IsNull implements NullableValue.IsNull for PGNumeric. +func (n PGNumeric) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for PGNumeric +func (n PGNumeric) String() string { + if !n.Valid { + return nullString + } + return n.Numeric +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for PGNumeric. +func (n PGNumeric) MarshalJSON() ([]byte, error) { + if n.Valid { + return []byte(fmt.Sprintf("%q", n.Numeric)), nil + } + return jsonNullBytes, nil +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGNumeric. +func (n *PGNumeric) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.Numeric = "" + n.Valid = false + return nil + } + payload, err := trimDoubleQuotes(payload) + if err != nil { + return err + } + n.Numeric = string(payload) + n.Valid = true + return nil +} + // NullRow represents a Cloud Spanner STRUCT that may be NULL. // See also the document for Row. // Note that NullRow is not a valid Cloud Spanner column Type. @@ -948,12 +994,15 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errNilSpannerType() } code := t.Code + typeAnnotation := t.TypeAnnotation acode := sppb.TypeCode_TYPE_CODE_UNSPECIFIED + atypeAnnotation := sppb.TypeAnnotationCode_TYPE_ANNOTATION_CODE_UNSPECIFIED if code == sppb.TypeCode_ARRAY { if t.ArrayElementType == nil { return errNilArrElemType(t) } acode = t.ArrayElementType.Code + atypeAnnotation = t.ArrayElementType.TypeAnnotation } _, isNull := v.Kind.(*proto3.Value_NullValue) @@ -1557,6 +1606,38 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return err } *p = y + case *PGNumeric: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_NUMERIC || typeAnnotation != sppb.TypeAnnotationCode_PG_NUMERIC { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = PGNumeric{} + break + } + *p = PGNumeric{v.GetStringValue(), true} + case *[]PGNumeric: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_NUMERIC || atypeAnnotation != sppb.TypeAnnotationCode_PG_NUMERIC { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodePGNumericArray(x) + if err != nil { + return err + } + *p = y case *time.Time: var nt NullTime if isNull { @@ -1780,7 +1861,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { if isNull && !decodableType.supportsNull() { return errDstNotForNull(ptr) } - return decodableType.decodeValueToCustomType(v, t, acode, ptr) + return decodableType.decodeValueToCustomType(v, t, acode, atypeAnnotation, ptr) } // Check if the proto encoding is for an array of structs. @@ -1840,6 +1921,7 @@ const ( spannerTypeNullDate spannerTypeNullNumeric spannerTypeNullJSON + spannerTypePGNumeric spannerTypeArrayOfNonNullString spannerTypeArrayOfByteArray spannerTypeArrayOfNonNullInt64 @@ -1856,6 +1938,7 @@ const ( spannerTypeArrayOfNullJSON spannerTypeArrayOfNullTime spannerTypeArrayOfNullDate + spannerTypeArrayOfPGNumeric ) // supportsNull returns true for the Go types that can hold a null value from @@ -1887,6 +1970,7 @@ var typeOfNullTime = reflect.TypeOf(NullTime{}) var typeOfNullDate = reflect.TypeOf(NullDate{}) var typeOfNullNumeric = reflect.TypeOf(NullNumeric{}) var typeOfNullJSON = reflect.TypeOf(NullJSON{}) +var typeOfPGNumeric = reflect.TypeOf(PGNumeric{}) // getDecodableSpannerType returns the corresponding decodableSpannerType of // the given pointer. @@ -1956,6 +2040,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullJSON) { return spannerTypeNullJSON } + if t.ConvertibleTo(typeOfPGNumeric) { + return spannerTypePGNumeric + } case reflect.Slice: kind := val.Type().Elem().Kind() switch kind { @@ -2011,6 +2098,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullJSON) { return spannerTypeArrayOfNullJSON } + if t.ConvertibleTo(typeOfPGNumeric) { + return spannerTypeArrayOfPGNumeric + } case reflect.Slice: // The only array-of-array type that is supported is [][]byte. kind := val.Type().Elem().Elem().Kind() @@ -2027,8 +2117,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { // decodeValueToCustomType decodes a protobuf Value into a pointer to a Go // value. It must be possible to convert the value to the type pointed to by // the pointer. -func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb.Type, acode sppb.TypeCode, ptr interface{}) error { +func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb.Type, acode sppb.TypeCode, atypeAnnotation sppb.TypeAnnotationCode, ptr interface{}) error { code := t.Code + typeAnnotation := t.TypeAnnotation _, isNull := v.Kind.(*proto3.Value_NullValue) if dsc == spannerTypeInvalid { return errNilDst(ptr) @@ -2146,6 +2237,15 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } else { result = &NullNumeric{*y, true} } + case spannerTypePGNumeric: + if code != sppb.TypeCode_NUMERIC || typeAnnotation != sppb.TypeAnnotationCode_PG_NUMERIC { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + result = &PGNumeric{} + break + } + result = &PGNumeric{v.GetStringValue(), true} case spannerTypeNullJSON: if code != sppb.TypeCode_JSON { return errTypeMismatch(code, acode, ptr) @@ -2295,6 +2395,23 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb return err } result = y + case spannerTypeArrayOfPGNumeric: + if acode != sppb.TypeCode_NUMERIC || atypeAnnotation != sppb.TypeAnnotationCode_PG_NUMERIC { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + ptr = nil + return nil + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, pgNumericType(), "PGNUMERIC") + if err != nil { + return err + } + result = y case spannerTypeArrayOfNullJSON: if acode != sppb.TypeCode_JSON { return errTypeMismatch(code, acode, ptr) @@ -2739,6 +2856,20 @@ func decodeNumericArray(pb *proto3.ListValue) ([]big.Rat, error) { return a, nil } +// decodePGNumericArray decodes proto3.ListValue pb into a PGNumeric slice. +func decodePGNumericArray(pb *proto3.ListValue) ([]PGNumeric, error) { + if pb == nil { + return nil, errNilListValue("PGNUMERIC") + } + a := make([]PGNumeric, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, pgNumericType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "PGNUMERIC", err) + } + } + return a, nil +} + // decodeByteArray decodes proto3.ListValue pb into a slice of byte slice. func decodeByteArray(pb *proto3.ListValue) ([][]byte, error) { if pb == nil { @@ -3273,6 +3404,19 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(numericType()) + case PGNumeric: + if v.Valid { + pb.Kind = stringKind(v.Numeric) + } + return pb, pgNumericType(), nil + case []PGNumeric: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(pgNumericType()) case NullJSON: if v.Valid { b, err := json.Marshal(v.Value) @@ -3481,6 +3625,8 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullNumeric{}))) case spannerTypeNullJSON: destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullJSON{}))) + case spannerTypePGNumeric: + destination = reflect.Indirect(reflect.New(reflect.TypeOf(PGNumeric{}))) case spannerTypeArrayOfNonNullString: if reflect.ValueOf(v).IsNil() { return []string(nil), nil @@ -3561,6 +3707,11 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int return []NullJSON(nil), nil } destination = reflect.MakeSlice(reflect.TypeOf([]NullJSON{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) + case spannerTypeArrayOfPGNumeric: + if reflect.ValueOf(v).IsNil() { + return []PGNumeric(nil), nil + } + destination = reflect.MakeSlice(reflect.TypeOf([]PGNumeric{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) default: // This should not be possible. return nil, fmt.Errorf("unknown decodable type found: %v", sourceType) diff --git a/spanner/value_test.go b/spanner/value_test.go index e90d3bc4d183..9b43aa1b456c 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -204,6 +204,7 @@ func TestEncodeValue(t *testing.T) { type CustomTime time.Time type CustomDate civil.Date type CustomNumeric big.Rat + type CustomPGNumeric PGNumeric type CustomNullString NullString type CustomNullInt64 NullInt64 @@ -247,15 +248,16 @@ func TestEncodeValue(t *testing.T) { minNumValuePtr, _ := (&big.Rat{}).SetString("-99999999999999999999999999999.999999999") var ( - tString = stringType() - tInt = intType() - tBool = boolType() - tFloat = floatType() - tBytes = bytesType() - tTime = timeType() - tDate = dateType() - tNumeric = numericType() - tJSON = jsonType() + tString = stringType() + tInt = intType() + tBool = boolType() + tFloat = floatType() + tBytes = bytesType() + tTime = timeType() + tDate = dateType() + tNumeric = numericType() + tJSON = jsonType() + tPGNumeric = pgNumericType() ) for i, test := range []struct { in interface{} @@ -331,6 +333,11 @@ func TestEncodeValue(t *testing.T) { {[]NullJSON{{msg, true}, {msg, false}}, listProto(stringProto(jsonStr), nullProto()), listType(tJSON), "[]NullJSON"}, {NullJSON{[]Message{}, true}, stringProto(emptyArrayJSONStr), tJSON, "a json string with empty array to NullJSON"}, {NullJSON{ptrMsg, true}, stringProto(nullValueJSONStr), tJSON, "a json string with null value to NullJSON"}, + // PG NUMERIC + {PGNumeric{"123.456", true}, stringProto("123.456"), tPGNumeric, "PG Numeric"}, + {PGNumeric{Valid: false}, nullProto(), tPGNumeric, "PG Numeric with a null value"}, + {[]PGNumeric(nil), nullProto(), listType(tPGNumeric), "null []PGNumeric"}, + {[]PGNumeric{{"123.456", true}, {Valid: false}}, listProto(stringProto("123.456"), nullProto()), listType(tPGNumeric), "[]PGNumeric"}, // TIMESTAMP / TIMESTAMP ARRAY {t1, timeProto(t1), tTime, "time"}, {NullTime{t1, true}, timeProto(t1), tTime, "NullTime with value"}, @@ -447,6 +454,11 @@ func TestEncodeValue(t *testing.T) { {CustomNullJSON{msg, false}, nullProto(), tJSON, "CustomNullJSON with null"}, {[]CustomNullJSON(nil), nullProto(), listType(tJSON), "null []CustomNullJSON"}, {[]CustomNullJSON{{msg, true}, {msg, false}}, listProto(stringProto(jsonStr), nullProto()), listType(tJSON), "[]CustomNullJSON"}, + // CUSTOM PG NUMERIC + {CustomPGNumeric{"123.456", true}, stringProto("123.456"), tPGNumeric, "PG Numeric"}, + {CustomPGNumeric{Valid: false}, nullProto(), tPGNumeric, "PG Numeric with a null value"}, + {[]CustomPGNumeric(nil), nullProto(), listType(tPGNumeric), "null []PGNumeric"}, + {[]CustomPGNumeric{{"123.456", true}, {Valid: false}}, listProto(stringProto("123.456"), nullProto()), listType(tPGNumeric), "[]PGNumeric"}, } { got, gotType, err := encodeValue(test.in) if err != nil { @@ -469,6 +481,11 @@ func TestEncodeInvalidValues(t *testing.T) { invalidNumPtr2, _ := (&big.Rat{}).SetString("199999999999999999999999999999.999999999") // Enable error mode. + oldValue := LossOfPrecisionHandling + defer func() { + // Reset the value to pre-test value + LossOfPrecisionHandling = oldValue + }() LossOfPrecisionHandling = NumericError for i, test := range []struct { @@ -1330,6 +1347,7 @@ func TestDecodeValue(t *testing.T) { type CustomNullDate NullDate type CustomNullNumeric NullNumeric type CustomNullJSON NullJSON + type CustomPGNumeric PGNumeric jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` var unmarshalledJSONStruct interface{} @@ -1472,6 +1490,13 @@ func TestDecodeValue(t *testing.T) { {desc: "decode ARRAY to []NullJSON", proto: listProto(stringProto(jsonStr), stringProto(jsonStr), nullProto()), protoType: listType(jsonType()), want: []NullJSON{{unmarshalledJSONStruct, true}, {unmarshalledJSONStruct, true}, {}}}, {desc: "decode ARRAY to NullJSON", proto: listProto(stringProto(jsonStr), nullProto(), stringProto("true")), protoType: listType(jsonType()), want: NullJSON{unmarshalledJSONArray, true}}, {desc: "decode NULL to []NullJSON", proto: nullProto(), protoType: listType(jsonType()), want: []NullJSON(nil)}, + // PG NUMERIC + {desc: "decode PG NUMERIC to PGNumeric", proto: stringProto("123.456"), protoType: pgNumericType(), want: PGNumeric{"123.456", true}}, + {desc: "decode NaN to PGNumeric", proto: stringProto("NaN"), protoType: pgNumericType(), want: PGNumeric{"NaN", true}}, + {desc: "decode NULL to PGNumeric", proto: nullProto(), protoType: pgNumericType(), want: PGNumeric{}}, + // PG NUMERIC ARRAY with []PGNumeric + {desc: "decode ARRAY to []PGNumeric", proto: listProto(stringProto("123.456"), stringProto("NaN"), nullProto()), protoType: listType(pgNumericType()), want: []PGNumeric{{"123.456", true}, {"NaN", true}, {}}}, + {desc: "decode NULL to []PGNumeric", proto: nullProto(), protoType: listType(pgNumericType()), want: []PGNumeric(nil)}, // TIMESTAMP {desc: "decode TIMESTAMP to time.Time", proto: timeProto(t1), protoType: timeType(), want: t1}, {desc: "decode TIMESTAMP to NullTime", proto: timeProto(t1), protoType: timeType(), want: NullTime{t1, true}}, @@ -1685,6 +1710,7 @@ func TestDecodeValue(t *testing.T) { {desc: "decode JSON to CustomNullJSON", proto: stringProto(jsonStr), protoType: jsonType(), want: CustomNullJSON{unmarshalledJSONStruct, true}}, {desc: "decode TIMESTAMP to CustomNullTime", proto: timeProto(t1), protoType: timeType(), want: CustomNullTime{t1, true}}, {desc: "decode DATE to CustomNullDate", proto: dateProto(d1), protoType: dateType(), want: CustomNullDate{d1, true}}, + {desc: "decode PG NUMERIC to CustomPGNumeric", proto: stringProto("123.456"), protoType: pgNumericType(), want: CustomPGNumeric{"123.456", true}}, {desc: "decode NULL to CustomNullString", proto: nullProto(), protoType: stringType(), want: CustomNullString{}}, {desc: "decode NULL to CustomNullInt64", proto: nullProto(), protoType: intType(), want: CustomNullInt64{}}, @@ -1694,6 +1720,7 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to CustomNullJSON", proto: nullProto(), protoType: jsonType(), want: CustomNullJSON{}}, {desc: "decode NULL to CustomNullTime", proto: nullProto(), protoType: timeType(), want: CustomNullTime{}}, {desc: "decode NULL to CustomNullDate", proto: nullProto(), protoType: dateType(), want: CustomNullDate{}}, + {desc: "decode NULL to CustomPGNumeric", proto: nullProto(), protoType: pgNumericType(), want: CustomPGNumeric{}}, // STRING ARRAY {desc: "decode NULL to []CustomString", proto: nullProto(), protoType: listType(stringType()), want: []CustomString(nil)}, @@ -1731,6 +1758,9 @@ func TestDecodeValue(t *testing.T) { // JSON ARRAY {desc: "decode NULL to []CustomNullJSON", proto: nullProto(), protoType: listType(jsonType()), want: []CustomNullJSON(nil)}, {desc: "decode ARRAY to []CustomNullJSON", proto: listProto(stringProto(jsonStr), stringProto(jsonStr), nullProto()), protoType: listType(jsonType()), want: []CustomNullJSON{{unmarshalledJSONStruct, true}, {unmarshalledJSONStruct, true}, {}}}, + // PG NUMERIC ARRAY + {desc: "decode NULL to []CustomPGNumeric", proto: nullProto(), protoType: listType(pgNumericType()), want: []CustomPGNumeric(nil)}, + {desc: "decode ARRAY to []CustomPGNumeric", proto: listProto(stringProto("123.456"), nullProto(), stringProto("1.23456")), protoType: listType(pgNumericType()), want: []CustomPGNumeric{{"123.456", true}, {}, {"1.23456", true}}}, // TIME ARRAY {desc: "decode NULL to []CustomTime", proto: nullProto(), protoType: listType(timeType()), want: []CustomTime(nil)}, {desc: "decode ARRAY with NULL values to []CustomTime", proto: listProto(timeProto(t1), nullProto(), timeProto(t2)), protoType: listType(timeType()), want: []CustomTime{}, wantErr: true}, @@ -2612,6 +2642,16 @@ func TestJSONMarshal_NullTypes(t *testing.T) { {input: NullJSON{}, expect: "null"}, }, }, + { + "PGNumeric", + []testcase{ + {input: PGNumeric{"123.456", true}, expect: `"123.456"`}, + {input: PGNumeric{"NaN", true}, expect: `"NaN"`}, + {input: &PGNumeric{"123.456", true}, expect: `"123.456"`}, + {input: &PGNumeric{"123.456", false}, expect: "null"}, + {input: PGNumeric{}, expect: "null"}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2723,6 +2763,17 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { {input: []byte(`{invalid_json_string}`), got: NullJSON{}, isNull: true, expect: nullString, expectError: true}, }, }, + { + "PGNumeric", + []testcase{ + {input: []byte(`"123.456"`), got: PGNumeric{}, isNull: false, expect: "123.456", expectError: false}, + {input: []byte(`"NaN"`), got: PGNumeric{}, isNull: false, expect: "NaN", expectError: false}, + {input: []byte("null"), got: PGNumeric{}, isNull: true, expect: nullString, expectError: false}, + {input: nil, got: PGNumeric{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(""), got: PGNumeric{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(`"123.456`), got: PGNumeric{}, isNull: true, expect: nullString, expectError: true}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2751,6 +2802,9 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { case NullJSON: err := json.Unmarshal(tc.input, &v) expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) + case PGNumeric: + err := json.Unmarshal(tc.input, &v) + expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) default: t.Fatalf("Unknown type: %T", v) }