Skip to content

Commit

Permalink
Fix escaping in LIKE expressions
Browse files Browse the repository at this point in the history
Explicitly specify '[' as the escape character in all cases, and use it to escape
the wildcards '%' and '_' in LIKE expressions. Works for Postgres and SQLite.

Fixes #83

Signed-off-by: Andrew Richardson <andrew.richardson@kaleido.io>
  • Loading branch information
awrichar committed Jul 6, 2023
1 parent da6668c commit c241a79
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ linters:
- unconvert
- unparam
- unused
issues:
exclude:
- "method ToSql should be ToSQL"
65 changes: 46 additions & 19 deletions pkg/dbsql/filter_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,42 @@ import (
"github.com/hyperledger/firefly-common/pkg/i18n"
)

const escapeChar = "["

type LikeEscape sq.Like
type NotLikeEscape sq.NotLike
type ILikeEscape sq.ILike
type NotILikeEscape sq.NotILike

func (lke LikeEscape) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = sq.Like(lke).ToSql()
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
}

func (lke NotLikeEscape) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = sq.NotLike(lke).ToSql()
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
}

func (lke ILikeEscape) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = sq.ILike(lke).ToSql()
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
}

func (lke NotILikeEscape) ToSql() (sql string, args []interface{}, err error) {
sql, args, err = sq.NotILike(lke).ToSql()
return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err
}

func (s *Database) escapeLike(value ffapi.FieldSerialization) string {
v, _ := value.Value()
vs, _ := v.(string)
vs = strings.ReplaceAll(vs, escapeChar, escapeChar+escapeChar)
vs = strings.ReplaceAll(vs, "%", escapeChar+"%")
vs = strings.ReplaceAll(vs, "_", escapeChar+"_")
return vs
}

func (s *Database) FilterSelect(ctx context.Context, tableName string, sel sq.SelectBuilder, filter ffapi.Filter, typeMap map[string]string, defaultSort []interface{}, preconditions ...sq.Sqlizer) (sq.SelectBuilder, sq.Sqlizer, *ffapi.FilterInfo, error) {
fi, err := filter.Finalize()
if err != nil {
Expand Down Expand Up @@ -122,15 +158,6 @@ func (s *Database) FilterUpdate(ctx context.Context, update sq.UpdateBuilder, fi
return update.Where(fop), nil
}

func (s *Database) escapeLike(value ffapi.FieldSerialization) string {
v, _ := value.Value()
vs, _ := v.(string)
vs = strings.ReplaceAll(vs, "[", "[[]")
vs = strings.ReplaceAll(vs, "%", "[%]")
vs = strings.ReplaceAll(vs, "_", "[_]")
return vs
}

func (s *Database) mapField(tableName, fieldName string, tm map[string]string) string {
if fieldName == "sequence" {
if tableName == "" {
Expand All @@ -153,17 +180,17 @@ func (s *Database) mapField(tableName, fieldName string, tm map[string]string) s
// newILike uses ILIKE if supported by DB, otherwise the "lower" approach
func (s *Database) newILike(field, value string) sq.Sqlizer {
if s.features.UseILIKE {
return sq.ILike{field: value}
return ILikeEscape{field: value}
}
return sq.Like{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
return LikeEscape{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
}

// newNotILike uses ILIKE if supported by DB, otherwise the "lower" approach
func (s *Database) newNotILike(field, value string) sq.Sqlizer {
if s.features.UseILIKE {
return sq.NotILike{field: value}
return NotILikeEscape{field: value}
}
return sq.NotLike{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
return NotLikeEscape{fmt.Sprintf("lower(%s)", field): strings.ToLower(value)}
}

func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.FilterInfo, tm map[string]string) (sq.Sqlizer, error) {
Expand All @@ -185,25 +212,25 @@ func (s *Database) filterOp(ctx context.Context, tableName string, op *ffapi.Fil
case ffapi.FilterOpNotIn:
return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil
case ffapi.FilterOpCont:
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotCont:
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpICont:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotICont:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpStartsWith:
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotStartsWith:
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%s%%", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIStartsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIStartsWith:
return s.newNotILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%s%%", s.escapeLike(op.Value))), nil
case ffapi.FilterOpEndsWith:
return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return LikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpNotEndsWith:
return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
return NotLikeEscape{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s", s.escapeLike(op.Value))}, nil
case ffapi.FilterOpIEndsWith:
return s.newILike(s.mapField(tableName, op.Field, tm), fmt.Sprintf("%%%s", s.escapeLike(op.Value))), nil
case ffapi.FilterOpNotIEndsWith:
Expand Down
20 changes: 18 additions & 2 deletions pkg/dbsql/filter_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestSQLQueryFactoryExtraOps(t *testing.T) {

sqlFilter, _, err := sel.ToSql()
assert.NoError(t, err)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.created IN (?,?,?) AND mt.created NOT IN (?,?,?) AND mt.id = ? AND mt.id IN (?) AND mt.id IS NOT NULL AND mt.created < ? AND mt.created <= ? AND mt.created >= ? AND mt.created <> ? AND mt.seq > ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ?) ORDER BY mt.seq DESC", sqlFilter)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.created IN (?,?,?) AND mt.created NOT IN (?,?,?) AND mt.id = ? AND mt.id IN (?) AND mt.id IS NOT NULL AND mt.created < ? AND mt.created <= ? AND mt.created >= ? AND mt.created <> ? AND mt.seq > ? AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter)
}

func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {
Expand Down Expand Up @@ -156,7 +156,23 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) {

sqlFilter, _, err := sel.ToSql()
assert.NoError(t, err)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.id ILIKE ? AND mt.id NOT ILIKE ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ? AND mt.topics LIKE ? AND mt.topics NOT LIKE ? AND mt.topics ILIKE ? AND mt.topics NOT ILIKE ?) ORDER BY mt.seq DESC", sqlFilter)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.id ILIKE ? ESCAPE '[' AND mt.id NOT ILIKE ? ESCAPE '[' AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[' AND mt.topics LIKE ? ESCAPE '[' AND mt.topics NOT LIKE ? ESCAPE '[' AND mt.topics ILIKE ? ESCAPE '[' AND mt.topics NOT ILIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter)
}

func TestSQLQueryFactoryEscapeLike(t *testing.T) {

s, _ := NewMockProvider().UTInit()
fb := TestQueryFactory.NewFilter(context.Background())
f := fb.And(fb.Contains("topics", "[%test_topic%]"))

sel := squirrel.Select("*").From("mytable AS mt")
sel, _, _, err := s.FilterSelect(context.Background(), "mt", sel, f, nil, []interface{}{"sequence"})
assert.NoError(t, err)

sqlFilter, args, err := sel.ToSql()
assert.NoError(t, err)
assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.topics LIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter)
assert.Equal(t, []interface{}{"%[[[%test[_topic[%]%"}, args)
}

func TestSQLQueryFactoryFinalizeFail(t *testing.T) {
Expand Down

0 comments on commit c241a79

Please sign in to comment.