Skip to content

Commit

Permalink
ast: Make Module.String() include if/contains for v1 modules (#…
Browse files Browse the repository at this point in the history
…7000)

Fixes: #6973
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored Sep 6, 2024
1 parent 3587ccf commit 7cd3fec
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 14 deletions.
47 changes: 38 additions & 9 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (mod *Module) String() string {
buf = append(buf, "")
for _, rule := range mod.Rules {
buf = appendAnnotationStrings(buf, rule)
buf = append(buf, rule.String())
buf = append(buf, rule.stringWithOpts(toStringOpts{regoVersion: mod.regoVersion}))
}
}
return strings.Join(buf, "\n")
Expand Down Expand Up @@ -770,18 +770,30 @@ func (rule *Rule) Ref() Ref {
}

func (rule *Rule) String() string {
return rule.stringWithOpts(toStringOpts{})
}

type toStringOpts struct {
regoVersion RegoVersion
}

func (rule *Rule) stringWithOpts(opts toStringOpts) string {
buf := []string{}
if rule.Default {
buf = append(buf, "default")
}
buf = append(buf, rule.Head.String())
buf = append(buf, rule.Head.stringWithOpts(opts))
if !rule.Default {
switch opts.regoVersion {
case RegoV1, RegoV0CompatV1:
buf = append(buf, "if")
}
buf = append(buf, "{")
buf = append(buf, rule.Body.String())
buf = append(buf, "}")
}
if rule.Else != nil {
buf = append(buf, rule.Else.elseString())
buf = append(buf, rule.Else.elseString(opts))
}
return strings.Join(buf, " ")
}
Expand Down Expand Up @@ -824,7 +836,7 @@ func (rule *Rule) MarshalJSON() ([]byte, error) {
return json.Marshal(data)
}

func (rule *Rule) elseString() string {
func (rule *Rule) elseString(opts toStringOpts) string {
var buf []string

buf = append(buf, "else")
Expand All @@ -835,12 +847,17 @@ func (rule *Rule) elseString() string {
buf = append(buf, value.String())
}

switch opts.regoVersion {
case RegoV1, RegoV0CompatV1:
buf = append(buf, "if")
}

buf = append(buf, "{")
buf = append(buf, rule.Body.String())
buf = append(buf, "}")

if rule.Else != nil {
buf = append(buf, rule.Else.elseString())
buf = append(buf, rule.Else.elseString(opts))
}

return strings.Join(buf, " ")
Expand Down Expand Up @@ -1000,16 +1017,28 @@ func (head *Head) Equal(other *Head) bool {
}

func (head *Head) String() string {
return head.stringWithOpts(toStringOpts{})
}

func (head *Head) stringWithOpts(opts toStringOpts) string {
buf := strings.Builder{}
buf.WriteString(head.Ref().String())
containsAdded := false

switch {
case len(head.Args) != 0:
buf.WriteString(head.Args.String())
case len(head.Reference) == 1 && head.Key != nil:
buf.WriteRune('[')
buf.WriteString(head.Key.String())
buf.WriteRune(']')
switch opts.regoVersion {
case RegoV0:
buf.WriteRune('[')
buf.WriteString(head.Key.String())
buf.WriteRune(']')
default:
containsAdded = true
buf.WriteString(" contains ")
buf.WriteString(head.Key.String())
}
}
if head.Value != nil {
if head.Assign {
Expand All @@ -1018,7 +1047,7 @@ func (head *Head) String() string {
buf.WriteString(" = ")
}
buf.WriteString(head.Value.String())
} else if head.Name == "" && head.Key != nil {
} else if !containsAdded && head.Name == "" && head.Key != nil {
buf.WriteString(" contains ")
buf.WriteString(head.Key.String())
}
Expand Down
106 changes: 106 additions & 0 deletions ast/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,112 @@ p := 7 { true }`
}
}

func TestModuleStringWithRegoVersion(t *testing.T) {
tests := []struct {
note string
regoVersion RegoVersion
module string
exp string
}{
{
note: "v0, basic",
regoVersion: RegoV0,
module: `package test
a := 1
b[1]
c[1] := 2
d.e.f := 3
e.f.g[1]
f.g.h[1] := 4
g := 5 {
false
} else := 6 {
false
} else := 7`,
exp: `package test
a := 1 { true }
b[1] { true }
c[1] := 2 { true }
d.e.f := 3 { true }
e.f.g[1] = true { true }
f.g.h[1] := 4 { true }
g := 5 { false } else = 6 { false } else = 7 { true }`,
},
{
note: "v0, rego.v1 import",
regoVersion: RegoV0,
module: `package test
import rego.v1
a := 1
b contains 1
c[1] := 2
d.e.f := 3
e.f.g contains 1
f.g.h[1] := 4
g := 5 if {
false
} else := 6 if {
false
} else := 7`,
exp: `package test
import rego.v1
a := 1 if { true }
b contains 1 if { true }
c[1] := 2 if { true }
d.e.f := 3 if { true }
e.f.g contains 1 if { true }
f.g.h[1] := 4 if { true }
g := 5 if { false } else = 6 if { false } else = 7 if { true }`,
},
{
note: "v1, basic",
regoVersion: RegoV1,
module: `package test
a := 1
b contains 1
c[1] := 2
d.e.f := 3
e.f.g contains 1
f.g.h[1] := 4
g := 5 if {
false
} else := 6 if {
false
} else := 7`,
exp: `package test
a := 1 if { true }
b contains 1 if { true }
c[1] := 2 if { true }
d.e.f := 3 if { true }
e.f.g contains 1 if { true }
f.g.h[1] := 4 if { true }
g := 5 if { false } else = 6 if { false } else = 7 if { true }`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
module, err := ParseModuleWithOpts("test.rego", tc.module, ParserOptions{RegoVersion: tc.regoVersion})
if err != nil {
t.Fatal(err)
}
if act := module.String(); act != tc.exp {
t.Errorf("expected:\n\n%s\n\ngot:\n\n%s", tc.exp, act)
}
})
}
}

func TestCommentCopy(t *testing.T) {
comment := &Comment{
Text: []byte("foo bar baz"),
Expand Down
6 changes: 4 additions & 2 deletions compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1246,9 +1246,11 @@ func TestCompilerOptimizationWithConfiguredNamespace(t *testing.T) {
t.Fatalf("expected two modules but got: %v", len(compiler.bundle.Modules))
}

// The compiler will strip the rego.v1 import from the module, so we need to compare it
// to a pure v1 module that doesn't require the import.
optimizedExp := ast.MustParseModuleWithOpts(`package custom
__not1_0_2__ = true if { data.test.q = _; _ }`,
ast.ParserOptions{AllFutureKeywords: true})
ast.ParserOptions{RegoVersion: ast.RegoV1})

if optimizedExp.String() != compiler.bundle.Modules[0].Parsed.String() {
t.Fatalf("expected optimized module to be:\n\n%v\n\ngot:\n\n%v", optimizedExp, compiler.bundle.Modules[0])
Expand All @@ -1258,7 +1260,7 @@ func TestCompilerOptimizationWithConfiguredNamespace(t *testing.T) {
k = {1, 2, 3} if { true }
p = true if { not data.custom.__not1_0_2__ }
q = true if { __local0__3 = input.a; data.test.k[__local0__3] = _; _; __local1__3 = input.b; data.test.k[__local1__3] = _; _ }`,
ast.ParserOptions{AllFutureKeywords: true})
ast.ParserOptions{RegoVersion: ast.RegoV1})

if expected.String() != compiler.bundle.Modules[1].Parsed.String() {
t.Fatalf("expected module to be:\n\n%v\n\ngot:\n\n%v", expected, compiler.bundle.Modules[1])
Expand Down
6 changes: 3 additions & 3 deletions rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ func TestPartialWithRegoV1(t *testing.T) {
expQuery: `data.partial.test.p = x`,
expSupport: `package partial.test.p
foo[__local1__1] { __local1__1 = input.v }`,
foo contains __local1__1 if { __local1__1 = input.v }`,
},
{
note: "rego.v1 imported",
Expand All @@ -1197,7 +1197,7 @@ foo[__local1__1] { __local1__1 = input.v }`,
expQuery: `data.partial.test.p = x`,
expSupport: `package partial.test.p
foo[__local1__1] { __local1__1 = input.v }`,
foo contains __local1__1 if { __local1__1 = input.v }`,
},
{
note: "future.keywords imported",
Expand All @@ -1210,7 +1210,7 @@ foo[__local1__1] { __local1__1 = input.v }`,
expQuery: `data.partial.test.p = x`,
expSupport: `package partial.test.p
foo[__local1__1] { __local1__1 = input.v }`,
foo contains __local1__1 if { __local1__1 = input.v }`,
},
}

Expand Down

0 comments on commit 7cd3fec

Please sign in to comment.