From 7cd3fecae72d42bd77318dd9498f77e0d71cc18a Mon Sep 17 00:00:00 2001 From: Johan Fylling Date: Fri, 6 Sep 2024 15:09:02 +0200 Subject: [PATCH] ast: Make `Module.String()` include `if`/`contains` for v1 modules (#7000) Fixes: #6973 Signed-off-by: Johan Fylling --- ast/policy.go | 47 ++++++++++++++---- ast/policy_test.go | 106 ++++++++++++++++++++++++++++++++++++++++ compile/compile_test.go | 6 ++- rego/rego_test.go | 6 +-- 4 files changed, 151 insertions(+), 14 deletions(-) diff --git a/ast/policy.go b/ast/policy.go index f07cf7b376..29963a09a4 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -405,7 +405,7 @@ func (mod *Module) String() string { buf = append(buf, "") for _, rule := range mod.Rules { buf = appendAnnotationStrings(buf, rule) - buf = append(buf, rule.String()) + buf = append(buf, rule.stringWithOpts(toStringOpts{regoVersion: mod.regoVersion})) } } return strings.Join(buf, "\n") @@ -770,18 +770,30 @@ func (rule *Rule) Ref() Ref { } func (rule *Rule) String() string { + return rule.stringWithOpts(toStringOpts{}) +} + +type toStringOpts struct { + regoVersion RegoVersion +} + +func (rule *Rule) stringWithOpts(opts toStringOpts) string { buf := []string{} if rule.Default { buf = append(buf, "default") } - buf = append(buf, rule.Head.String()) + buf = append(buf, rule.Head.stringWithOpts(opts)) if !rule.Default { + switch opts.regoVersion { + case RegoV1, RegoV0CompatV1: + buf = append(buf, "if") + } buf = append(buf, "{") buf = append(buf, rule.Body.String()) buf = append(buf, "}") } if rule.Else != nil { - buf = append(buf, rule.Else.elseString()) + buf = append(buf, rule.Else.elseString(opts)) } return strings.Join(buf, " ") } @@ -824,7 +836,7 @@ func (rule *Rule) MarshalJSON() ([]byte, error) { return json.Marshal(data) } -func (rule *Rule) elseString() string { +func (rule *Rule) elseString(opts toStringOpts) string { var buf []string buf = append(buf, "else") @@ -835,12 +847,17 @@ func (rule *Rule) elseString() string { buf = append(buf, value.String()) } + switch opts.regoVersion { + case RegoV1, RegoV0CompatV1: + buf = append(buf, "if") + } + buf = append(buf, "{") buf = append(buf, rule.Body.String()) buf = append(buf, "}") if rule.Else != nil { - buf = append(buf, rule.Else.elseString()) + buf = append(buf, rule.Else.elseString(opts)) } return strings.Join(buf, " ") @@ -1000,16 +1017,28 @@ func (head *Head) Equal(other *Head) bool { } func (head *Head) String() string { + return head.stringWithOpts(toStringOpts{}) +} + +func (head *Head) stringWithOpts(opts toStringOpts) string { buf := strings.Builder{} buf.WriteString(head.Ref().String()) + containsAdded := false switch { case len(head.Args) != 0: buf.WriteString(head.Args.String()) case len(head.Reference) == 1 && head.Key != nil: - buf.WriteRune('[') - buf.WriteString(head.Key.String()) - buf.WriteRune(']') + switch opts.regoVersion { + case RegoV0: + buf.WriteRune('[') + buf.WriteString(head.Key.String()) + buf.WriteRune(']') + default: + containsAdded = true + buf.WriteString(" contains ") + buf.WriteString(head.Key.String()) + } } if head.Value != nil { if head.Assign { @@ -1018,7 +1047,7 @@ func (head *Head) String() string { buf.WriteString(" = ") } buf.WriteString(head.Value.String()) - } else if head.Name == "" && head.Key != nil { + } else if !containsAdded && head.Name == "" && head.Key != nil { buf.WriteString(" contains ") buf.WriteString(head.Key.String()) } diff --git a/ast/policy_test.go b/ast/policy_test.go index 9752dd2520..0248780c18 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -842,6 +842,112 @@ p := 7 { true }` } } +func TestModuleStringWithRegoVersion(t *testing.T) { + tests := []struct { + note string + regoVersion RegoVersion + module string + exp string + }{ + { + note: "v0, basic", + regoVersion: RegoV0, + module: `package test +a := 1 +b[1] +c[1] := 2 +d.e.f := 3 +e.f.g[1] +f.g.h[1] := 4 + +g := 5 { + false +} else := 6 { + false +} else := 7`, + exp: `package test + +a := 1 { true } +b[1] { true } +c[1] := 2 { true } +d.e.f := 3 { true } +e.f.g[1] = true { true } +f.g.h[1] := 4 { true } +g := 5 { false } else = 6 { false } else = 7 { true }`, + }, + { + note: "v0, rego.v1 import", + regoVersion: RegoV0, + module: `package test + +import rego.v1 + +a := 1 +b contains 1 +c[1] := 2 +d.e.f := 3 +e.f.g contains 1 +f.g.h[1] := 4 + +g := 5 if { + false +} else := 6 if { + false +} else := 7`, + exp: `package test + +import rego.v1 + +a := 1 if { true } +b contains 1 if { true } +c[1] := 2 if { true } +d.e.f := 3 if { true } +e.f.g contains 1 if { true } +f.g.h[1] := 4 if { true } +g := 5 if { false } else = 6 if { false } else = 7 if { true }`, + }, + { + note: "v1, basic", + regoVersion: RegoV1, + module: `package test + +a := 1 +b contains 1 +c[1] := 2 +d.e.f := 3 +e.f.g contains 1 +f.g.h[1] := 4 + +g := 5 if { + false +} else := 6 if { + false +} else := 7`, + exp: `package test + +a := 1 if { true } +b contains 1 if { true } +c[1] := 2 if { true } +d.e.f := 3 if { true } +e.f.g contains 1 if { true } +f.g.h[1] := 4 if { true } +g := 5 if { false } else = 6 if { false } else = 7 if { true }`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + module, err := ParseModuleWithOpts("test.rego", tc.module, ParserOptions{RegoVersion: tc.regoVersion}) + if err != nil { + t.Fatal(err) + } + if act := module.String(); act != tc.exp { + t.Errorf("expected:\n\n%s\n\ngot:\n\n%s", tc.exp, act) + } + }) + } +} + func TestCommentCopy(t *testing.T) { comment := &Comment{ Text: []byte("foo bar baz"), diff --git a/compile/compile_test.go b/compile/compile_test.go index c0093412d4..572e0063e1 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -1246,9 +1246,11 @@ func TestCompilerOptimizationWithConfiguredNamespace(t *testing.T) { t.Fatalf("expected two modules but got: %v", len(compiler.bundle.Modules)) } + // The compiler will strip the rego.v1 import from the module, so we need to compare it + // to a pure v1 module that doesn't require the import. optimizedExp := ast.MustParseModuleWithOpts(`package custom __not1_0_2__ = true if { data.test.q = _; _ }`, - ast.ParserOptions{AllFutureKeywords: true}) + ast.ParserOptions{RegoVersion: ast.RegoV1}) if optimizedExp.String() != compiler.bundle.Modules[0].Parsed.String() { t.Fatalf("expected optimized module to be:\n\n%v\n\ngot:\n\n%v", optimizedExp, compiler.bundle.Modules[0]) @@ -1258,7 +1260,7 @@ func TestCompilerOptimizationWithConfiguredNamespace(t *testing.T) { k = {1, 2, 3} if { true } p = true if { not data.custom.__not1_0_2__ } q = true if { __local0__3 = input.a; data.test.k[__local0__3] = _; _; __local1__3 = input.b; data.test.k[__local1__3] = _; _ }`, - ast.ParserOptions{AllFutureKeywords: true}) + ast.ParserOptions{RegoVersion: ast.RegoV1}) if expected.String() != compiler.bundle.Modules[1].Parsed.String() { t.Fatalf("expected module to be:\n\n%v\n\ngot:\n\n%v", expected, compiler.bundle.Modules[1]) diff --git a/rego/rego_test.go b/rego/rego_test.go index c69e978b97..ce0b1c79f2 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -1184,7 +1184,7 @@ func TestPartialWithRegoV1(t *testing.T) { expQuery: `data.partial.test.p = x`, expSupport: `package partial.test.p -foo[__local1__1] { __local1__1 = input.v }`, +foo contains __local1__1 if { __local1__1 = input.v }`, }, { note: "rego.v1 imported", @@ -1197,7 +1197,7 @@ foo[__local1__1] { __local1__1 = input.v }`, expQuery: `data.partial.test.p = x`, expSupport: `package partial.test.p -foo[__local1__1] { __local1__1 = input.v }`, +foo contains __local1__1 if { __local1__1 = input.v }`, }, { note: "future.keywords imported", @@ -1210,7 +1210,7 @@ foo[__local1__1] { __local1__1 = input.v }`, expQuery: `data.partial.test.p = x`, expSupport: `package partial.test.p -foo[__local1__1] { __local1__1 = input.v }`, +foo contains __local1__1 if { __local1__1 = input.v }`, }, }