Skip to content

Commit

Permalink
Handle json and jsonb types as bytes. (#240)
Browse files Browse the repository at this point in the history
* Handle `json` and `jsonb` types as bytes.

By default pgx will use encoding/json and return a map of the json.
This behaviour does not play well with Avro and can result in each record
to emit different schema (since JSON can be anything).

Thus any JSON data will be returned as bytes and allow the user to serder into
whatever type they desire.
  • Loading branch information
lyuboxa authored Jan 10, 2025
1 parent ee46c21 commit 134b515
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 31 deletions.
20 changes: 20 additions & 0 deletions source/cpool/cpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ package cpool

import (
"context"
"encoding/json"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
)

Expand All @@ -37,6 +39,7 @@ func New(ctx context.Context, conninfo string) (*pgxpool.Pool, error) {

config.BeforeAcquire = beforeAcquireHook
config.BeforeConnect = beforeConnectHook
config.AfterConnect = afterConnectHook
config.AfterRelease = afterReleaseHook

pool, err := pgxpool.NewWithConfig(ctx, config)
Expand All @@ -47,6 +50,23 @@ func New(ctx context.Context, conninfo string) (*pgxpool.Pool, error) {
return pool, nil
}

func afterConnectHook(_ context.Context, conn *pgx.Conn) error {
// Override the JSON and JSONB codec to return bytes rather than the
// unmarshalled representation of map.
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "json",
OID: pgtype.JSONOID,
Codec: &pgtype.JSONCodec{Marshal: json.Marshal, Unmarshal: jsonNoopUnmarshal},
})
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "jsonb",
OID: pgtype.JSONBOID,
Codec: &pgtype.JSONBCodec{Marshal: json.Marshal, Unmarshal: jsonNoopUnmarshal},
})

return nil
}

// beforeAcquireHook ensures purpose specific connections are returned:
// * If a replication connection is requested, ensure the connection has replication enabled.
// * If a regular connection is requested, return non-replication connections.
Expand Down
37 changes: 37 additions & 0 deletions source/cpool/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cpool

import (
"encoding/json"
"reflect"
)

// noopUnmarshal will copy source into dst.
// this is to be used with the pgtype JSON codec
func jsonNoopUnmarshal(src []byte, dst any) error {
dstptr, ok := (dst.(*any))
if dst == nil || !ok {
return &json.InvalidUnmarshalError{Type: reflect.TypeOf(dst)}
}

v := make([]byte, len(src))
copy(v, src)

// set the slice to the value of the ptr.
*dstptr = v

return nil
}
45 changes: 45 additions & 0 deletions source/cpool/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cpool

import (
"testing"

"github.com/matryer/is"
)

func Test_jsonNoopUnmarshal(t *testing.T) {
is := is.New(t)

var dst any
data := []byte(`{"foo":"bar"}`)

is.NoErr(jsonNoopUnmarshal(data, &dst))
is.Equal(data, dst.([]byte))

var err error

err = jsonNoopUnmarshal(data, dst)
is.True(err != nil)
if err != nil {
is.Equal(err.Error(), "json: Unmarshal(non-pointer []uint8)")
}

err = jsonNoopUnmarshal(data, nil)
is.True(err != nil)
if err != nil {
is.Equal(err.Error(), "json: Unmarshal(nil)")
}
}
26 changes: 19 additions & 7 deletions source/logrepl/cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ func TestCDCIterator_Next(t *testing.T) {
name: "should detect insert",
setup: func(t *testing.T) {
is := is.New(t)
query := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (6, 'bizz', 456, false, 12.3, 14)`, table)
query := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7)
VALUES (6, 'bizz', 456, false, 12.3, 14, '{"foo2": "bar2"}', '{"foo2": "baz2"}')`, table)
_, err := pool.Exec(ctx, query)
is.NoErr(err)
},
Expand All @@ -165,6 +165,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": false,
"column4": 12.3,
"column5": int64(14),
"column6": []byte(`{"foo2": "bar2"}`),
"column7": []byte(`{"foo2": "baz2"}`),
"key": nil,
},
},
Expand Down Expand Up @@ -197,6 +199,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": false,
"column4": 12.2,
"column5": int64(4),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
"key": []uint8("1"),
},
},
Expand Down Expand Up @@ -231,6 +235,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": false,
"column4": 12.2,
"column5": int64(4),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
"key": []uint8("1"),
},
After: opencdc.StructuredData{
Expand All @@ -240,6 +246,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": false,
"column4": 12.2,
"column5": int64(4),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
"key": []uint8("1"),
},
},
Expand Down Expand Up @@ -274,6 +282,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": nil,
"column4": nil,
"column5": nil,
"column6": nil,
"column7": nil,
"key": nil,
},
},
Expand Down Expand Up @@ -309,6 +319,8 @@ func TestCDCIterator_Next(t *testing.T) {
"column3": false,
"column4": nil,
"column5": int64(9),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
},
},
},
Expand Down Expand Up @@ -552,13 +564,13 @@ func TestCDCIterator_Schema(t *testing.T) {
t.Run("column added", func(t *testing.T) {
is := is.New(t)

_, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN column6 timestamp;`, table))
_, err := pool.Exec(ctx, fmt.Sprintf(`ALTER TABLE %s ADD COLUMN column101 timestamp;`, table))
is.NoErr(err)

_, err = pool.Exec(
ctx,
fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column5, column6)
VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, 12345, '2023-09-09 10:00:00');`, table),
fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column4, column5, column6, column7, column101)
VALUES (7, decode('aabbcc', 'hex'), 'example data 1', 100, true, 12345.678, 12345, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table),
)
is.NoErr(err)

Expand All @@ -577,8 +589,8 @@ func TestCDCIterator_Schema(t *testing.T) {

_, err = pool.Exec(
ctx,
fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column6)
VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '2023-09-09 10:00:00');`, table),
fmt.Sprintf(`INSERT INTO %s (id, key, column1, column2, column3, column6, column7, column101)
VALUES (8, decode('aabbcc', 'hex'), 'example data 1', 100, true, '{"foo":"bar"}', '{"foo2":"baz2"}', '2023-09-09 10:00:00');`, table),
)
is.NoErr(err)

Expand Down
20 changes: 16 additions & 4 deletions source/logrepl/combined_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ func TestCombinedIterator_Next(t *testing.T) {
is.NoErr(err)

_, err = pool.Exec(ctx, fmt.Sprintf(
`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (6, 'bizz', 1010, false, 872.2, 101)`,
`INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7)
VALUES (6, 'bizz', 1010, false, 872.2, 101, '{"foo12": "bar12"}', '{"foo13": "bar13"}')`,
table,
))
is.NoErr(err)
Expand Down Expand Up @@ -233,8 +233,8 @@ func TestCombinedIterator_Next(t *testing.T) {
is.NoErr(err)

_, err = pool.Exec(ctx, fmt.Sprintf(
`INSERT INTO %s (id, column1, column2, column3, column4, column5)
VALUES (7, 'buzz', 10101, true, 121.9, 51)`,
`INSERT INTO %s (id, column1, column2, column3, column4, column5, column6, column7)
VALUES (7, 'buzz', 10101, true, 121.9, 51, '{"foo7": "bar7"}', '{"foo8": "bar8"}')`,
table,
))
is.NoErr(err)
Expand Down Expand Up @@ -277,6 +277,8 @@ func testRecords() []opencdc.StructuredData {
"column3": false,
"column4": 12.2,
"column5": int64(4),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
},
{
"id": int64(2),
Expand All @@ -286,6 +288,8 @@ func testRecords() []opencdc.StructuredData {
"column3": true,
"column4": 13.42,
"column5": int64(8),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
},
{
"id": int64(3),
Expand All @@ -295,6 +299,8 @@ func testRecords() []opencdc.StructuredData {
"column3": false,
"column4": nil,
"column5": int64(9),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
},
{
"id": int64(4),
Expand All @@ -304,6 +310,8 @@ func testRecords() []opencdc.StructuredData {
"column3": nil,
"column4": 91.1,
"column5": nil,
"column6": nil,
"column7": nil,
},
{
"id": int64(6),
Expand All @@ -313,6 +321,8 @@ func testRecords() []opencdc.StructuredData {
"column3": false,
"column4": 872.2,
"column5": int64(101),
"column6": []byte(`{"foo12": "bar12"}`),
"column7": []byte(`{"foo13": "bar13"}`),
},
{
"id": int64(7),
Expand All @@ -322,6 +332,8 @@ func testRecords() []opencdc.StructuredData {
"column3": true,
"column4": 121.9,
"column5": int64(51),
"column6": []byte(`{"foo7": "bar7"}`),
"column7": []byte(`{"foo8": "bar8"}`),
},
}
}
6 changes: 4 additions & 2 deletions source/logrepl/internal/relationset.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ func (rs *RelationSet) decodeValue(col *pglogrepl.RelationMessageColumn, data []
// (see: https://github.com/jackc/pgx/pull/2083#discussion_r1755768269).
var val any
var err error
if col.DataType == pgtype.XMLOID || col.DataType == pgtype.XMLArrayOID {

switch col.DataType {
case pgtype.XMLOID, pgtype.XMLArrayOID, pgtype.JSONBOID, pgtype.JSONOID:
val, err = decoder.DecodeDatabaseSQLValue(rs.connInfo, col.DataType, pgtype.TextFormatCode, data)
} else {
default:
val, err = decoder.DecodeValue(rs.connInfo, col.DataType, pgtype.TextFormatCode, data)
}

Expand Down
12 changes: 6 additions & 6 deletions source/logrepl/internal/relationset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ func insertRowAllTypes(ctx context.Context, t *testing.T, conn test.Querier, tab
2147483647, -- col_int4
9223372036854775807, -- col_int8
'18 seconds', -- col_interval
'{"foo":"bar"}', -- col_json
'{"foo":"baz"}', -- col_jsonb
'{"foo": "bar"}', -- col_json
'{"foo": "baz"}', -- col_jsonb
'{19,20,21}', -- col_line
'((22,23),(24,25))', -- col_lseg
'08:00:2b:01:02:26', -- col_macaddr
Expand Down Expand Up @@ -297,8 +297,8 @@ func isValuesAllTypes(is *is.I, got map[string]any) {
Months: 0,
Valid: true,
},
"col_json": map[string]any{"foo": "bar"},
"col_jsonb": map[string]any{"foo": "baz"},
"col_json": []byte(`{"foo": "bar"}`),
"col_jsonb": []byte(`{"foo": "baz"}`),
"col_line": pgtype.Line{
A: 19,
B: 20,
Expand Down Expand Up @@ -393,8 +393,8 @@ func isValuesAllTypesStandalone(is *is.I, got map[string]any) {
Months: 0,
Valid: true,
},
"col_json": map[string]any{"foo": "bar"},
"col_jsonb": map[string]any{"foo": "baz"},
"col_json": []byte(`{"foo": "bar"}`),
"col_jsonb": []byte(`{"foo": "baz"}`),
"col_line": pgtype.Line{
A: 19,
B: 20,
Expand Down
2 changes: 2 additions & 0 deletions source/schema/avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ var Avro = &avroExtractor{
"int2": avro.NewPrimitiveSchema(avro.Int, nil),
"text": avro.NewPrimitiveSchema(avro.String, nil),
"varchar": avro.NewPrimitiveSchema(avro.String, nil),
"jsonb": avro.NewPrimitiveSchema(avro.Bytes, nil),
"json": avro.NewPrimitiveSchema(avro.Bytes, nil),
"timestamptz": avro.NewPrimitiveSchema(
avro.Long,
avro.NewPrimitiveLogicalSchema(avro.TimestampMicros),
Expand Down
15 changes: 11 additions & 4 deletions source/snapshot/fetch_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,16 @@ func Test_FetcherRun_Initial(t *testing.T) {
is.NoErr(tt.Err())
is.True(len(gotFetchData) == 4)

var (
value6 = []byte(`{"foo": "bar"}`)
value7 = []byte(`{"foo": "baz"}`)
)

expectedMatch := []opencdc.StructuredData{
{"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4)},
{"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8)},
{"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9)},
{"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil},
{"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4), "column6": value6, "column7": value7},
{"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8), "column6": value6, "column7": value7},
{"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9), "column6": value6, "column7": value7},
{"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil, "column6": nil, "column7": nil},
}

for i, got := range gotFetchData {
Expand Down Expand Up @@ -342,6 +347,8 @@ func Test_FetcherRun_Resume(t *testing.T) {
"column3": false,
"column4": nil,
"column5": int64(9),
"column6": []byte(`{"foo": "bar"}`),
"column7": []byte(`{"foo": "baz"}`),
}))

is.Equal(dd[0].Position, position.SnapshotPosition{
Expand Down
Loading

0 comments on commit 134b515

Please sign in to comment.