Skip to content

Commit

Permalink
Adding --v1-compatible flag to all previously unsupported command l…
Browse files Browse the repository at this point in the history
…ine commands (#6521)

In addition to those commands already supported:

* build
* check
* eval
* fmt
* test

support has been added to the following commands:

* `bench`
* `deps`
* `exec`
* `inspect`
* `parse`
* `run` (command `server` and `REPL`)

Fixes: #6520

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored Jan 24, 2024
1 parent b1261ba commit b36151d
Show file tree
Hide file tree
Showing 45 changed files with 3,724 additions and 164 deletions.
17 changes: 1 addition & 16 deletions ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,22 +698,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, regoCo
if mod.regoVersion == RegoV0CompatV1 || mod.regoVersion == RegoV1 {
for _, rule := range mod.Rules {
for r := rule; r != nil; r = r.Else {
var t string
if r.isFunction() {
t = "function"
} else {
t = "rule"
}

if r.generatedBody && r.Head.generatedValue {
errs = append(errs, NewError(ParseErr, r.Location, "%s must have value assignment and/or body declaration", t))
}
if r.Body != nil && !r.generatedBody && !ruleDeclarationHasKeyword(r, tokens.If) && !r.Default {
errs = append(errs, NewError(ParseErr, r.Location, "`if` keyword is required before %s body", t))
}
if r.Head.RuleKind() == MultiValue && !ruleDeclarationHasKeyword(r, tokens.Contains) {
errs = append(errs, NewError(ParseErr, r.Location, "`contains` keyword is required for partial set rules"))
}
errs = append(errs, CheckRegoV1(r)...)
}
}
}
Expand Down
42 changes: 40 additions & 2 deletions ast/rego_v1.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package ast

import (
"fmt"

"github.com/open-policy-agent/opa/ast/internal/tokens"
)

func checkDuplicateImports(modules []*Module) (errors Errors) {
for _, module := range modules {
processedImports := map[Var]*Import{}
Expand Down Expand Up @@ -116,11 +122,43 @@ func checkDeprecatedBuiltinsForCurrentVersion(node interface{}) Errors {
return checkDeprecatedBuiltins(deprecatedBuiltins, node)
}

// CheckRegoV1 checks the given module for errors that are specific to Rego v1
func CheckRegoV1(module *Module) Errors {
// CheckRegoV1 checks the given module or rule for errors that are specific to Rego v1.
// Passing something other than an *ast.Rule or *ast.Module is considered a programming error, and will cause a panic.
func CheckRegoV1(x interface{}) Errors {
switch x := x.(type) {
case *Module:
return checkRegoV1Module(x)
case *Rule:
return checkRegoV1Rule(x)
}
panic(fmt.Sprintf("cannot check rego-v1 compatibility on type %T", x))
}

func checkRegoV1Module(module *Module) Errors {
var errors Errors
errors = append(errors, checkDuplicateImports([]*Module{module})...)
errors = append(errors, checkRootDocumentOverrides(module)...)
errors = append(errors, checkDeprecatedBuiltinsForCurrentVersion(module)...)
return errors
}

func checkRegoV1Rule(rule *Rule) Errors {
t := "rule"
if rule.isFunction() {
t = "function"
}

var errs Errors

if rule.generatedBody && rule.Head.generatedValue {
errs = append(errs, NewError(ParseErr, rule.Location, "%s must have value assignment and/or body declaration", t))
}
if rule.Body != nil && !rule.generatedBody && !ruleDeclarationHasKeyword(rule, tokens.If) && !rule.Default {
errs = append(errs, NewError(ParseErr, rule.Location, "`if` keyword is required before %s body", t))
}
if rule.Head.RuleKind() == MultiValue && !ruleDeclarationHasKeyword(rule, tokens.Contains) {
errs = append(errs, NewError(ParseErr, rule.Location, "`contains` keyword is required for partial set rules"))
}

return errs
}
2 changes: 2 additions & 0 deletions cmd/bench.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ The optional "gobench" output format conforms to the Go Benchmark Data Format.
addIgnoreFlag(benchCommand.Flags(), &params.ignore)
addSchemaFlags(benchCommand.Flags(), params.schema)
addTargetFlag(benchCommand.Flags(), params.target)
addV1CompatibleFlag(benchCommand.Flags(), &params.v1Compatible, false)

// Shared benchmark flags
addCountFlag(benchCommand.Flags(), &params.count, "benchmark")
Expand Down Expand Up @@ -299,6 +300,7 @@ func benchE2E(ctx context.Context, args []string, params benchmarkCommandParams,
GracefulShutdownPeriod: params.gracefulShutdownPeriod,
ShutdownWaitPeriod: params.shutdownWaitPeriod,
ConfigFile: params.configFile,
V1Compatible: params.v1Compatible,
}

rt, err := runtime.NewRuntime(ctx, rtParams)
Expand Down
117 changes: 117 additions & 0 deletions cmd/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -691,6 +692,122 @@ func TestBenchMainBadQueryE2E(t *testing.T) {
}
}

func TestBenchMainV1Compatible(t *testing.T) {
tests := []struct {
note string
v1Compatible bool
module string
query string
expErrs []string
}{
// These tests are slow, so we're not being completely exhaustive here.
{
note: "v0.x, keywords not used",
module: `package test
a[4] {
1 == 1
}`,
query: `data.test.a`,
},
{
note: "v0.x, no keywords imported",
module: `package test
a contains 4 if {
1 == 1
}`,
query: `data.test.a`,
expErrs: []string{
"rego_parse_error: var cannot be used for rule name",
"rego_parse_error: number cannot be used for rule name",
},
},
{
note: "v1.0, keywords not used",
v1Compatible: true,
module: `package test
a[4] {
1 == 1
}`,
query: `data.test.a`,
expErrs: []string{
"rego_parse_error: `if` keyword is required before rule body",
"rego_parse_error: `contains` keyword is required for partial set rules",
},
},
{
note: "v1.0, no keywords imported",
v1Compatible: true,
module: `package test
a contains 4 if {
1 == 1
}`,
query: `data.test.a`,
},
}

modes := []struct {
name string
e2e bool
}{
{
name: "run",
},
{
name: "e2e",
e2e: true,
},
}

for _, mode := range modes {
for _, tc := range tests {
t.Run(fmt.Sprintf("%s, %s", tc.note, mode.name), func(t *testing.T) {
files := map[string]string{
"mod.rego": tc.module,
}

test.WithTempFS(files, func(path string) {
params := testBenchParams()
_ = params.outputFormat.Set(evalPrettyOutput)
params.v1Compatible = tc.v1Compatible
params.e2e = mode.e2e

for n := range files {
err := params.dataPaths.Set(filepath.Join(path, n))
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
}

args := []string{tc.query}

var buf bytes.Buffer
rc, err := benchMain(args, params, &buf, &goBenchRunner{})

if len(tc.expErrs) > 0 {
if rc == 0 {
t.Fatalf("Expected non-zero return code")
}

output := buf.String()
for _, expErr := range tc.expErrs {
if !strings.Contains(output, expErr) {
t.Fatalf("Expected error:\n\n%s\n\ngot:\n\n%s", expErr, output)
}
}
} else {
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if rc != 0 {
t.Fatalf("Unexpected return code %d, expected 0", rc)
}
}
})
})
}
}
}

func TestRenderBenchmarkResultJSONOutput(t *testing.T) {
params := testBenchParams()
err := params.outputFormat.Set(evalJSONOutput)
Expand Down
32 changes: 25 additions & 7 deletions cmd/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package cmd
import (
"errors"
"fmt"
"io"
"os"

"github.com/open-policy-agent/opa/dependencies"
Expand All @@ -24,21 +25,35 @@ type depsCommandParams struct {
outputFormat *util.EnumFlag
ignore []string
bundlePaths repeatedStringFlag
v1Compatible bool
}

func (p *depsCommandParams) regoVersion() ast.RegoVersion {
if p.v1Compatible {
return ast.RegoV1
}
return ast.RegoV0
}

const (
depsFormatPretty = "pretty"
depsFormatJSON = "json"
)

func init() {

func newDepsCommandParams() depsCommandParams {
var params depsCommandParams

params.outputFormat = util.NewEnumFlag(depsFormatPretty, []string{
depsFormatPretty, depsFormatJSON,
})

return params
}

func init() {

params := newDepsCommandParams()

depsCommand := &cobra.Command{
Use: "deps <query>",
Short: "Analyze Rego query dependencies",
Expand Down Expand Up @@ -81,7 +96,7 @@ data.policy.is_admin.
return nil
},
Run: func(cmd *cobra.Command, args []string) {
if err := deps(args, params); err != nil {
if err := deps(args, params, os.Stdout); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
Expand All @@ -92,11 +107,12 @@ data.policy.is_admin.
addDataFlag(depsCommand.Flags(), &params.dataPaths)
addBundleFlag(depsCommand.Flags(), &params.bundlePaths)
addOutputFormat(depsCommand.Flags(), params.outputFormat)
addV1CompatibleFlag(depsCommand.Flags(), &params.v1Compatible, false)

RootCommand.AddCommand(depsCommand)
}

func deps(args []string, params depsCommandParams) error {
func deps(args []string, params depsCommandParams, w io.Writer) error {

query, err := ast.ParseBody(args[0])
if err != nil {
Expand All @@ -110,7 +126,9 @@ func deps(args []string, params depsCommandParams) error {
Ignore: params.ignore,
}

result, err := loader.NewFileLoader().Filtered(params.dataPaths.v, f.Apply)
result, err := loader.NewFileLoader().
WithRegoVersion(params.regoVersion()).
Filtered(params.dataPaths.v, f.Apply)
if err != nil {
return err
}
Expand Down Expand Up @@ -157,8 +175,8 @@ func deps(args []string, params depsCommandParams) error {

switch params.outputFormat.String() {
case depsFormatJSON:
return presentation.JSON(os.Stdout, output)
return presentation.JSON(w, output)
default:
return output.Pretty(os.Stdout)
return output.Pretty(w)
}
}
Loading

0 comments on commit b36151d

Please sign in to comment.