Skip to content

Commit

Permalink
Added support for embedded structs when inserting or updating.
Browse files Browse the repository at this point in the history
  • Loading branch information
andymoon committed Aug 13, 2015
1 parent 79f264d commit dd90d79
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 23 deletions.
34 changes: 24 additions & 10 deletions dataset_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"reflect"
"sort"
"time"
)

//Generates the default INSERT statement. If Prepared has been called with true then the statement will not be interpolated. See examples.
Expand Down Expand Up @@ -94,16 +95,7 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
rowCols []interface{}
rowVals []interface{}
)
for j := 0; j < newRowValue.NumField(); j++ {
f := newRowValue.Field(j)
t := newRowValue.Type().Field(j)
if me.canInsertField(t) {
if columns == nil {
rowCols = append(rowCols, t.Tag.Get("db"))
}
rowVals = append(rowVals, f.Interface())
}
}
rowCols, rowVals = me.getFieldsValues(newRowValue)
if columns == nil {
columns = cols(rowCols...)
}
Expand All @@ -115,6 +107,28 @@ func (me *Dataset) getInsertColsAndVals(rows ...interface{}) (columns ColumnList
return columns, vals, nil
}

func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{}, rowVals []interface{}) {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)

kind := v.Kind()
if (reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name()) || (kind != reflect.Struct) {
t := value.Type().Field(i)
if me.canInsertField(t) {
rowCols = append(rowCols, t.Tag.Get("db"))
rowVals = append(rowVals, v.Interface())
}
} else {
cols, vals := me.getFieldsValues(reflect.Indirect(reflect.ValueOf(v.Interface())))
rowCols = append(rowCols, cols...)
rowVals = append(rowVals, vals...)
}

}

return rowCols, rowVals
}

//Creates an INSERT statement with the columns and values passed in
func (me *Dataset) insertSql(cols ColumnList, values [][]interface{}, prepared bool) (string, []interface{}, error) {
buf := NewSqlBuilder(prepared)
Expand Down
40 changes: 35 additions & 5 deletions dataset_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package goqu
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/technotronicoz/testify/assert"

"time"
)

func (me *datasetTest) TestInsertSqlNoReturning() {
Expand Down Expand Up @@ -36,21 +38,49 @@ func (me *datasetTest) TestInsertSqlWithStructs() {
t := me.T()
ds1 := From("items")
type item struct {
Address string `db:"address"`
Name string `db:"name"`
Created time.Time `db:"created"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", Created: created})
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test', '`+created.Format(time.RFC3339Nano)+`')`)

sql, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1", Created: created},
item{Address: "211 Test Addr", Name: "Test2", Created: created},
item{Address: "311 Test Addr", Name: "Test3", Created: created},
item{Address: "411 Test Addr", Name: "Test4", Created: created},
)
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name", "created") VALUES ('111 Test Addr', 'Test1', '`+created.Format(time.RFC3339Nano)+`'), ('211 Test Addr', 'Test2', '`+created.Format(time.RFC3339Nano)+`'), ('311 Test Addr', 'Test3', '`+created.Format(time.RFC3339Nano)+`'), ('411 Test Addr', 'Test4', '`+created.Format(time.RFC3339Nano)+`')`)
}

func (me *datasetTest) TestInsertSqlWithEmbeddedStruct() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
}
type item struct {
phone
Address string `db:"address"`
Name string `db:"name"`
}
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"})
sql, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr", phone: phone{Home: "123123", Primary: "456456"}})
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test')`)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test')`)

sql, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1"},
item{Address: "211 Test Addr", Name: "Test2"},
item{Address: "111 Test Addr", Name: "Test1", phone: phone{Home: "123123", Primary: "456456"}},
item{Address: "211 Test Addr", Name: "Test2", phone: phone{Home: "123123", Primary: "456456"}},
item{Address: "311 Test Addr", Name: "Test3"},
item{Address: "411 Test Addr", Name: "Test4"},
)
assert.NoError(t, err)
assert.Equal(t, sql, `INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1'), ('211 Test Addr', 'Test2'), ('311 Test Addr', 'Test3'), ('411 Test Addr', 'Test4')`)
assert.Equal(t, sql, `INSERT INTO "items" ("primary_phone", "home_phone", "address", "name") VALUES ('456456', '123123', '111 Test Addr', 'Test1'), ('456456', '123123', '211 Test Addr', 'Test2'), ('', '', '311 Test Addr', 'Test3'), ('', '', '411 Test Addr', 'Test4')`)
}

func (me *datasetTest) TestInsertSqlWithMaps() {
Expand Down
26 changes: 19 additions & 7 deletions dataset_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"reflect"
"sort"
"time"
)

func (me *Dataset) canUpdateField(field reflect.StructField) bool {
Expand Down Expand Up @@ -38,13 +39,7 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
updates = append(updates, I(key.String()).Set(updateValue.MapIndex(key).Interface()))
}
case reflect.Struct:
for j := 0; j < updateValue.NumField(); j++ {
f := updateValue.Field(j)
t := updateValue.Type().Field(j)
if me.canUpdateField(t) {
updates = append(updates, I(t.Tag.Get("db")).Set(f.Interface()))
}
}
updates = me.getUpdateExpression(updateValue)
default:
return "", nil, NewGoquError("Unsupported update interface type %+v", updateValue.Type())
}
Expand Down Expand Up @@ -81,3 +76,20 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
sql, args := buf.ToSql()
return sql, args, nil
}

func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)
kind := v.Kind()
if reflect.TypeOf(v.Interface()).Name() == reflect.TypeOf((*time.Time)(nil)).Elem().Name() || kind != reflect.Struct {
t := value.Type().Field(i)
if me.canUpdateField(t) {
updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface()))
}
} else {
updates = append(updates, me.getUpdateExpression(reflect.Indirect(reflect.ValueOf(v.Interface())))...)
}
}

return updates
}
27 changes: 27 additions & 0 deletions dataset_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goqu
import (
"database/sql/driver"
"fmt"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/technotronicoz/testify/assert"
Expand Down Expand Up @@ -221,6 +222,32 @@ func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() {
assert.Equal(t, sql, `UPDATE "items" SET "name"=?`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() {
t := me.T()
ds1 := From("items")
type phone struct {
Primary string `db:"primary_phone"`
Home string `db:"home_phone"`
Created time.Time `db:"phone_created"`
}
type item struct {
phone
Address string `db:"address" goqu:"skipupdate"`
Name string `db:"name"`
Created time.Time `db:"created"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")

sql, args, err := ds1.Prepared(true).ToUpdateSql(item{Name: "Test", Address: "111 Test Addr", Created: created, phone: phone{
Home: "123123",
Primary: "456456",
Created: created,
}})
assert.NoError(t, err)
assert.Equal(t, args, []interface{}{"456456", "123123", created, "Test", created})
assert.Equal(t, sql, `UPDATE "items" SET "primary_phone"=?,"home_phone"=?,"phone_created"=?,"name"=?,"created"=?`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithWhere() {
t := me.T()
ds1 := From("items")
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ type GoquError struct {

func (me GoquError) Error() string {
return me.err
}
}

0 comments on commit dd90d79

Please sign in to comment.