From a48d90c56126bdb54259406645e008099e942cdb Mon Sep 17 00:00:00 2001 From: Chu Ka-cheong Date: Tue, 20 Sep 2022 12:34:20 +0800 Subject: [PATCH] Fix unintented initialization of pointer fields - Invalid sql.Null* values should be copied to pointer field as nil - Embedded struct should be initialized if its member is copied directly - Add test cases to cover the above cases --- copier.go | 3 ++- copier_test.go | 60 ++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/copier.go b/copier.go index 84efd25..e7b5410 100644 --- a/copier.go +++ b/copier.go @@ -263,7 +263,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) // process for nested anonymous field destFieldNotSet := false if f, ok := dest.Type().FieldByName(destFieldName); ok { - for idx := range f.Index { + // only initialize parent embedded struct pointer in the path + for idx := range f.Index[:len(f.Index)-1] { destField := dest.FieldByIndex(f.Index[:idx+1]) if destField.Kind() != reflect.Ptr { diff --git a/copier_test.go b/copier_test.go index ec097cc..08db7f3 100644 --- a/copier_test.go +++ b/copier_test.go @@ -1089,6 +1089,36 @@ func TestAnonymousFields(t *testing.T) { } }) + t.Run("Should work with exported ptr fields with same name src field", func(t *testing.T) { + type Nested struct { + A string + } + type parentA struct { + A string + } + type parentB struct { + *Nested + } + + fieldValue := "a" + from := parentA{A: fieldValue} + to := parentB{} + + err := copier.CopyWithOption(&to, &from, copier.Option{ + DeepCopy: true, + }) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + from.A = "b" + + if to.Nested.A != fieldValue { + t.Errorf("should not change") + } + }) + t.Run("Should work with exported fields", func(t *testing.T) { type Nested struct { A string @@ -1356,19 +1386,21 @@ func TestScanner(t *testing.T) { func TestScanFromPtrToSqlNullable(t *testing.T) { var ( from struct { - S string - Sptr *string - T1 sql.NullTime - T2 sql.NullTime - T3 *time.Time + S string + Sptr *string + Snull sql.NullString + T1 sql.NullTime + T2 sql.NullTime + T3 *time.Time } to struct { - S sql.NullString - Sptr sql.NullString - T1 time.Time - T2 *time.Time - T3 sql.NullTime + S sql.NullString + Sptr sql.NullString + Snull *string + T1 time.Time + T2 *time.Time + T3 sql.NullTime } s string @@ -1393,8 +1425,12 @@ func TestScanFromPtrToSqlNullable(t *testing.T) { t.Errorf("to.T1 should be Zero but %v", to.T1) } - if to.T2 != nil && !to.T2.IsZero() { - t.Errorf("to.T2 should be Zero but %v", to.T2) + if to.T2 != nil { + t.Errorf("to.T2 should be nil but %v", to.T2) + } + + if to.Snull != nil { + t.Errorf("to.Snull should be nil but %v", to.Snull) } now := time.Now()