From 134b5151084e74d6661dd3cc072d297446d04dd3 Mon Sep 17 00:00:00 2001 From: Lyubo Kamenov Date: Fri, 10 Jan 2025 15:21:19 -0500 Subject: [PATCH] Handle `json` and `jsonb` types as bytes. (#240) * Handle `json` and `jsonb` types as bytes. By default pgx will use encoding/json and return a map of the json. This behaviour does not play well with Avro and can result in each record to emit different schema (since JSON can be anything). Thus any JSON data will be returned as bytes and allow the user to serder into whatever type they desire. --- source/cpool/cpool.go | 20 +++++++++ source/cpool/json.go | 37 +++++++++++++++++ source/cpool/json_test.go | 45 +++++++++++++++++++++ source/logrepl/cdc_test.go | 26 ++++++++---- source/logrepl/combined_test.go | 20 +++++++-- source/logrepl/internal/relationset.go | 6 ++- source/logrepl/internal/relationset_test.go | 12 +++--- source/schema/avro.go | 2 + source/snapshot/fetch_worker_test.go | 15 +++++-- test/helper.go | 24 +++++++---- 10 files changed, 176 insertions(+), 31 deletions(-) create mode 100644 source/cpool/json.go create mode 100644 source/cpool/json_test.go diff --git a/source/cpool/cpool.go b/source/cpool/cpool.go index ed07741..c6c187d 100644 --- a/source/cpool/cpool.go +++ b/source/cpool/cpool.go @@ -16,9 +16,11 @@ package cpool import ( "context" + "encoding/json" "fmt" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" ) @@ -37,6 +39,7 @@ func New(ctx context.Context, conninfo string) (*pgxpool.Pool, error) { config.BeforeAcquire = beforeAcquireHook config.BeforeConnect = beforeConnectHook + config.AfterConnect = afterConnectHook config.AfterRelease = afterReleaseHook pool, err := pgxpool.NewWithConfig(ctx, config) @@ -47,6 +50,23 @@ func New(ctx context.Context, conninfo string) (*pgxpool.Pool, error) { return pool, nil } +func afterConnectHook(_ context.Context, conn *pgx.Conn) error { + // Override the JSON and JSONB codec to return bytes rather than the + // unmarshalled representation of map. + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "json", + OID: pgtype.JSONOID, + Codec: &pgtype.JSONCodec{Marshal: json.Marshal, Unmarshal: jsonNoopUnmarshal}, + }) + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "jsonb", + OID: pgtype.JSONBOID, + Codec: &pgtype.JSONBCodec{Marshal: json.Marshal, Unmarshal: jsonNoopUnmarshal}, + }) + + return nil +} + // beforeAcquireHook ensures purpose specific connections are returned: // * If a replication connection is requested, ensure the connection has replication enabled. // * If a regular connection is requested, return non-replication connections. diff --git a/source/cpool/json.go b/source/cpool/json.go new file mode 100644 index 0000000..5c532ca --- /dev/null +++ b/source/cpool/json.go @@ -0,0 +1,37 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cpool + +import ( + "encoding/json" + "reflect" +) + +// noopUnmarshal will copy source into dst. +// this is to be used with the pgtype JSON codec +func jsonNoopUnmarshal(src []byte, dst any) error { + dstptr, ok := (dst.(*any)) + if dst == nil || !ok { + return &json.InvalidUnmarshalError{Type: reflect.TypeOf(dst)} + } + + v := make([]byte, len(src)) + copy(v, src) + + // set the slice to the value of the ptr. + *dstptr = v + + return nil +} diff --git a/source/cpool/json_test.go b/source/cpool/json_test.go new file mode 100644 index 0000000..29b82ac --- /dev/null +++ b/source/cpool/json_test.go @@ -0,0 +1,45 @@ +// Copyright © 2024 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cpool + +import ( + "testing" + + "github.com/matryer/is" +) + +func Test_jsonNoopUnmarshal(t *testing.T) { + is := is.New(t) + + var dst any + data := []byte(`{"foo":"bar"}`) + + is.NoErr(jsonNoopUnmarshal(data, &dst)) + is.Equal(data, dst.([]byte)) + + var err error + + err = jsonNoopUnmarshal(data, dst) + is.True(err != nil) + if err != nil { + is.Equal(err.Error(), "json: Unmarshal(non-pointer []uint8)") + } + + err = jsonNoopUnmarshal(data, nil) + is.True(err != nil) + if err != nil { + is.Equal(err.Error(), "json: Unmarshal(nil)") + } +} diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index d96ec93..7e30c67 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -140,8 +140,8 @@ func TestCDCIterator_Next(t *testing.T) { name: "should detect insert", setup: func(t *testing.T) { is := is.New(t) - query := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (6, 'bizz', 456, false, 12.3, 14)`, table) + query := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) + VALUES (6, 'bizz', 456, false, 12.3, 14, '{"foo2": "bar2"}', '{"foo2": "baz2"}')`, table) _, err := pool.Exec(ctx, query) is.NoErr(err) }, @@ -165,6 +165,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": false, "column4": 12.3, "column5": int64(14), + "column6": []byte(`{"foo2": "bar2"}`), + "column7": []byte(`{"foo2": "baz2"}`), "key": nil, }, }, @@ -197,6 +199,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": false, "column4": 12.2, "column5": int64(4), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), }, }, @@ -231,6 +235,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": false, "column4": 12.2, "column5": int64(4), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), }, After: opencdc.StructuredData{ @@ -240,6 +246,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": false, "column4": 12.2, "column5": int64(4), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), }, }, @@ -274,6 +282,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": nil, "column4": nil, "column5": nil, + "column6": nil, + "column7": nil, "key": nil, }, }, @@ -309,6 +319,8 @@ func TestCDCIterator_Next(t *testing.T) { "column3": false, "column4": nil, "column5": int64(9), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), }, }, }, @@ -552,13 +564,13 @@ func TestCDCIterator_Schema(t *testing.T) { t.Run("column added", func(t *testing.T) { is := is.New(t) - _, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN column6 timestamp;`, table)) + _, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN column101 timestamp;`, table)) is.NoErr(err) _, err = pool.Exec( ctx, - fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column5, column6) - VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, 12345, '2023-09-09 10:00:00');`, table), + fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column5, column6, column7, column101) + VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, 12345, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table), ) is.NoErr(err) @@ -577,8 +589,8 @@ func TestCDCIterator_Schema(t *testing.T) { _, err = pool.Exec( ctx, - fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column6) - VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '2023-09-09 10:00:00');`, table), + fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column6, column7, column101) + VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table), ) is.NoErr(err) diff --git a/source/logrepl/combined_test.go b/source/logrepl/combined_test.go index 0ab4c3a..7f54c70 100644 --- a/source/logrepl/combined_test.go +++ b/source/logrepl/combined_test.go @@ -156,8 +156,8 @@ func TestCombinedIterator_Next(t *testing.T) { is.NoErr(err) _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (6, 'bizz', 1010, false, 872.2, 101)`, + `INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) + VALUES (6, 'bizz', 1010, false, 872.2, 101, '{"foo12": "bar12"}', '{"foo13": "bar13"}')`, table, )) is.NoErr(err) @@ -233,8 +233,8 @@ func TestCombinedIterator_Next(t *testing.T) { is.NoErr(err) _, err = pool.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, column1, column2, column3, column4, column5) - VALUES (7, 'buzz', 10101, true, 121.9, 51)`, + `INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7) + VALUES (7, 'buzz', 10101, true, 121.9, 51, '{"foo7": "bar7"}', '{"foo8": "bar8"}')`, table, )) is.NoErr(err) @@ -277,6 +277,8 @@ func testRecords() []opencdc.StructuredData { "column3": false, "column4": 12.2, "column5": int64(4), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), }, { "id": int64(2), @@ -286,6 +288,8 @@ func testRecords() []opencdc.StructuredData { "column3": true, "column4": 13.42, "column5": int64(8), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), }, { "id": int64(3), @@ -295,6 +299,8 @@ func testRecords() []opencdc.StructuredData { "column3": false, "column4": nil, "column5": int64(9), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), }, { "id": int64(4), @@ -304,6 +310,8 @@ func testRecords() []opencdc.StructuredData { "column3": nil, "column4": 91.1, "column5": nil, + "column6": nil, + "column7": nil, }, { "id": int64(6), @@ -313,6 +321,8 @@ func testRecords() []opencdc.StructuredData { "column3": false, "column4": 872.2, "column5": int64(101), + "column6": []byte(`{"foo12": "bar12"}`), + "column7": []byte(`{"foo13": "bar13"}`), }, { "id": int64(7), @@ -322,6 +332,8 @@ func testRecords() []opencdc.StructuredData { "column3": true, "column4": 121.9, "column5": int64(51), + "column6": []byte(`{"foo7": "bar7"}`), + "column7": []byte(`{"foo8": "bar8"}`), }, } } diff --git a/source/logrepl/internal/relationset.go b/source/logrepl/internal/relationset.go index f87336d..ccc718c 100644 --- a/source/logrepl/internal/relationset.go +++ b/source/logrepl/internal/relationset.go @@ -93,9 +93,11 @@ func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data [] // (see: https://github.com/jackc/pgx/pull/2083#discussion_r1755768269). var val any var err error - if col.DataType == pgtype.XMLOID || col.DataType == pgtype.XMLArrayOID { + + switch col.DataType { + case pgtype.XMLOID, pgtype.XMLArrayOID, pgtype.JSONBOID, pgtype.JSONOID: val, err = decoder.DecodeDatabaseSQLValue(rs.connInfo, col.DataType, pgtype.TextFormatCode, data) - } else { + default: val, err = decoder.DecodeValue(rs.connInfo, col.DataType, pgtype.TextFormatCode, data) } diff --git a/source/logrepl/internal/relationset_test.go b/source/logrepl/internal/relationset_test.go index 04a3cbd..614a6e1 100644 --- a/source/logrepl/internal/relationset_test.go +++ b/source/logrepl/internal/relationset_test.go @@ -227,8 +227,8 @@ func insertRowAllTypes(ctx context.Context, t *testing.T, conn test.Querier, tab 2147483647, -- col_int4 9223372036854775807, -- col_int8 '18 seconds', -- col_interval - '{"foo":"bar"}', -- col_json - '{"foo":"baz"}', -- col_jsonb + '{"foo": "bar"}', -- col_json + '{"foo": "baz"}', -- col_jsonb '{19,20,21}', -- col_line '((22,23),(24,25))', -- col_lseg '08:00:2b:01:02:26', -- col_macaddr @@ -297,8 +297,8 @@ func isValuesAllTypes(is *is.I, got map[string]any) { Months: 0, Valid: true, }, - "col_json": map[string]any{"foo": "bar"}, - "col_jsonb": map[string]any{"foo": "baz"}, + "col_json": []byte(`{"foo": "bar"}`), + "col_jsonb": []byte(`{"foo": "baz"}`), "col_line": pgtype.Line{ A: 19, B: 20, @@ -393,8 +393,8 @@ func isValuesAllTypesStandalone(is *is.I, got map[string]any) { Months: 0, Valid: true, }, - "col_json": map[string]any{"foo": "bar"}, - "col_jsonb": map[string]any{"foo": "baz"}, + "col_json": []byte(`{"foo": "bar"}`), + "col_jsonb": []byte(`{"foo": "baz"}`), "col_line": pgtype.Line{ A: 19, B: 20, diff --git a/source/schema/avro.go b/source/schema/avro.go index 2350615..0026a2b 100644 --- a/source/schema/avro.go +++ b/source/schema/avro.go @@ -37,6 +37,8 @@ var Avro = &avroExtractor{ "int2": avro.NewPrimitiveSchema(avro.Int, nil), "text": avro.NewPrimitiveSchema(avro.String, nil), "varchar": avro.NewPrimitiveSchema(avro.String, nil), + "jsonb": avro.NewPrimitiveSchema(avro.Bytes, nil), + "json": avro.NewPrimitiveSchema(avro.Bytes, nil), "timestamptz": avro.NewPrimitiveSchema( avro.Long, avro.NewPrimitiveLogicalSchema(avro.TimestampMicros), diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go index a29fcd7..786dd83 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_test.go @@ -268,11 +268,16 @@ func Test_FetcherRun_Initial(t *testing.T) { is.NoErr(tt.Err()) is.True(len(gotFetchData) == 4) + var ( + value6 = []byte(`{"foo": "bar"}`) + value7 = []byte(`{"foo": "baz"}`) + ) + expectedMatch := []opencdc.StructuredData{ - {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4)}, - {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8)}, - {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9)}, - {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil}, + {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4), "column6": value6, "column7": value7}, + {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8), "column6": value6, "column7": value7}, + {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9), "column6": value6, "column7": value7}, + {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil, "column6": nil, "column7": nil}, } for i, got := range gotFetchData { @@ -342,6 +347,8 @@ func Test_FetcherRun_Resume(t *testing.T) { "column3": false, "column4": nil, "column5": int64(9), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), })) is.Equal(dd[0].Position, position.SnapshotPosition{ diff --git a/test/helper.go b/test/helper.go index 41d9385..113fe44 100644 --- a/test/helper.go +++ b/test/helper.go @@ -67,6 +67,8 @@ const TestTableAvroSchemaV1 = `{ "precision": 5 } }, + {"name":"column6","type":"bytes"}, + {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -79,6 +81,7 @@ const TestTableAvroSchemaV2 = `{ "fields": [ {"name":"column1","type":"string"}, + {"name":"column101","type":{"type":"long","logicalType":"local-timestamp-micros"}}, {"name":"column2","type":"int"}, {"name":"column3","type":"boolean"}, { @@ -100,7 +103,8 @@ const TestTableAvroSchemaV2 = `{ "precision": 5 } }, - {"name":"column6","type":{"type":"long","logicalType":"local-timestamp-micros"}}, + {"name":"column6","type":"bytes"}, + {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -113,9 +117,11 @@ const TestTableAvroSchemaV3 = `{ "fields": [ {"name":"column1","type":"string"}, + {"name":"column101","type":{"type":"long","logicalType":"local-timestamp-micros"}}, {"name":"column2","type":"int"}, {"name":"column3","type":"boolean"}, - {"name":"column6","type":{"type":"long","logicalType":"local-timestamp-micros"}}, + {"name":"column6","type":"bytes"}, + {"name":"column7","type":"bytes"}, {"name":"id","type":"long"}, {"name":"key","type":"bytes"} ] @@ -140,7 +146,9 @@ const testTableCreateQuery = ` column2 integer, column3 boolean, column4 numeric(16,3), - column5 numeric(5) + column5 numeric(5), + column6 jsonb, + column7 json )` type Querier interface { @@ -198,11 +206,11 @@ func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { table := SetupEmptyTestTable(ctx, t, conn) query := ` - INSERT INTO %s (key, column1, column2, column3, column4, column5) - VALUES ('1', 'foo', 123, false, 12.2, 4), - ('2', 'bar', 456, true, 13.42, 8), - ('3', 'baz', 789, false, null, 9), - ('4', null, null, null, 91.1, null)` + INSERT INTO %s (key, column1, column2, column3, column4, column5, column6, column7) + VALUES ('1', 'foo', 123, false, 12.2, 4, '{"foo": "bar"}', '{"foo": "baz"}'), + ('2', 'bar', 456, true, 13.42, 8, '{"foo": "bar"}', '{"foo": "baz"}'), + ('3', 'baz', 789, false, null, 9, '{"foo": "bar"}', '{"foo": "baz"}'), + ('4', null, null, null, 91.1, null, null, null)` query = fmt.Sprintf(query, table) _, err := conn.Exec(ctx, query) is.NoErr(err)