Skip to content

Commit

Permalink
Update PruneAst to support constants of optional type (#1109)
Browse files Browse the repository at this point in the history
* Update PruneAst and unparser to support optional types constants.

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>

* recursive definitions

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>

* add prune tests for <list>.last()

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>

* clean-up

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>

---------

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>
Co-authored-by: Dennis Buduev <dbuduev@users.noreply.github.com>
  • Loading branch information
dbuduev and dbuduev authored Jan 22, 2025
1 parent 33a7f97 commit 91fb306
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
2 changes: 1 addition & 1 deletion interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *as

func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) {
switch v := val.(type) {
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint:
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint, *types.Optional:
p.state.SetValue(id, val)
return p.NewLiteral(id, val), true
case types.Duration:
Expand Down
44 changes: 44 additions & 0 deletions interpreter/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
"github.com/google/cel-go/test"

Expand Down Expand Up @@ -216,6 +217,36 @@ var testCases = []testInfo{
expr: `a.?b`,
out: `a.?b`,
},
{
in: partialActivation(map[string]any{"a": map[string]any{"b": 10}}),
expr: `a.?b`,
out: `optional.of(10)`,
},
{
in: partialActivation(map[string]any{"a": map[string]any{"b": 10}}),
expr: `a[?"b"]`,
out: `optional.of(10)`,
},
{
in: unknownActivation(),
expr: `{'b': optional.of(10)}.?b`,
out: `optional.of(optional.of(10))`,
},
{
in: partialActivation(map[string]any{"a": map[string]any{}}),
expr: `a.?b`,
out: `optional.none()`,
},
{
in: unknownActivation(),
expr: `[10].last()`,
out: "optional.of(10)",
},
{
in: unknownActivation(),
expr: `[].last()`,
out: "optional.none()",
},
{
in: unknownActivation("a"),
expr: `a[?"b"]`,
Expand Down Expand Up @@ -561,5 +592,18 @@ func optionalDecls(t *testing.T) []*decls.FunctionDecl {
types.NewTypeParamType("K"),
}, optionalType),
),
funcDecl(t, "last", decls.Overload("list_last", []*types.Type{paramType}, optionalType,
decls.UnaryBinding(func(v ref.Val) ref.Val {
list := v.(traits.Lister)
sz := list.Size().Value().(int64)

if sz == 0 {
return types.OptionalNone
}

return types.OptionalOf(list.Get(types.Int(sz - 1)))
}),
),
),
}
}
28 changes: 26 additions & 2 deletions parser/unparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

// Unparse takes an input expression and source position information and generates a human-readable
Expand Down Expand Up @@ -273,8 +274,17 @@ func (un *unparser) visitCallUnary(expr ast.Expr) error {
return un.visitMaybeNested(args[0], nested)
}

func (un *unparser) visitConst(expr ast.Expr) error {
val := expr.AsLiteral()
func (un *unparser) visitConstVal(val ref.Val) error {
optional := false
if optVal, ok := val.(*types.Optional); ok {
if !optVal.HasValue() {
un.str.WriteString("optional.none()")
return nil
}
optional = true
un.str.WriteString("optional.of(")
val = optVal.GetValue()
}
switch val := val.(type) {
case types.Bool:
un.str.WriteString(strconv.FormatBool(bool(val)))
Expand Down Expand Up @@ -303,7 +313,21 @@ func (un *unparser) visitConst(expr ast.Expr) error {
ui := strconv.FormatUint(uint64(val), 10)
un.str.WriteString(ui)
un.str.WriteString("u")
case *types.Optional:
if err := un.visitConstVal(val); err != nil {
return err
}
default:
return errors.New("unsupported constant")
}
if optional {
un.str.WriteString(")")
}
return nil
}
func (un *unparser) visitConst(expr ast.Expr) error {
val := expr.AsLiteral()
if err := un.visitConstVal(val); err != nil {
return fmt.Errorf("unsupported constant: %v", expr)
}
return nil
Expand Down

0 comments on commit 91fb306

Please sign in to comment.