diff --git a/dataset_insert.go b/dataset_insert.go index 75a5b1b0..753927b4 100644 --- a/dataset_insert.go +++ b/dataset_insert.go @@ -3,6 +3,7 @@ package goqu import ( "reflect" "sort" + "time" ) //Generates the default INSERT statement. If Prepared has been called with true then the statement will not be interpolated. See examples. @@ -94,16 +95,7 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList rowCols []interface{} rowVals []interface{} ) - for j := 0; j < newRowValue.NumField(); j++ { - f := newRowValue.Field(j) - t := newRowValue.Type().Field(j) - if me.canInsertField(t) { - if columns == nil { - rowCols = append(rowCols, t.Tag.Get("db")) - } - rowVals = append(rowVals, f.Interface()) - } - } + rowCols, rowVals = me.getFieldsValues(newRowValue) if columns == nil { columns = cols(rowCols...) } @@ -115,6 +107,28 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList return columns, vals, nil } +func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{}, rowVals []interface{}) { + for i := 0; i < value.NumField(); i++ { + v := value.Field(i) + + kind := v.Kind() + if (reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name()) || (kind != reflect.Struct) { + t := value.Type().Field(i) + if me.canInsertField(t) { + rowCols = append(rowCols, t.Tag.Get("db")) + rowVals = append(rowVals, v.Interface()) + } + } else { + cols, vals := me.getFieldsValues(reflect.Indirect(reflect.ValueOf(v.Interface()))) + rowCols = append(rowCols, cols...) + rowVals = append(rowVals, vals...) + } + + } + + return rowCols, rowVals +} + //Creates an INSERT statement with the columns and values passed in func (me *Dataset) insertSql(cols ColumnList, values [][]interface{}, prepared bool) (string, []interface{}, error) { buf := NewSqlBuilder(prepared) diff --git a/dataset_insert_test.go b/dataset_insert_test.go index bc16a166..d3c7c3c7 100644 --- a/dataset_insert_test.go +++ b/dataset_insert_test.go @@ -3,6 +3,8 @@ package goqu import ( "github.com/DATA-DOG/go-sqlmock" "github.com/technotronicoz/testify/assert" + + "time" ) func (me *datasetTest) TestInsertSqlNoReturning() { @@ -36,21 +38,49 @@ func (me *datasetTest) TestInsertSqlWithStructs() { t := me.T() ds1 := From("items") type item struct { + Address string `db:"address"` + Name string `db:"name"` + Created time.Time `db:"created"` + } + created, _ := time.Parse("2006-01-02", "2015-01-01") + sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", Created: created}) + assert.NoError(t, err) + assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test', '`+created.Format(time.RFC3339Nano)+`')`) + + sql, _, err = ds1.ToInsertSql( + item{Address: "111 Test Addr", Name: "Test1", Created: created}, + item{Address: "211 Test Addr", Name: "Test2", Created: created}, + item{Address: "311 Test Addr", Name: "Test3", Created: created}, + item{Address: "411 Test Addr", Name: "Test4", Created: created}, + ) + assert.NoError(t, err) + assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test1', '`+created.Format(time.RFC3339Nano)+`'), ('211 Test Addr', 'Test2', '`+created.Format(time.RFC3339Nano)+`'), ('311 Test Addr', 'Test3', '`+created.Format(time.RFC3339Nano)+`'), ('411 Test Addr', 'Test4', '`+created.Format(time.RFC3339Nano)+`')`) +} + +func (me *datasetTest) TestInsertSqlWithEmbeddedStruct() { + t := me.T() + ds1 := From("items") + type phone struct { + Primary string `db:"primary_phone"` + Home string `db:"home_phone"` + } + type item struct { + phone Address string `db:"address"` Name string `db:"name"` } - sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"}) + sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: phone{Home: "123123", Primary: "456456"}}) assert.NoError(t, err) - assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`) + assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`) sql, _, err = ds1.ToInsertSql( - item{Address: "111 Test Addr", Name: "Test1"}, - item{Address: "211 Test Addr", Name: "Test2"}, + item{Address: "111 Test Addr", Name: "Test1", phone: phone{Home: "123123", Primary: "456456"}}, + item{Address: "211 Test Addr", Name: "Test2", phone: phone{Home: "123123", Primary: "456456"}}, item{Address: "311 Test Addr", Name: "Test3"}, item{Address: "411 Test Addr", Name: "Test4"}, ) assert.NoError(t, err) - assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1'), ('211 Test Addr', 'Test2'), ('311 Test Addr', 'Test3'), ('411 Test Addr', 'Test4')`) + assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('', '', '311 Test Addr', 'Test3'), ('', '', '411 Test Addr', 'Test4')`) } func (me *datasetTest) TestInsertSqlWithMaps() { diff --git a/dataset_update.go b/dataset_update.go index 84444d19..b92b5dec 100644 --- a/dataset_update.go +++ b/dataset_update.go @@ -3,6 +3,7 @@ package goqu import ( "reflect" "sort" + "time" ) func (me *Dataset) canUpdateField(field reflect.StructField) bool { @@ -38,13 +39,7 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error updates = append(updates, I(key.String()).Set(updateValue.MapIndex(key).Interface())) } case reflect.Struct: - for j := 0; j < updateValue.NumField(); j++ { - f := updateValue.Field(j) - t := updateValue.Type().Field(j) - if me.canUpdateField(t) { - updates = append(updates, I(t.Tag.Get("db")).Set(f.Interface())) - } - } + updates = me.getUpdateExpression(updateValue) default: return "", nil, NewGoquError("Unsupported update interface type %+v", updateValue.Type()) } @@ -81,3 +76,20 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error sql, args := buf.ToSql() return sql, args, nil } + +func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) { + for i := 0; i < value.NumField(); i++ { + v := value.Field(i) + kind := v.Kind() + if reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name() || kind != reflect.Struct { + t := value.Type().Field(i) + if me.canUpdateField(t) { + updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface())) + } + } else { + updates = append(updates, me.getUpdateExpression(reflect.Indirect(reflect.ValueOf(v.Interface())))...) + } + } + + return updates +} diff --git a/dataset_update_test.go b/dataset_update_test.go index 95706bee..c8c7ecd9 100644 --- a/dataset_update_test.go +++ b/dataset_update_test.go @@ -3,6 +3,7 @@ package goqu import ( "database/sql/driver" "fmt" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/technotronicoz/testify/assert" @@ -221,6 +222,32 @@ func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() { assert.Equal(t, sql, `UPDATE "items" SET "name"=?`) } +func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() { + t := me.T() + ds1 := From("items") + type phone struct { + Primary string `db:"primary_phone"` + Home string `db:"home_phone"` + Created time.Time `db:"phone_created"` + } + type item struct { + phone + Address string `db:"address" goqu:"skipupdate"` + Name string `db:"name"` + Created time.Time `db:"created"` + } + created, _ := time.Parse("2006-01-02", "2015-01-01") + + sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: phone{ + Home: "123123", + Primary: "456456", + Created: created, + }}) + assert.NoError(t, err) + assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created}) + assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`) +} + func (me *datasetTest) TestPreparedUpdateSqlWithWhere() { t := me.T() ds1 := From("items") diff --git a/errors.go b/errors.go index 22c9c424..285a0848 100644 --- a/errors.go +++ b/errors.go @@ -25,4 +25,4 @@ type GoquError struct { func (me GoquError) Error() string { return me.err -} \ No newline at end of file +}