From 3adad2b18ecd3aed7a6937547d5af3f14645b786 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Tue, 5 Dec 2023 09:58:23 -0700 Subject: [PATCH] GODRIVER-3009 Fix concurrent panic in struct codec. (#1477) (#1489) Co-authored-by: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> --- bson/bsoncodec/registry.go | 9 +++------ bson/bsoncodec/registry_test.go | 30 ++++++++++++++++++++++++++++++ bson/marshal_test.go | 17 +++++++++++++++++ bson/unmarshal_test.go | 19 +++++++++++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/bson/bsoncodec/registry.go b/bson/bsoncodec/registry.go index f309ee2b39..196c491bbb 100644 --- a/bson/bsoncodec/registry.go +++ b/bson/bsoncodec/registry.go @@ -388,6 +388,9 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for // concurrent use by multiple goroutines after all codecs and encoders are registered. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { + if valueType == nil { + return nil, ErrNoEncoder{Type: valueType} + } enc, found := r.lookupTypeEncoder(valueType) if found { if enc == nil { @@ -400,15 +403,10 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if found { return r.typeEncoders.LoadOrStore(valueType, enc), nil } - if valueType == nil { - r.storeTypeEncoder(valueType, nil) - return nil, ErrNoEncoder{Type: valueType} - } if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { return r.storeTypeEncoder(valueType, v), nil } - r.storeTypeEncoder(valueType, nil) return nil, ErrNoEncoder{Type: valueType} } @@ -474,7 +472,6 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { return r.storeTypeDecoder(valueType, v), nil } - r.storeTypeDecoder(valueType, nil) return nil, ErrNoDecoder{Type: valueType} } diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index acc24a6e4d..2a7d50a719 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -792,6 +792,36 @@ func TestRegistry(t *testing.T) { }) }) } + t.Run("nil type", func(t *testing.T) { + t.Parallel() + + t.Run("Encoder", func(t *testing.T) { + t.Parallel() + + wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)} + + gotcodec, goterr := reg.LookupEncoder(nil) + if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr) + } + if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) { + t.Errorf("codecs did not match: got %#v, want nil", gotcodec) + } + }) + t.Run("Decoder", func(t *testing.T) { + t.Parallel() + + wanterr := ErrNilType + + gotcodec, goterr := reg.LookupDecoder(nil) + if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { + t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr) + } + if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) { + t.Errorf("codecs did not match: got %v: want nil", gotcodec) + } + }) + }) // lookup a type whose pointer implements an interface and expect that the registered hook is // returned t.Run("interface implementation with hook (pointer)", func(t *testing.T) { diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 54b27dfcf1..99a3bba67e 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "reflect" + "sync" "testing" "time" @@ -380,3 +381,19 @@ func TestMarshalExtJSONIndent(t *testing.T) { }) } } + +func TestMarshalConcurrently(t *testing.T) { + t.Parallel() + + const size = 10_000 + + wg := sync.WaitGroup{} + wg.Add(size) + for i := 0; i < size; i++ { + go func() { + defer wg.Done() + _, _ = Marshal(struct{ LastError error }{}) + }() + } + wg.Wait() +} diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 11452a895c..2283b96771 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -9,6 +9,7 @@ package bson import ( "math/rand" "reflect" + "sync" "testing" "go.mongodb.org/mongo-driver/bson/bsoncodec" @@ -773,3 +774,21 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) { }) } } + +func TestUnmarshalConcurrently(t *testing.T) { + t.Parallel() + + const size = 10_000 + + data := []byte{16, 0, 0, 0, 10, 108, 97, 115, 116, 101, 114, 114, 111, 114, 0, 0} + wg := sync.WaitGroup{} + wg.Add(size) + for i := 0; i < size; i++ { + go func() { + defer wg.Done() + var res struct{ LastError error } + _ = Unmarshal(data, &res) + }() + } + wg.Wait() +}