Skip to content

Commit

Permalink
Merge pull request #166 from kcchu/master
Browse files Browse the repository at this point in the history
Fix unintented initialization of pointer fields
  • Loading branch information
jinzhu authored Oct 21, 2022
2 parents 0e264e9 + a48d90c commit 155f7ce
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
3 changes: 2 additions & 1 deletion copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
// process for nested anonymous field
destFieldNotSet := false
if f, ok := dest.Type().FieldByName(destFieldName); ok {
for idx := range f.Index {
// only initialize parent embedded struct pointer in the path
for idx := range f.Index[:len(f.Index)-1] {
destField := dest.FieldByIndex(f.Index[:idx+1])

if destField.Kind() != reflect.Ptr {
Expand Down
60 changes: 48 additions & 12 deletions copier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,36 @@ func TestAnonymousFields(t *testing.T) {
}
})

t.Run("Should work with exported ptr fields with same name src field", func(t *testing.T) {
type Nested struct {
A string
}
type parentA struct {
A string
}
type parentB struct {
*Nested
}

fieldValue := "a"
from := parentA{A: fieldValue}
to := parentB{}

err := copier.CopyWithOption(&to, &from, copier.Option{
DeepCopy: true,
})
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}

from.A = "b"

if to.Nested.A != fieldValue {
t.Errorf("should not change")
}
})

t.Run("Should work with exported fields", func(t *testing.T) {
type Nested struct {
A string
Expand Down Expand Up @@ -1356,19 +1386,21 @@ func TestScanner(t *testing.T) {
func TestScanFromPtrToSqlNullable(t *testing.T) {
var (
from struct {
S string
Sptr *string
T1 sql.NullTime
T2 sql.NullTime
T3 *time.Time
S string
Sptr *string
Snull sql.NullString
T1 sql.NullTime
T2 sql.NullTime
T3 *time.Time
}

to struct {
S sql.NullString
Sptr sql.NullString
T1 time.Time
T2 *time.Time
T3 sql.NullTime
S sql.NullString
Sptr sql.NullString
Snull *string
T1 time.Time
T2 *time.Time
T3 sql.NullTime
}

s string
Expand All @@ -1393,8 +1425,12 @@ func TestScanFromPtrToSqlNullable(t *testing.T) {
t.Errorf("to.T1 should be Zero but %v", to.T1)
}

if to.T2 != nil && !to.T2.IsZero() {
t.Errorf("to.T2 should be Zero but %v", to.T2)
if to.T2 != nil {
t.Errorf("to.T2 should be nil but %v", to.T2)
}

if to.Snull != nil {
t.Errorf("to.Snull should be nil but %v", to.Snull)
}

now := time.Now()
Expand Down

0 comments on commit 155f7ce

Please sign in to comment.