From 78c6dfd712aca47ab1945921cdb37d1185aea7f3 Mon Sep 17 00:00:00 2001 From: Sergei Sadov <32732298+SergeiSadov@users.noreply.github.com> Date: Wed, 12 Jun 2024 13:49:45 +0300 Subject: [PATCH] Fix association replace non-addressable panic (#7012) * Fix association replace non-addressable panic * Fix tests * Add has one panic test --------- Co-authored-by: sgsv <-> --- association.go | 14 ++++++++++++++ tests/associations_has_many_test.go | 12 ++++++++++++ tests/associations_has_one_test.go | 12 ++++++++++++ 3 files changed, 38 insertions(+) diff --git a/association.go b/association.go index 7c93ebea0d..e3f51d173b 100644 --- a/association.go +++ b/association.go @@ -396,6 +396,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } case reflect.Struct: + if !rv.CanAddr() { + association.Error = ErrInvalidValue + return + } association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { @@ -433,6 +437,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) } case reflect.Struct: + if !rv.CanAddr() { + association.Error = ErrInvalidValue + return + } appendToFieldValues(rv.Addr()) } @@ -510,6 +518,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + if association.Error != nil { + return + } // TODO support save slice data, sql with case? association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error @@ -531,6 +542,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for idx, value := range values { rv := reflect.Indirect(reflect.ValueOf(value)) appendToRelations(reflectValue, rv, clear && idx == 0) + if association.Error != nil { + return + } } if len(values) > 0 { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index b8e8ff5efc..db397eb78f 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -554,3 +554,15 @@ func TestHasManyAssociationUnscoped(t *testing.T) { t.Errorf("expected %d contents, got %d", 0, len(contents)) } } + +func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) { + user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil { + t.Error("expected association error to be not nil") + } +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a2c0750904..78290ce90b 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -255,3 +255,15 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { DB.Model(&pets).Association("Toy").Clear() AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } + +func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) { + user := User{Name: "jinzhu", Account: Account{Number: "1"}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil { + t.Error("expected association error to be not nil") + } +}