Skip to content

Commit

Permalink
Improve tests for lifted args
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyhb committed Jan 5, 2024
1 parent 28d5616 commit be344fd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
16 changes: 16 additions & 0 deletions lift.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ import (
"strings"
)

// LiftedArgs represents a set of variables that have been lifted from expressions and
// replaced with identifiers, eg `id == "foo"` becomes `id == vars.a`, with "foo" lifted
// as "vars.a".
type LiftedArgs interface {
Get(val string) (any, bool)
Map() map[string]any
}

// liftLiterals lifts quoted literals into variables, allowing us to normalize
Expand Down Expand Up @@ -128,6 +132,14 @@ type pointerArgMap struct {
vars map[string]argMapValue
}

func (p pointerArgMap) Map() map[string]any {
res := map[string]any{}
for k, v := range p.vars {
res[k] = v.get(p.expr)
}
return res
}

func (p pointerArgMap) Get(key string) (any, bool) {
val, ok := p.vars[key]
if !ok {
Expand All @@ -151,3 +163,7 @@ func (p regularArgMap) Get(key string) (any, bool) {
val, ok := p[key]
return val, ok
}

func (p regularArgMap) Map() map[string]any {
return p
}
4 changes: 2 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ func callToPredicate(item celast.Expr, negated bool, vars LiftedArgs) *Predicate
}
}

if aIsVar {
if aIsVar && vars != nil {
if val, ok := vars.Get(strings.TrimPrefix(identA, VarPrefix)); ok {
// Normalize.
literal = val
Expand All @@ -554,7 +554,7 @@ func callToPredicate(item celast.Expr, negated bool, vars LiftedArgs) *Predicate
}
}

if bIsVar {
if bIsVar && vars != nil {
if val, ok := vars.Get(strings.TrimPrefix(identB, VarPrefix)); ok {
// Normalize.
literal = val
Expand Down
3 changes: 3 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,9 @@ func TestParse_LiftedVars(t *testing.T) {
test.expected.Evaluable = eval
}

// Convert the lifted arg interfaces to the same map values
actual.Vars = regularArgMap(actual.Vars.Map())

require.NoError(t, err)
require.NotNil(t, actual)

Expand Down

0 comments on commit be344fd

Please sign in to comment.