Skip to content

Commit

Permalink
Add mods to fields, specifically for "lower(field)" in SQL
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Broadhurst <peter.broadhurst@kaleido.io>
  • Loading branch information
peterbroadhurst committed Jul 25, 2023
1 parent 992357b commit 8c28b1d
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 28 deletions.
62 changes: 36 additions & 26 deletions pkg/dbsql/filter_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *Database) filterSelectFinalized(ctx context.Context, tableName string,
if len(fi.GroupBy) > 0 {
groupByWithResolvedFieldName := make([]string, len(fi.GroupBy))
for i, gb := range fi.GroupBy {
groupByWithResolvedFieldName[i] = s.mapField(tableName, gb, typeMap)
groupByWithResolvedFieldName[i] = s.mapFieldName(tableName, gb, typeMap)
}
groupByString := strings.Join(groupByWithResolvedFieldName, ",")
sel = sel.GroupBy(groupByString)
Expand All @@ -140,7 +140,7 @@ func (s *Database) filterSelectFinalized(ctx context.Context, tableName string,
} else if sf.Nulls == ffapi.NullsLast {
nulls = " NULLS LAST"
}
sort[i] = fmt.Sprintf("%s%s%s", s.mapField(tableName, sf.Field, typeMap), direction, nulls)
sort[i] = fmt.Sprintf("%s%s%s", s.mapFieldName(tableName, sf.Field, typeMap), direction, nulls)
}
sortString = strings.Join(sort, ", ")
sel = sel.OrderBy(sortString)
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *Database) BuildUpdate(sel sq.UpdateBuilder, update ffapi.Update, typeMa
}
for _, so := range ui.SetOperations {

sel = sel.Set(s.mapField("", so.Field, typeMap), so.Value)
sel = sel.Set(s.mapFieldName("", so.Field, typeMap), so.Value)
}
return sel, nil
}
Expand All @@ -191,7 +191,7 @@ func (s *Database) FilterUpdate(ctx context.Context, update sq.UpdateBuilder, fi
return update.Where(fop), nil
}

func (s *Database) mapField(tableName, fieldName string, tm map[string]string) string {
func (s *Database) mapFieldName(tableName, fieldName string, tm map[string]string) string {
if fieldName == "sequence" {
if tableName == "" {
return s.sequenceColumn
Expand All @@ -210,6 +210,16 @@ func (s *Database) mapField(tableName, fieldName string, tm map[string]string) s
return field
}

func (s *Database) mapField(tableName string, op *ffapi.FilterInfo, tm map[string]string) string {
fieldName := s.mapFieldName(tableName, op.Field, tm)
for _, m := range op.FieldMods {
if m == ffapi.FieldModLower {
fieldName = fmt.Sprintf("lower(%s)", fieldName)
}
}
return fieldName
}

// newILike uses ILIKE if supported by DB, otherwise the "lower" approach
func (s *Database) newILike(field, value string) sq.Sqlizer {
if s.features.UseILIKE {
Expand All @@ -233,49 +243,49 @@ func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.Fil
case ffapi.FilterOpAnd:
return s.filterAnd(ctx, tableName, op, tm)
case ffapi.FilterOpEq:
return sq.Eq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Eq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpIEq:
return s.newILike(s.mapField(tableName, op.Field, tm), s.escapeLike(op.Value)), nil
return s.newILike(s.mapField(tableName, op, tm), s.escapeLike(op.Value)), nil
case ffapi.FilterOpIn:
return sq.Eq{s.mapField(tableName, op.Field, tm): op.Values}, nil
return sq.Eq{s.mapField(tableName, op, tm): op.Values}, nil
case ffapi.FilterOpNeq:
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.NotEq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpNIeq:
return s.newNotILike(s.mapField(tableName, op.Field, tm), s.escapeLike(op.Value)), nil
return s.newNotILike(s.mapField(tableName, op, tm), s.escapeLike(op.Value)), nil
case ffapi.FilterOpNotIn:
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil
return sq.NotEq{s.mapField(tableName, op, tm): op.Values}, nil
case ffapi.FilterOpCont:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotCont:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpICont:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotICont:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpStartsWith:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotStartsWith:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIStartsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIStartsWith:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpEndsWith:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotEndsWith:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIEndsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIEndsWith:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
case ffapi.FilterOpGt:
return sq.Gt{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Gt{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpGte:
return sq.GtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.GtOrEq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpLt:
return sq.Lt{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Lt{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpLte:
return sq.LtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.LtOrEq{s.mapField(tableName, op, tm): op.Value}, nil
default:
return nil, i18n.NewError(ctx, i18n.MsgUnsupportedSQLOpInFilter, op.Op)
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/dbsql/filter_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var TestQueryFactory = &ffapi.QueryFields{
"tag": &ffapi.StringField{},
"topics": &ffapi.FFStringArrayField{},
"type": &ffapi.StringField{},
"address": &ffapi.StringFieldLower{},
}

func TestSQLQueryFactoryIgnoreInvalidFilterFields(t *testing.T) {
Expand Down Expand Up @@ -172,6 +173,26 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
}, args)
}

func TestSQLQueryFactoryLowerCaseIndexSearch(t *testing.T) {

s, _ := NewMockProvider().UTInit()
fb := TestQueryFactory.NewFilter(context.Background())
addr1 := "0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0"
addr2 := "0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7"
f := fb.And(
fb.IEq("address", addr1),
fb.In("address", []driver.Value{addr1, addr2}),
)

sel := squirrel.Select("*").From("mytable AS mt")
sel, _, _, err := s.FilterSelect(context.Background(), "mt", sel, f, nil, []interface{}{"sequence"})
assert.NoError(t, err)

sqlFilter, _, err := sel.ToSql()
assert.NoError(t, err)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (lower(mt.address) ILIKE ? ESCAPE '[' AND lower(mt.address) IN (?,?)) ORDER BY mt.seq DESC", sqlFilter)
}

func TestSQLQueryFactoryEscapeLike(t *testing.T) {

sel := squirrel.Select("*").From("mytable AS mt").
Expand Down
22 changes: 20 additions & 2 deletions pkg/ffapi/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ type FilterInfo struct {
Count bool
CountExpr string
Field string
FieldMods []FieldMod
Op FilterOp
Values []FieldSerialization
Value FieldSerialization
Expand Down Expand Up @@ -265,6 +266,12 @@ func ValueString(f FieldSerialization) string {
}

func (f *FilterInfo) filterString() string {
fieldName := f.Field
for _, fm := range f.FieldMods {
if fm == FieldModLower {
fieldName = fmt.Sprintf("lower(%s)", fieldName)
}
}
switch f.Op {
case FilterOpAnd, FilterOpOr:
cs := make([]string, len(f.Children))
Expand All @@ -277,9 +284,9 @@ func (f *FilterInfo) filterString() string {
for i, v := range f.Values {
strValues[i] = ValueString(v)
}
return fmt.Sprintf("%s %s [%s]", f.Field, f.Op, strings.Join(strValues, ","))
return fmt.Sprintf("%s %s [%s]", fieldName, f.Op, strings.Join(strValues, ","))
default:
return fmt.Sprintf("%s %s %s", f.Field, f.Op, ValueString(f.Value))
return fmt.Sprintf("%s %s %s", fieldName, f.Op, ValueString(f.Value))
}
}

Expand Down Expand Up @@ -353,10 +360,18 @@ func (f *baseFilter) Builder() FilterBuilder {
return f.fb
}

func fieldMods(f Field) []FieldMod {
if hfm, ok := f.(HasFieldMods); ok {
return hfm.FieldMods()
}
return nil
}

func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
var children []*FilterInfo
var value FieldSerialization
var values []FieldSerialization
var mods []FieldMod

switch f.op {
case FilterOpAnd, FilterOpOr:
Expand All @@ -374,6 +389,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
if !ok {
return nil, i18n.NewError(f.fb.ctx, i18n.MsgInvalidFilterField, name)
}
mods = fieldMods(field)
for i, fv := range fValues {
values[i] = field.GetSerialization()
if err = values[i].Scan(fv); err != nil {
Expand All @@ -386,6 +402,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
if !ok {
return nil, i18n.NewError(f.fb.ctx, i18n.MsgInvalidFilterField, name)
}
mods = fieldMods(field)
skipScan := false
switch f.value.(type) {
case nil:
Expand Down Expand Up @@ -427,6 +444,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
Children: children,
Op: f.op,
Field: f.field,
FieldMods: mods,
Values: values,
Value: value,
Sort: f.fb.sort,
Expand Down
12 changes: 12 additions & 0 deletions pkg/ffapi/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ func TestBuildMessageTimeConvert(t *testing.T) {
assert.Equal(t, "( created >> 1621112824000000000 ) && ( created >> 0 ) && ( created == 1621112874123456789 ) && ( created == null ) && ( created << 1621112824000000000 ) && ( created << 1621112824000000000 )", f.String())
}

func TestLowerField(t *testing.T) {
fb := TestQueryFactory.NewFilter(context.Background())
addr1 := "0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0"
addr2 := "0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7"
f, err := fb.And(
fb.Eq("address", addr1),
fb.In("address", []driver.Value{addr1, addr2}),
).Finalize()
assert.NoError(t, err)
assert.Equal(t, "( lower(address) == '0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0' ) && ( lower(address) IN ['0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0','0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7'] )", f.String())
}

func TestBuildMessageStringConvert(t *testing.T) {
fb := TestQueryFactory.NewFilter(context.Background())
u := fftypes.MustParseUUID("3f96e0d5-a10e-47c6-87a0-f2e7604af179")
Expand Down
18 changes: 18 additions & 0 deletions pkg/ffapi/query_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ type QueryFactory interface {
NewUpdate(ctx context.Context) UpdateBuilder
}

type FieldMod int

const (
FieldModLower FieldMod = iota
)

// HasFieldMods can be set on a QueryField to do special things that a DB might support - like lowercase index filtering
type HasFieldMods interface {
FieldMods() []FieldMod
}

type QueryFields map[string]Field

func (qf *QueryFields) NewFilterLimit(ctx context.Context, defLimit uint64) FilterBuilder {
Expand Down Expand Up @@ -131,6 +142,13 @@ func (f *StringField) GetSerialization() FieldSerialization { return &stringFiel
func (f *StringField) FilterAsString() bool { return true }
func (f *StringField) Description() string { return "String" }

type StringFieldLower struct {
StringField
}

func (f *StringFieldLower) GetSerialization() FieldSerialization { return &stringField{} }
func (f *StringFieldLower) FieldMods() []FieldMod { return []FieldMod{FieldModLower} }

type UUIDField struct{}
type uuidField struct{ u *fftypes.UUID }

Expand Down
1 change: 1 addition & 0 deletions pkg/ffapi/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var TestQueryFactory = &QueryFields{
"tag": &StringField{},
"topics": &FFStringArrayField{},
"type": &StringField{},
"address": &StringFieldLower{},
}

func TestUpdateBuilderOK(t *testing.T) {
Expand Down

0 comments on commit 8c28b1d

Please sign in to comment.