diff --git a/pkg/dbsql/filter_sql.go b/pkg/dbsql/filter_sql.go index 8e65d25..e8fc10d 100644 --- a/pkg/dbsql/filter_sql.go +++ b/pkg/dbsql/filter_sql.go @@ -33,24 +33,54 @@ type NotLikeEscape sq.NotLike type ILikeEscape sq.ILike type NotILikeEscape sq.NotILike -func (lke LikeEscape) ToSql() (sql string, args []interface{}, err error) { - sql, args, err = sq.Like(lke).ToSql() - return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err +// Split a map into a list of maps with a single entry each +func splitMap[T ~map[string]interface{}](m T) []T { + var exprs []T + for key, val := range m { + exprs = append(exprs, T{key: val}) + } + return exprs +} + +// Convert a list of Sqlizer operations to sq.And +func toAnd[T sq.Sqlizer](ops []T) sq.And { + var and sq.And + for _, op := range ops { + and = append(and, op) + } + return and +} + +func (lk LikeEscape) ToSql() (sql string, args []interface{}, err error) { + if len(lk) == 1 { + sql, args, err = sq.Like(lk).ToSql() + return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err + } + return toAnd(splitMap(lk)).ToSql() } -func (lke NotLikeEscape) ToSql() (sql string, args []interface{}, err error) { - sql, args, err = sq.NotLike(lke).ToSql() - return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err +func (lk NotLikeEscape) ToSql() (sql string, args []interface{}, err error) { + if len(lk) == 1 { + sql, args, err = sq.NotLike(lk).ToSql() + return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err + } + return toAnd(splitMap(lk)).ToSql() } -func (lke ILikeEscape) ToSql() (sql string, args []interface{}, err error) { - sql, args, err = sq.ILike(lke).ToSql() - return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err +func (lk ILikeEscape) ToSql() (sql string, args []interface{}, err error) { + if len(lk) == 1 { + sql, args, err = sq.ILike(lk).ToSql() + return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err + } + return toAnd(splitMap(lk)).ToSql() } -func (lke NotILikeEscape) ToSql() (sql string, args []interface{}, err error) { - sql, args, err = sq.NotILike(lke).ToSql() - return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err +func (lk NotILikeEscape) ToSql() (sql string, args []interface{}, err error) { + if len(lk) == 1 { + sql, args, err = sq.NotILike(lk).ToSql() + return fmt.Sprintf("%s ESCAPE '%s'", sql, escapeChar), args, err + } + return toAnd(splitMap(lk)).ToSql() } func (s *Database) escapeLike(value ffapi.FieldSerialization) string { diff --git a/pkg/dbsql/filter_sql_test.go b/pkg/dbsql/filter_sql_test.go index 43e2c07..e20d5c5 100644 --- a/pkg/dbsql/filter_sql_test.go +++ b/pkg/dbsql/filter_sql_test.go @@ -161,18 +161,14 @@ func TestSQLQueryFactoryEvenMoreOps(t *testing.T) { func TestSQLQueryFactoryEscapeLike(t *testing.T) { - s, _ := NewMockProvider().UTInit() - fb := TestQueryFactory.NewFilter(context.Background()) - f := fb.And(fb.Contains("topics", "[%test_topic%]")) + sel := squirrel.Select("*").From("mytable AS mt"). + Where(LikeEscape{"a": 1, "b": 2}). + Where(NotLikeEscape{"a": 1, "b": 2}). + Where(ILikeEscape{"a": 1, "b": 2}). + Where(NotILikeEscape{"a": 1, "b": 2}) - sel := squirrel.Select("*").From("mytable AS mt") - sel, _, _, err := s.FilterSelect(context.Background(), "mt", sel, f, nil, []interface{}{"sequence"}) - assert.NoError(t, err) - - sqlFilter, args, err := sel.ToSql() + _, _, err := sel.ToSql() assert.NoError(t, err) - assert.Equal(t, "SELECT * FROM mytable AS mt WHERE (mt.topics LIKE ? ESCAPE '[') ORDER BY mt.seq DESC", sqlFilter) - assert.Equal(t, []interface{}{"%[[[%test[_topic[%]%"}, args) } func TestSQLQueryFactoryFinalizeFail(t *testing.T) {