Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mods to fields, specifically for "lower(field)" in SQL #92

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions pkg/dbsql/filter_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (s *Database) filterSelectFinalized(ctx context.Context, tableName string,
if len(fi.GroupBy) > 0 {
groupByWithResolvedFieldName := make([]string, len(fi.GroupBy))
for i, gb := range fi.GroupBy {
groupByWithResolvedFieldName[i] = s.mapField(tableName, gb, typeMap)
groupByWithResolvedFieldName[i] = s.mapFieldName(tableName, gb, typeMap)
}
groupByString := strings.Join(groupByWithResolvedFieldName, ",")
sel = sel.GroupBy(groupByString)
Expand All @@ -140,7 +140,7 @@ func (s *Database) filterSelectFinalized(ctx context.Context, tableName string,
} else if sf.Nulls == ffapi.NullsLast {
nulls = " NULLS LAST"
}
sort[i] = fmt.Sprintf("%s%s%s", s.mapField(tableName, sf.Field, typeMap), direction, nulls)
sort[i] = fmt.Sprintf("%s%s%s", s.mapFieldName(tableName, sf.Field, typeMap), direction, nulls)
}
sortString = strings.Join(sort, ", ")
sel = sel.OrderBy(sortString)
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *Database) BuildUpdate(sel sq.UpdateBuilder, update ffapi.Update, typeMa
}
for _, so := range ui.SetOperations {

sel = sel.Set(s.mapField("", so.Field, typeMap), so.Value)
sel = sel.Set(s.mapFieldName("", so.Field, typeMap), so.Value)
}
return sel, nil
}
Expand All @@ -191,7 +191,7 @@ func (s *Database) FilterUpdate(ctx context.Context, update sq.UpdateBuilder, fi
return update.Where(fop), nil
}

func (s *Database) mapField(tableName, fieldName string, tm map[string]string) string {
func (s *Database) mapFieldName(tableName, fieldName string, tm map[string]string) string {
if fieldName == "sequence" {
if tableName == "" {
return s.sequenceColumn
Expand All @@ -210,6 +210,16 @@ func (s *Database) mapField(tableName, fieldName string, tm map[string]string) s
return field
}

func (s *Database) mapField(tableName string, op *ffapi.FilterInfo, tm map[string]string) string {
fieldName := s.mapFieldName(tableName, op.Field, tm)
for _, m := range op.FieldMods {
if m == ffapi.FieldModLower {
fieldName = fmt.Sprintf("lower(%s)", fieldName)
}
}
return fieldName
}

// newILike uses ILIKE if supported by DB, otherwise the "lower" approach
func (s *Database) newILike(field, value string) sq.Sqlizer {
if s.features.UseILIKE {
Expand All @@ -233,49 +243,49 @@ func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.Fil
case ffapi.FilterOpAnd:
return s.filterAnd(ctx, tableName, op, tm)
case ffapi.FilterOpEq:
return sq.Eq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Eq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpIEq:
return s.newILike(s.mapField(tableName, op.Field, tm), s.escapeLike(op.Value)), nil
return s.newILike(s.mapField(tableName, op, tm), s.escapeLike(op.Value)), nil
case ffapi.FilterOpIn:
return sq.Eq{s.mapField(tableName, op.Field, tm): op.Values}, nil
return sq.Eq{s.mapField(tableName, op, tm): op.Values}, nil
case ffapi.FilterOpNeq:
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.NotEq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpNIeq:
return s.newNotILike(s.mapField(tableName, op.Field, tm), s.escapeLike(op.Value)), nil
return s.newNotILike(s.mapField(tableName, op, tm), s.escapeLike(op.Value)), nil
case ffapi.FilterOpNotIn:
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil
return sq.NotEq{s.mapField(tableName, op, tm): op.Values}, nil
case ffapi.FilterOpCont:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotCont:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpICont:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotICont:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpStartsWith:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotStartsWith:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIStartsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIStartsWith:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpEndsWith:
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotEndsWith:
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIEndsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
return s.newILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIEndsWith:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
return s.newNotILike(s.mapField(tableName, op, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
case ffapi.FilterOpGt:
return sq.Gt{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Gt{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpGte:
return sq.GtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.GtOrEq{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpLt:
return sq.Lt{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.Lt{s.mapField(tableName, op, tm): op.Value}, nil
case ffapi.FilterOpLte:
return sq.LtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
return sq.LtOrEq{s.mapField(tableName, op, tm): op.Value}, nil
default:
return nil, i18n.NewError(ctx, i18n.MsgUnsupportedSQLOpInFilter, op.Op)
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/dbsql/filter_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var TestQueryFactory = &ffapi.QueryFields{
"tag": &ffapi.StringField{},
"topics": &ffapi.FFStringArrayField{},
"type": &ffapi.StringField{},
"address": &ffapi.StringFieldLower{},
}

func TestSQLQueryFactoryIgnoreInvalidFilterFields(t *testing.T) {
Expand Down Expand Up @@ -172,6 +173,26 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
}, args)
}

func TestSQLQueryFactoryLowerCaseIndexSearch(t *testing.T) {

s, _ := NewMockProvider().UTInit()
fb := TestQueryFactory.NewFilter(context.Background())
addr1 := "0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0"
addr2 := "0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7"
f := fb.And(
fb.IEq("address", addr1),
fb.In("address", []driver.Value{addr1, addr2}),
)

sel := squirrel.Select("*").From("mytable AS mt")
sel, _, _, err := s.FilterSelect(context.Background(), "mt", sel, f, nil, []interface{}{"sequence"})
assert.NoError(t, err)

sqlFilter, _, err := sel.ToSql()
assert.NoError(t, err)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (lower(mt.address) ILIKE ? ESCAPE '[' AND lower(mt.address) IN (?,?)) ORDER BY mt.seq DESC", sqlFilter)
}

func TestSQLQueryFactoryEscapeLike(t *testing.T) {

sel := squirrel.Select("*").From("mytable AS mt").
Expand Down
22 changes: 20 additions & 2 deletions pkg/ffapi/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ type FilterInfo struct {
Count bool
CountExpr string
Field string
FieldMods []FieldMod
Op FilterOp
Values []FieldSerialization
Value FieldSerialization
Expand Down Expand Up @@ -265,6 +266,12 @@ func ValueString(f FieldSerialization) string {
}

func (f *FilterInfo) filterString() string {
fieldName := f.Field
for _, fm := range f.FieldMods {
if fm == FieldModLower {
fieldName = fmt.Sprintf("lower(%s)", fieldName)
}
}
switch f.Op {
case FilterOpAnd, FilterOpOr:
cs := make([]string, len(f.Children))
Expand All @@ -277,9 +284,9 @@ func (f *FilterInfo) filterString() string {
for i, v := range f.Values {
strValues[i] = ValueString(v)
}
return fmt.Sprintf("%s %s [%s]", f.Field, f.Op, strings.Join(strValues, ","))
return fmt.Sprintf("%s %s [%s]", fieldName, f.Op, strings.Join(strValues, ","))
default:
return fmt.Sprintf("%s %s %s", f.Field, f.Op, ValueString(f.Value))
return fmt.Sprintf("%s %s %s", fieldName, f.Op, ValueString(f.Value))
}
}

Expand Down Expand Up @@ -353,10 +360,18 @@ func (f *baseFilter) Builder() FilterBuilder {
return f.fb
}

func fieldMods(f Field) []FieldMod {
if hfm, ok := f.(HasFieldMods); ok {
return hfm.FieldMods()
}
return nil
}

func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
var children []*FilterInfo
var value FieldSerialization
var values []FieldSerialization
var mods []FieldMod

switch f.op {
case FilterOpAnd, FilterOpOr:
Expand All @@ -374,6 +389,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
if !ok {
return nil, i18n.NewError(f.fb.ctx, i18n.MsgInvalidFilterField, name)
}
mods = fieldMods(field)
for i, fv := range fValues {
values[i] = field.GetSerialization()
if err = values[i].Scan(fv); err != nil {
Expand All @@ -386,6 +402,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
if !ok {
return nil, i18n.NewError(f.fb.ctx, i18n.MsgInvalidFilterField, name)
}
mods = fieldMods(field)
skipScan := false
switch f.value.(type) {
case nil:
Expand Down Expand Up @@ -427,6 +444,7 @@ func (f *baseFilter) Finalize() (fi *FilterInfo, err error) {
Children: children,
Op: f.op,
Field: f.field,
FieldMods: mods,
Values: values,
Value: value,
Sort: f.fb.sort,
Expand Down
12 changes: 12 additions & 0 deletions pkg/ffapi/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ func TestBuildMessageTimeConvert(t *testing.T) {
assert.Equal(t, "( created >> 1621112824000000000 ) && ( created >> 0 ) && ( created == 1621112874123456789 ) && ( created == null ) && ( created << 1621112824000000000 ) && ( created << 1621112824000000000 )", f.String())
}

func TestLowerField(t *testing.T) {
fb := TestQueryFactory.NewFilter(context.Background())
addr1 := "0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0"
addr2 := "0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7"
f, err := fb.And(
fb.Eq("address", addr1),
fb.In("address", []driver.Value{addr1, addr2}),
).Finalize()
assert.NoError(t, err)
assert.Equal(t, "( lower(address) == '0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0' ) && ( lower(address) IN ['0xf698D78272a0bCD63A3feb097B24a866f6b8a5a0','0xb9B919763dBC54D4D634150446Bf3991A9ef5eD7'] )", f.String())
}

func TestBuildMessageStringConvert(t *testing.T) {
fb := TestQueryFactory.NewFilter(context.Background())
u := fftypes.MustParseUUID("3f96e0d5-a10e-47c6-87a0-f2e7604af179")
Expand Down
18 changes: 18 additions & 0 deletions pkg/ffapi/query_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ type QueryFactory interface {
NewUpdate(ctx context.Context) UpdateBuilder
}

type FieldMod int

const (
FieldModLower FieldMod = iota
)

// HasFieldMods can be set on a QueryField to do special things that a DB might support - like lowercase index filtering
type HasFieldMods interface {
FieldMods() []FieldMod
}

type QueryFields map[string]Field

func (qf *QueryFields) NewFilterLimit(ctx context.Context, defLimit uint64) FilterBuilder {
Expand Down Expand Up @@ -131,6 +142,13 @@ func (f *StringField) GetSerialization() FieldSerialization { return &stringFiel
func (f *StringField) FilterAsString() bool { return true }
func (f *StringField) Description() string { return "String" }

type StringFieldLower struct {
StringField
}

func (f *StringFieldLower) GetSerialization() FieldSerialization { return &stringField{} }
func (f *StringFieldLower) FieldMods() []FieldMod { return []FieldMod{FieldModLower} }

type UUIDField struct{}
type uuidField struct{ u *fftypes.UUID }

Expand Down
1 change: 1 addition & 0 deletions pkg/ffapi/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var TestQueryFactory = &QueryFields{
"tag": &StringField{},
"topics": &FFStringArrayField{},
"type": &StringField{},
"address": &StringFieldLower{},
}

func TestUpdateBuilderOK(t *testing.T) {
Expand Down