diff --git a/internal/edittree/edittree_test.go b/internal/edittree/edittree_test.go index 9956930377..ec354eac0a 100644 --- a/internal/edittree/edittree_test.go +++ b/internal/edittree/edittree_test.go @@ -885,7 +885,7 @@ func parsePath(path *ast.Term) (ast.Ref, error) { pathSegments = append(pathSegments, term) }) default: - return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p)) + return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.ValueName(p)) } return pathSegments, nil diff --git a/internal/planner/planner.go b/internal/planner/planner.go index d6b3020413..160775c0e9 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -1519,7 +1519,7 @@ func (p *Planner) planValue(t ast.Value, loc *ast.Location, iter planiter) error p.loc = loc return p.planObjectComprehension(v, iter) default: - return fmt.Errorf("%v term not implemented", ast.TypeName(v)) + return fmt.Errorf("%v term not implemented", ast.ValueName(v)) } } diff --git a/v1/ast/compare.go b/v1/ast/compare.go index 3bb6f2a75d..24e61712e7 100644 --- a/v1/ast/compare.go +++ b/v1/ast/compare.go @@ -151,14 +151,7 @@ func Compare(a, b interface{}) int { } return 1 case Var: - b := b.(Var) - if a.Equal(b) { - return 0 - } - if a < b { - return -1 - } - return 1 + return VarCompare(a, b.(Var)) case Ref: b := b.(Ref) return termSliceCompare(a, b) @@ -181,7 +174,7 @@ func Compare(a, b interface{}) int { if cmp := Compare(a.Term, b.Term); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case *ObjectComprehension: b := b.(*ObjectComprehension) if cmp := Compare(a.Key, b.Key); cmp != 0 { @@ -190,13 +183,13 @@ func Compare(a, b interface{}) int { if cmp := Compare(a.Value, b.Value); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case *SetComprehension: b := b.(*SetComprehension) if cmp := Compare(a.Term, b.Term); cmp != 0 { return cmp } - return Compare(a.Body, b.Body) + return a.Body.Compare(b.Body) case Call: b := b.(Call) return termSliceCompare(a, b) @@ -394,3 +387,54 @@ func withSliceCompare(a, b []*With) int { } return 0 } + +func VarCompare(a, b Var) int { + if a == b { + return 0 + } + if a < b { + return -1 + } + return 1 +} + +func TermValueCompare(a, b *Term) int { + return a.Value.Compare(b.Value) +} + +func ValueEqual(a, b Value) bool { + // TODO(ae): why doesn't this work the same? + // + // case interface{ Equal(Value) bool }: + // return v.Equal(b) + // + // When put on top, golangci-lint even flags the other cases as unreachable.. + // but TestTopdownVirtualCache will have failing test cases when we replace + // the other cases with the above one.. 🤔 + switch v := a.(type) { + case Null: + return v.Equal(b) + case Boolean: + return v.Equal(b) + case Number: + return v.Equal(b) + case String: + return v.Equal(b) + case Var: + return v.Equal(b) + case Ref: + return v.Equal(b) + case *Array: + return v.Equal(b) + } + + return a.Compare(b) == 0 +} + +func RefCompare(a, b Ref) int { + return termSliceCompare(a, b) +} + +func RefEqual(a, b Ref) bool { + return termSliceEqual(a, b) +} diff --git a/v1/ast/compile.go b/v1/ast/compile.go index a238d454af..76b3c51bda 100644 --- a/v1/ast/compile.go +++ b/v1/ast/compile.go @@ -5500,7 +5500,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e return true } } - errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) + errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", ValueName(t.Value))) return true } diff --git a/v1/ast/interning.go b/v1/ast/interning.go index f521af9661..17b10231b7 100644 --- a/v1/ast/interning.go +++ b/v1/ast/interning.go @@ -15,6 +15,8 @@ var ( // since this is by far the most common negative number minusOneTerm = &Term{Value: Number("-1")} + + InternedNullTerm = &Term{Value: Null{}} ) // InternedBooleanTerm returns an interned term with the given boolean value. @@ -1090,3 +1092,7 @@ var intNumberTerms = [...]*Term{ {Value: Number("511")}, {Value: Number("512")}, } + +var InternedEmptyString = StringTerm("") + +var InternedEmptyObject = ObjectTerm() diff --git a/v1/ast/map.go b/v1/ast/map.go index c22d279a68..5a64f32505 100644 --- a/v1/ast/map.go +++ b/v1/ast/map.go @@ -31,7 +31,7 @@ func (vs *ValueMap) MarshalJSON() ([]byte, error) { vs.Iter(func(k Value, v Value) bool { tmp = append(tmp, map[string]interface{}{ "name": k.String(), - "type": TypeName(v), + "type": ValueName(v), "value": v, }) return false diff --git a/v1/ast/parser.go b/v1/ast/parser.go index 6639ca990b..fef9575132 100644 --- a/v1/ast/parser.go +++ b/v1/ast/parser.go @@ -591,7 +591,7 @@ func (p *Parser) parsePackage() *Package { pkg.Path[0] = DefaultRootDocument.Copy().SetLocation(v[0].Location) first, ok := v[0].Value.(Var) if !ok { - p.errorf(v[0].Location, "unexpected %v token: expecting var", TypeName(v[0].Value)) + p.errorf(v[0].Location, "unexpected %v token: expecting var", ValueName(v[0].Value)) return nil } pkg.Path[1] = StringTerm(string(first)).SetLocation(v[0].Location) @@ -600,7 +600,7 @@ func (p *Parser) parsePackage() *Package { case String: pkg.Path[i] = v[i-1] default: - p.errorf(v[i-1].Location, "unexpected %v token: expecting string", TypeName(v[i-1].Value)) + p.errorf(v[i-1].Location, "unexpected %v token: expecting string", ValueName(v[i-1].Value)) return nil } } @@ -643,7 +643,7 @@ func (p *Parser) parseImport() *Import { case Ref: for i := 1; i < len(v); i++ { if _, ok := v[i].Value.(String); !ok { - p.errorf(v[i].Location, "unexpected %v token: expecting string", TypeName(v[i].Value)) + p.errorf(v[i].Location, "unexpected %v token: expecting string", ValueName(v[i].Value)) return nil } } @@ -1717,7 +1717,7 @@ func (p *Parser) parseRef(head *Term, offset int) (term *Term) { case Var, *Array, Object, Set, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: // ok default: - p.errorf(loc, "illegal ref (head cannot be %v)", TypeName(h)) + p.errorf(loc, "illegal ref (head cannot be %v)", ValueName(h)) } ref := []*Term{head} @@ -2318,7 +2318,7 @@ func (p *Parser) validateDefaultRuleArgs(rule *Rule) bool { switch v := x.Value.(type) { case Var: // do nothing default: - p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", TypeName(v))) + p.error(rule.Loc(), fmt.Sprintf("illegal default rule (arguments cannot contain %v)", ValueName(v))) valid = false return true } diff --git a/v1/ast/parser_ext.go b/v1/ast/parser_ext.go index f08c112a72..db1c3caedc 100644 --- a/v1/ast/parser_ext.go +++ b/v1/ast/parser_ext.go @@ -186,7 +186,7 @@ func ParseRuleFromExpr(module *Module, expr *Expr) (*Rule, error) { } return ParsePartialSetDocRuleFromTerm(module, term) default: - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(v)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(v)) } } @@ -277,7 +277,7 @@ func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, erro return nil, fmt.Errorf("ref not ground") } } else { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(lhs.Value)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(lhs.Value)) } head.Value = rhs head.Location = lhs.Location @@ -299,7 +299,7 @@ func ParseCompleteDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, erro func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) if !ok { - return nil, fmt.Errorf("%v cannot be used for rule name", TypeName(term.Value)) + return nil, fmt.Errorf("%v cannot be used for rule name", ValueName(term.Value)) } if _, ok := ref[0].Value.(Var); !ok { @@ -328,7 +328,7 @@ func ParseCompleteDocRuleWithDotsFromTerm(module *Module, term *Term) (*Rule, er func ParsePartialObjectDocRuleFromEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { ref, ok := lhs.Value.(Ref) if !ok { - return nil, fmt.Errorf("%v cannot be used as rule name", TypeName(lhs.Value)) + return nil, fmt.Errorf("%v cannot be used as rule name", ValueName(lhs.Value)) } if _, ok := ref[0].Value.(Var); !ok { @@ -363,7 +363,7 @@ func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { ref, ok := term.Value.(Ref) if !ok || len(ref) == 1 { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + return nil, fmt.Errorf("%vs cannot be used for rule head", ValueName(term.Value)) } if _, ok := ref[0].Value.(Var); !ok { return nil, fmt.Errorf("invalid rule head: %v", ref) @@ -373,7 +373,7 @@ func ParsePartialSetDocRuleFromTerm(module *Module, term *Term) (*Rule, error) { if len(ref) == 2 { v, ok := ref[0].Value.(Var) if !ok { - return nil, fmt.Errorf("%vs cannot be used for rule head", TypeName(term.Value)) + return nil, fmt.Errorf("%vs cannot be used for rule head", ValueName(term.Value)) } // Modify the code to add the location to the head ref // and set the head ref's jsonOptions. @@ -408,7 +408,7 @@ func ParseRuleFromCallEqExpr(module *Module, lhs, rhs *Term) (*Rule, error) { ref, ok := call[0].Value.(Ref) if !ok { - return nil, fmt.Errorf("%vs cannot be used in function signature", TypeName(call[0].Value)) + return nil, fmt.Errorf("%vs cannot be used in function signature", ValueName(call[0].Value)) } if _, ok := ref[0].Value.(Var); !ok { return nil, fmt.Errorf("invalid rule head: %v", ref) diff --git a/v1/ast/strings.go b/v1/ast/strings.go index e489f6977c..40d66753f5 100644 --- a/v1/ast/strings.go +++ b/v1/ast/strings.go @@ -16,3 +16,39 @@ func TypeName(x interface{}) string { } return strings.ToLower(reflect.Indirect(reflect.ValueOf(x)).Type().Name()) } + +// ValueName returns a human readable name for the AST Value type. +// This is preferrable over calling TypeName when the argument is known to be +// a Value, as this doesn't require reflection (= heap allocations). +func ValueName(x Value) string { + switch x.(type) { + case String: + return "string" + case Boolean: + return "boolean" + case Number: + return "number" + case Null: + return "null" + case Var: + return "var" + case Object: + return "object" + case Set: + return "set" + case Ref: + return "ref" + case Call: + return "call" + case *Array: + return "array" + case *ArrayComprehension: + return "arraycomprehension" + case *ObjectComprehension: + return "objectcomprehension" + case *SetComprehension: + return "setcomprehension" + } + + return TypeName(x) +} diff --git a/v1/ast/strings_bench_test.go b/v1/ast/strings_bench_test.go new file mode 100644 index 0000000000..c7cce82bc2 --- /dev/null +++ b/v1/ast/strings_bench_test.go @@ -0,0 +1,29 @@ +package ast + +import "testing" + +// BenchmarkTypeName-10 32207775 38.93 ns/op 8 B/op 1 allocs/op +func BenchmarkTypeName(b *testing.B) { + term := StringTerm("foo") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + name := TypeName(term.Value) + if name != "string" { + b.Fatalf("expected string but got %v", name) + } + } +} + +// BenchmarkValueName-10 508312227 2.374 ns/op 0 B/op 0 allocs/op +func BenchmarkValueName(b *testing.B) { + term := StringTerm("foo") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + name := ValueName(term.Value) + if name != "string" { + b.Fatalf("expected string but got %v", name) + } + } +} diff --git a/v1/ast/term.go b/v1/ast/term.go index d79f4418bd..1350150f1a 100644 --- a/v1/ast/term.go +++ b/v1/ast/term.go @@ -14,7 +14,7 @@ import ( "math/big" "net/url" "regexp" - "sort" + "slices" "strconv" "strings" "sync" @@ -56,10 +56,16 @@ type Value interface { func InterfaceToValue(x interface{}) (Value, error) { switch x := x.(type) { case nil: - return Null{}, nil + return NullValue, nil case bool: - return Boolean(x), nil + if x { + return InternedBooleanTerm(true).Value, nil + } + return InternedBooleanTerm(false).Value, nil case json.Number: + if interned := InternedIntNumberTermFromString(string(x)); interned != nil { + return interned.Value, nil + } return Number(x), nil case int64: return int64Number(x), nil @@ -85,11 +91,7 @@ func InterfaceToValue(x interface{}) (Value, error) { kvs := util.NewPtrSlice[Term](len(x) * 2) idx := 0 for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - kvs[idx].Value = k + kvs[idx].Value = String(k) v, err := InterfaceToValue(v) if err != nil { return nil, err @@ -105,15 +107,7 @@ func InterfaceToValue(x interface{}) (Value, error) { case map[string]string: r := newobject(len(x)) for k, v := range x { - k, err := InterfaceToValue(k) - if err != nil { - return nil, err - } - v, err := InterfaceToValue(v) - if err != nil { - return nil, err - } - r.Insert(NewTerm(k), NewTerm(v)) + r.Insert(StringTerm(k), StringTerm(v)) } return r, nil default: @@ -136,7 +130,7 @@ func ValueFromReader(r io.Reader) (Value, error) { // As converts v into a Go native type referred to by x. func As(v Value, x interface{}) error { - return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x) + return util.NewJSONDecoder(strings.NewReader(v.String())).Decode(x) } // Resolver defines the interface for resolving references to native Go values. @@ -363,7 +357,7 @@ func (term *Term) Copy() *Term { } // Equal returns true if this term equals the other term. Equality is -// defined for each kind of term. +// defined for each kind of term, and does not compare the Location. func (term *Term) Equal(other *Term) bool { if term == nil && other != nil { return false @@ -375,28 +369,7 @@ func (term *Term) Equal(other *Term) bool { return true } - // TODO(tsandall): This early-exit avoids allocations for types that have - // Equal() functions that just use == underneath. We should revisit the - // other types and implement Equal() functions that do not require - // allocations. - switch v := term.Value.(type) { - case Null: - return v.Equal(other.Value) - case Boolean: - return v.Equal(other.Value) - case Number: - return v.Equal(other.Value) - case String: - return v.Equal(other.Value) - case Var: - return v.Equal(other.Value) - case Ref: - return v.Equal(other.Value) - case *Array: - return v.Equal(other.Value) - } - - return term.Value.Compare(other.Value) == 0 + return ValueEqual(term.Value, other.Value) } // Get returns a value referred to by name from the term. @@ -441,7 +414,7 @@ func (term *Term) setJSONOptions(opts astJSON.Options) { // Specialized marshalling logic is required to include a type hint for Value. func (term *Term) MarshalJSON() ([]byte, error) { d := map[string]interface{}{ - "type": TypeName(term.Value), + "type": ValueName(term.Value), "value": term.Value, } if term.jsonOptions.MarshalOptions.IncludeLocation.Term { @@ -553,13 +526,7 @@ func ContainsClosures(v interface{}) bool { // IsScalar returns true if the AST value is a scalar. func IsScalar(v Value) bool { switch v.(type) { - case String: - return true - case Number: - return true - case Boolean: - return true - case Null: + case String, Number, Boolean, Null: return true } return false @@ -568,9 +535,11 @@ func IsScalar(v Value) bool { // Null represents the null value defined by JSON. type Null struct{} +var NullValue Value = Null{} + // NullTerm creates a new Term with a Null value. func NullTerm() *Term { - return &Term{Value: Null{}} + return &Term{Value: NullValue} } // Equal returns true if the other term Value is also Null. @@ -586,13 +555,16 @@ func (null Null) Equal(other Value) bool { // Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (null Null) Compare(other Value) int { - return Compare(null, other) + if _, ok := other.(Null); ok { + return 0 + } + return -1 } // Find returns the current value or a not found error. func (null Null) Find(path Ref) (Value, error) { if len(path) == 0 { - return null, nil + return NullValue, nil } return nil, errFindNotFound } @@ -616,7 +588,10 @@ type Boolean bool // BooleanTerm creates a new Term with a Boolean value. func BooleanTerm(b bool) *Term { - return &Term{Value: Boolean(b)} + if b { + return &Term{Value: InternedBooleanTerm(true).Value} + } + return &Term{Value: InternedBooleanTerm(false).Value} } // Equal returns true if the other Value is a Boolean and is equal. @@ -632,13 +607,29 @@ func (bol Boolean) Equal(other Value) bool { // Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (bol Boolean) Compare(other Value) int { - return Compare(bol, other) + switch other := other.(type) { + case Boolean: + if bol == other { + return 0 + } + if !bol { + return -1 + } + return 1 + case Null: + return 1 + } + + return -1 } // Find returns the current value or a not found error. func (bol Boolean) Find(path Ref) (Value, error) { if len(path) == 0 { - return bol, nil + if bol { + return InternedBooleanTerm(true).Value, nil + } + return InternedBooleanTerm(false).Value, nil } return nil, errFindNotFound } @@ -688,13 +679,14 @@ func FloatNumberTerm(f float64) *Term { func (num Number) Equal(other Value) bool { switch other := other.(type) { case Number: - n1, ok1 := num.Int64() - n2, ok2 := other.Int64() - if ok1 && ok2 && n1 == n2 { - return true + if n1, ok1 := num.Int64(); ok1 { + n2, ok2 := other.Int64() + if ok1 && ok2 && n1 == n2 { + return true + } } - return Compare(num, other) == 0 + return num.Compare(other) == 0 default: return false } @@ -703,6 +695,21 @@ func (num Number) Equal(other Value) bool { // Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (num Number) Compare(other Value) int { + // Optimize for the common case, as calling Compare allocates on heap. + if otherNum, yes := other.(Number); yes { + if ai, ok := num.Int64(); ok { + if bi, ok := otherNum.Int64(); ok { + if ai == bi { + return 0 + } + if ai < bi { + return -1 + } + return 1 + } + } + } + return Compare(num, other) } @@ -800,6 +807,19 @@ func (str String) Equal(other Value) bool { // Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (str String) Compare(other Value) int { + // Optimize for the common case of one string being compared to another by + // using a direct comparison of values. This avoids the allocation performed + // when calling Compare and its interface{} argument conversion. + if otherStr, ok := other.(String); ok { + if str == otherStr { + return 0 + } + if str < otherStr { + return -1 + } + return 1 + } + return Compare(str, other) } @@ -848,6 +868,9 @@ func (v Var) Equal(other Value) bool { // Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (v Var) Compare(other Value) int { + if otherVar, ok := other.(Var); ok { + return strings.Compare(string(v), string(otherVar)) + } return Compare(v, other) } @@ -1020,6 +1043,10 @@ func (ref Ref) Equal(other Value) bool { // Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (ref Ref) Compare(other Value) int { + if o, ok := other.(Ref); ok { + return termSliceCompare(ref, o) + } + return Compare(ref, other) } @@ -1051,32 +1078,32 @@ func (ref Ref) HasPrefix(other Ref) bool { // ConstantPrefix returns the constant portion of the ref starting from the head. func (ref Ref) ConstantPrefix() Ref { - ref = ref.Copy() - i := ref.Dynamic() if i < 0 { - return ref + return ref.Copy() } - return ref[:i] + return ref[:i].Copy() } func (ref Ref) StringPrefix() Ref { - r := ref.Copy() - for i := 1; i < len(ref); i++ { - switch r[i].Value.(type) { + switch ref[i].Value.(type) { case String: // pass default: // cut off - return r[:i] + return ref[:i].Copy() } } - return r + return ref.Copy() } // GroundPrefix returns the ground portion of the ref starting from the head. By // definition, the head of the reference is always ground. func (ref Ref) GroundPrefix() Ref { + if ref.IsGround() { + return ref + } + prefix := make(Ref, 0, len(ref)) for i, x := range ref { @@ -1260,6 +1287,19 @@ func (arr *Array) Equal(other Value) bool { // Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (arr *Array) Compare(other Value) int { + if b, ok := other.(*Array); ok { + return termSliceCompare(arr.elems, b.elems) + } + + sortA := sortOrder(arr) + sortB := sortOrder(other) + + if sortA < sortB { + return -1 + } else if sortB < sortA { + return 1 + } + return Compare(arr, other) } @@ -1307,7 +1347,9 @@ func (arr *Array) Sorted() *Array { for i := range cpy { cpy[i] = arr.elems[i] } - sort.Sort(termSlice(cpy)) + + slices.SortFunc(cpy, TermValueCompare) + a := NewArray(cpy...) a.hashs = arr.hashs return a @@ -1480,7 +1522,7 @@ func newset(n int) *set { keys: keys, hash: 0, ground: true, - sortGuard: new(sync.Once), + sortGuard: sync.Once{}, } } @@ -1493,11 +1535,15 @@ func SetTerm(t ...*Term) *Term { } type set struct { - elems map[int]*Term - keys []*Term - hash int - ground bool - sortGuard *sync.Once // Prevents race condition around sorting. + elems map[int]*Term + keys []*Term + hash int + ground bool + // Prevents race condition around sorting. + // We can avoid (the allocation cost of) using a pointer here as all + // methods of `set` use a pointer receiver, and the `sync.Once` value + // is never copied. + sortGuard sync.Once } // Copy returns a deep copy of s. @@ -1547,7 +1593,7 @@ func (s *set) String() string { func (s *set) sortedKeys() []*Term { s.sortGuard.Do(func() { - sort.Sort(termSlice(s.keys)) + slices.SortFunc(s.keys, TermValueCompare) }) return s.keys } @@ -1717,7 +1763,7 @@ func (s *set) clear() { s.keys = s.keys[:0] s.hash = 0 s.ground = true - s.sortGuard = new(sync.Once) + s.sortGuard = sync.Once{} } func (s *set) insertNoGuard(x *Term) { @@ -1825,7 +1871,7 @@ func (s *set) insert(x *Term, resetSortGuard bool) { // Note that this will always be the case when external code calls insert via // Add, or otherwise. Internal code may however benefit from not having to // re-create this pointer when it's known not to be needed. - s.sortGuard = new(sync.Once) + s.sortGuard = sync.Once{} } s.hash += hash @@ -2094,7 +2140,8 @@ func (l *lazyObj) Keys() []*Term { for k := range l.native { ret = append(ret, StringTerm(k)) } - sort.Sort(termSlice(ret)) + slices.SortFunc(ret, TermValueCompare) + return ret } @@ -2148,7 +2195,7 @@ type object struct { ground int // number of key and value grounds. Counting is // required to support insert's key-value replace. hash int - sortGuard *sync.Once // Prevents race condition around sorting. + sortGuard sync.Once // Prevents race condition around sorting. } func newobject(n int) *object { @@ -2161,7 +2208,7 @@ func newobject(n int) *object { keys: keys, ground: 0, hash: 0, - sortGuard: new(sync.Once), + sortGuard: sync.Once{}, } } @@ -2185,7 +2232,9 @@ func Item(key, value *Term) [2]*Term { func (obj *object) sortedKeys() objectElemSlice { obj.sortGuard.Do(func() { - sort.Sort(obj.keys) + slices.SortFunc(obj.keys, func(a, b *objectElem) int { + return a.key.Value.Compare(b.key.Value) + }) }) return obj.keys } @@ -2376,7 +2425,7 @@ func (obj *object) MarshalJSON() ([]byte, error) { // overlapping keys between obj and other, the values of associated with the keys are merged. Only // objects can be merged with other objects. If the values cannot be merged, the second turn value // will be false. -func (obj object) Merge(other Object) (Object, bool) { +func (obj *object) Merge(other Object) (Object, bool) { return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { obj1, ok1 := v1.Value.(Object) obj2, ok2 := v2.Value.(Object) @@ -2395,7 +2444,7 @@ func (obj object) Merge(other Object) (Object, bool) { // If there are overlapping keys between obj and other, the conflictResolver // is called. The conflictResolver can return a merged value and a boolean // indicating if the merge has failed and should stop. -func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { +func (obj *object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { result := NewObject() stop := obj.Until(func(k, v *Term) bool { v2 := other.Get(k) @@ -2438,11 +2487,11 @@ func (obj *object) Filter(filter Object) (Object, error) { } // Len returns the number of elements in the object. -func (obj object) Len() int { +func (obj *object) Len() int { return len(obj.keys) } -func (obj object) String() string { +func (obj *object) String() string { sb := sbPool.Get().(*strings.Builder) sb.Reset() sb.Grow(obj.Len() * 32) @@ -2667,8 +2716,8 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) { // See https://github.com/golang/go/issues/25955 for why we do it this way. // Note that this will always be the case when external code calls insert via // Add, or otherwise. Internal code may however benefit from not having to - // re-create this pointer when it's known not to be needed. - obj.sortGuard = new(sync.Once) + // re-create this when it's known not to be needed. + obj.sortGuard = sync.Once{} } obj.hash += hash + v.Hash() @@ -2695,7 +2744,7 @@ func (obj *object) rehash() { } func filterObject(o Value, filter Value) (Value, error) { - if filter.Compare(Null{}) == 0 { + if (Null{}).Equal(filter) { return o, nil } @@ -3013,12 +3062,16 @@ func (c Call) String() string { func termSliceCopy(a []*Term) []*Term { cpy := make([]*Term, len(a)) - for i := range a { - cpy[i] = a[i].Copy() - } + termSliceCopyTo(a, cpy) return cpy } +func termSliceCopyTo(src, dst []*Term) { + for i := range src { + dst[i] = src[i].Copy() + } +} + func termSliceEqual(a, b []*Term) bool { if len(a) == len(b) { for i := range a { @@ -3243,7 +3296,7 @@ func unmarshalValue(d map[string]interface{}) (Value, error) { v := d["value"] switch d["type"] { case "null": - return Null{}, nil + return NullValue, nil case "boolean": if b, ok := v.(bool); ok { return Boolean(b), nil diff --git a/v1/ast/term_test.go b/v1/ast/term_test.go index 63d4e7e6ca..2d28a05895 100644 --- a/v1/ast/term_test.go +++ b/v1/ast/term_test.go @@ -277,7 +277,7 @@ func TestTermBadJSON(t *testing.T) { term := Term{} err := util.UnmarshalJSON([]byte(input), &term) expected := fmt.Errorf("ast: unable to unmarshal term") - if !reflect.DeepEqual(expected, err) { + if expected.Error() != err.Error() { t.Errorf("Expected %v but got: %v", expected, err) } } @@ -756,7 +756,7 @@ func TestSetMap(t *testing.T) { return nil, fmt.Errorf("oops") }) - if !reflect.DeepEqual(err, fmt.Errorf("oops")) { + if err.Error() != "oops" { t.Fatalf("Expected oops to be returned but got: %v, %v", result, err) } } @@ -1418,7 +1418,7 @@ func TestLazyObjectKeys(t *testing.T) { }) act := x.Keys() exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} - if !reflect.DeepEqual(exp, act) { + if !termSliceEqual(exp, act) { t.Errorf("expected Keys() %v, got %v", exp, act) } assertForced(t, x, false) @@ -1436,7 +1436,7 @@ func TestLazyObjectKeysIterator(t *testing.T) { act = append(act, k) } exp := []*Term{StringTerm("a"), StringTerm("b"), StringTerm("c")} - if !reflect.DeepEqual(exp, act) { + if !termSliceEqual(exp, act) { t.Errorf("expected Keys() %v, got %v", exp, act) } assertForced(t, x, false) diff --git a/v1/format/format.go b/v1/format/format.go index 56c30171dd..e86964d1b4 100644 --- a/v1/format/format.go +++ b/v1/format/format.go @@ -1637,7 +1637,7 @@ func ArityFormatMismatchError(operands []*ast.Term, operator string, loc *ast.Lo have := make([]string, len(operands)) for i := 0; i < len(operands); i++ { - have[i] = ast.TypeName(operands[i].Value) + have[i] = ast.ValueName(operands[i].Value) } err := ast.NewError(ast.TypeErr, loc, "%s: %s", operator, "arity mismatch") err.Details = &ArityFormatErrDetail{ diff --git a/v1/rego/rego.go b/v1/rego/rego.go index 1b7ea47bdd..ede02439dd 100644 --- a/v1/rego/rego.go +++ b/v1/rego/rego.go @@ -2598,7 +2598,7 @@ func (r *Rego) rewriteQueryForPartialEval(_ ast.QueryCompiler, query ast.Body) ( ref, ok := term.Value.(ast.Ref) if !ok { - return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.TypeName(term.Value)) + return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.ValueName(term.Value)) } if !ref.IsGround() { diff --git a/v1/topdown/aggregates.go b/v1/topdown/aggregates.go index e7d0578224..02425d2411 100644 --- a/v1/topdown/aggregates.go +++ b/v1/topdown/aggregates.go @@ -99,7 +99,7 @@ func builtinMax(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - var max = ast.Value(ast.Null{}) + max := ast.InternedNullTerm.Value a.Foreach(func(x *ast.Term) { if ast.Compare(max, x.Value) <= 0 { max = x.Value @@ -110,7 +110,7 @@ func builtinMax(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { + max, err := a.Reduce(ast.InternedNullTerm, func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { if ast.Compare(max, elem) <= 0 { return elem, nil } @@ -142,11 +142,11 @@ func builtinMin(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err if a.Len() == 0 { return nil } - min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { + min, err := a.Reduce(ast.InternedNullTerm, func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { // The null term is considered to be less than any other term, // so in order for min of a set to make sense, we need to check // for it. - if min.Value.Compare(ast.Null{}) == 0 { + if min.Value.Compare(ast.InternedNullTerm.Value) == 0 { return elem, nil } diff --git a/v1/topdown/array.go b/v1/topdown/array.go index d37204bef0..4a2a2ed148 100644 --- a/v1/topdown/array.go +++ b/v1/topdown/array.go @@ -20,6 +20,13 @@ func builtinArrayConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T return err } + if arrA.Len() == 0 { + return iter(operands[1]) + } + if arrB.Len() == 0 { + return iter(operands[0]) + } + arrC := make([]*ast.Term, arrA.Len()+arrB.Len()) i := 0 @@ -68,6 +75,10 @@ func builtinArraySlice(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te startIndex = stopIndex } + if startIndex == 0 && stopIndex >= arr.Len() { + return iter(operands[0]) + } + return iter(ast.NewTerm(arr.Slice(startIndex, stopIndex))) } diff --git a/v1/topdown/builtins/builtins.go b/v1/topdown/builtins/builtins.go index c788cf2536..45a0b88408 100644 --- a/v1/topdown/builtins/builtins.go +++ b/v1/topdown/builtins/builtins.go @@ -128,23 +128,23 @@ func NewOperandErr(pos int, f string, a ...interface{}) error { func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error { if len(expected) == 1 { - return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got)) + return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.ValueName(got)) } - return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got)) + return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.ValueName(got)) } // NewOperandElementErr returns an operand error indicating an element in the // composite operand was wrong. func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error { - tpe := ast.TypeName(composite) + tpe := ast.ValueName(composite) if len(expected) == 1 { - return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got)) + return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.ValueName(got)) } - return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got)) + return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.ValueName(got)) } // NewOperandEnumErr returns an operand error indicating a value was wrong. @@ -233,7 +233,7 @@ func ObjectOperand(x ast.Value, pos int) (ast.Object, error) { func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) { a, ok := x.(*ast.Array) if !ok { - return ast.NewArray(), NewOperandTypeErr(pos, x, "array") + return nil, NewOperandTypeErr(pos, x, "array") } return a, nil } diff --git a/v1/topdown/crypto.go b/v1/topdown/crypto.go index ff53550748..ab499e3e8f 100644 --- a/v1/topdown/crypto.go +++ b/v1/topdown/crypto.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/hex" "encoding/json" "encoding/pem" "fmt" @@ -373,7 +374,7 @@ func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter } if len(rawKeys) == 0 { - return iter(ast.NullTerm()) + return iter(ast.InternedNullTerm) } key, err := jwk.New(rawKeys[0]) @@ -407,7 +408,7 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter } if string(input) == "" { - return iter(ast.NullTerm()) + return iter(ast.InternedNullTerm) } // get the raw private key @@ -417,7 +418,7 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter } if len(rawKeys) == 0 { - return iter(ast.NewTerm(ast.NewArray())) + return iter(emptyArr) } bs, err := json.Marshal(rawKeys) @@ -438,36 +439,43 @@ func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter return iter(ast.NewTerm(value)) } -func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) { - s, err := builtins.StringOperand(a, 1) - if err != nil { - return nil, err - } - return ast.String(h(s)), nil +func toHexEncodedString(src []byte) string { + dst := make([]byte, hex.EncodedLen(len(src))) + hex.Encode(dst, src) + return util.ByteSliceToString(dst) } func builtinCryptoMd5(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + md5sum := md5.Sum([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(md5sum[:]))) } func builtinCryptoSha1(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", sha1.Sum([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + sha1sum := sha1.Sum([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(sha1sum[:]))) } func builtinCryptoSha256(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - res, err := hashHelper(operands[0].Value, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) }) + s, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err } - return iter(ast.NewTerm(res)) + + sha256sum := sha256.Sum256([]byte(s)) + + return iter(ast.StringTerm(toHexEncodedString(sha256sum[:]))) } func hmacHelper(operands []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error { @@ -724,9 +732,11 @@ func readCertFromFile(localCertFile string) ([]byte, error) { return certPEM, nil } +var beginPrefix = []byte("-----BEGIN ") + func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls.Certificate, error) { - if !strings.HasPrefix(string(certPemBlock), "-----BEGIN") { + if !bytes.HasPrefix(certPemBlock, beginPrefix) { s, err := base64.StdEncoding.DecodeString(string(certPemBlock)) if err != nil { return nil, err @@ -734,7 +744,7 @@ func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls. certPemBlock = s } - if !strings.HasPrefix(string(keyPemBlock), "-----BEGIN") { + if !bytes.HasPrefix(keyPemBlock, beginPrefix) { s, err := base64.StdEncoding.DecodeString(string(keyPemBlock)) if err != nil { return nil, err @@ -743,7 +753,7 @@ func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls. } // we assume it a DER certificate and try to convert it to a PEM. - if !bytes.HasPrefix(certPemBlock, []byte("-----BEGIN")) { + if !bytes.HasPrefix(certPemBlock, beginPrefix) { pemBlock := &pem.Block{ Type: "CERTIFICATE", diff --git a/v1/topdown/crypto_test.go b/v1/topdown/crypto_test.go index 839f03cf95..b5bd82b1a0 100644 --- a/v1/topdown/crypto_test.go +++ b/v1/topdown/crypto_test.go @@ -865,3 +865,29 @@ func TestExtractX509VerifyOptions(t *testing.T) { } } } + +// Before/after replacing sprintf("%x", ...) with hex.EncodeToString(...), and using +// util.ByteSliceToString to convert the resulting byte slice: +// BenchmarkMd5-10 3294998 435.2 ns/op 128 B/op 5 allocs/op +// BenchmarkMd5-10 6193455 180.9 ns/op 96 B/op 3 allocs/op +// ... +func BenchmarkMd5(b *testing.B) { + bctx := BuiltinContext{} + operands := []*ast.Term{ast.StringTerm("hello")} + expect := ast.String("5d41402abc4b2a76b9719d911017c592") + iter := func(result *ast.Term) error { + if !expect.Equal(result.Value) { + return fmt.Errorf("unexpected result: %v", result.Value) + } + return nil + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + err := builtinCryptoMd5(bctx, operands, iter) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } + } +} diff --git a/v1/topdown/glob.go b/v1/topdown/glob.go index cda17f3827..4de17d06c5 100644 --- a/v1/topdown/glob.go +++ b/v1/topdown/glob.go @@ -13,8 +13,10 @@ import ( const globCacheMaxSize = 100 const globInterQueryValueCacheHits = "rego_builtin_glob_interquery_value_cache_hits" -var globCacheLock = sync.Mutex{} -var globCache map[string]glob.Glob +var noDelimiters = []rune{} +var dotDelimiters = []rune{'.'} +var globCacheLock = sync.RWMutex{} +var globCache = map[string]glob.Glob{} func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { pattern, err := builtins.StringOperand(operands[0].Value, 1) @@ -25,14 +27,14 @@ func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast. var delimiters []rune switch operands[1].Value.(type) { case ast.Null: - delimiters = []rune{} + delimiters = noDelimiters case *ast.Array: delimiters, err = builtins.RuneSliceOperand(operands[1].Value, 2) if err != nil { return err } if len(delimiters) == 0 { - delimiters = []rune{'.'} + delimiters = dotDelimiters } default: return builtins.NewOperandTypeErr(2, operands[1].Value, "array", "null") @@ -86,14 +88,15 @@ func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimit return res.Match(match), nil } - globCacheLock.Lock() - defer globCacheLock.Unlock() + globCacheLock.RLock() p, ok := globCache[id] + globCacheLock.RUnlock() if !ok { var err error if p, err = glob.Compile(pattern, delimiters...); err != nil { return false, err } + globCacheLock.Lock() if len(globCache) >= globCacheMaxSize { // Delete a (semi-)random key to make room for the new one. for k := range globCache { @@ -102,9 +105,10 @@ func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimit } } globCache[id] = p + globCacheLock.Unlock() } - out := p.Match(match) - return out, nil + + return p.Match(match), nil } func builtinGlobQuoteMeta(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -117,7 +121,6 @@ func builtinGlobQuoteMeta(_ BuiltinContext, operands []*ast.Term, iter func(*ast } func init() { - globCache = map[string]glob.Glob{} RegisterBuiltinFunc(ast.GlobMatch.Name, builtinGlobMatch) RegisterBuiltinFunc(ast.GlobQuoteMeta.Name, builtinGlobQuoteMeta) } diff --git a/v1/topdown/http.go b/v1/topdown/http.go index 20fea2d7a6..71c7c7d9eb 100644 --- a/v1/topdown/http.go +++ b/v1/topdown/http.go @@ -86,11 +86,24 @@ var cacheableHTTPStatusCodes = [...]int{ http.StatusNotImplemented, } +var ( + codeTerm = ast.StringTerm("code") + messageTerm = ast.StringTerm("message") + statusCodeTerm = ast.StringTerm("status_code") + errorTerm = ast.StringTerm("error") + methodTerm = ast.StringTerm("method") + urlTerm = ast.StringTerm("url") + + httpSendNetworkErrTerm = ast.StringTerm(HTTPSendNetworkErr) + httpSendInternalErrTerm = ast.StringTerm(HTTPSendInternalErr) +) + var ( allowedKeys = ast.NewSet() + keyCache = make(map[string]*ast.Term, len(allowedKeyNames)) cacheableCodes = ast.NewSet() - requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) - httpSendLatencyMetricKey = "rego_builtin_" + strings.ReplaceAll(ast.HTTPSend.Name, ".", "_") + requiredKeys = ast.NewSet(methodTerm, urlTerm) + httpSendLatencyMetricKey = "rego_builtin_http_send" httpSendInterQueryCacheHits = httpSendLatencyMetricKey + "_interquery_cache_hits" ) @@ -151,22 +164,24 @@ func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.T } func generateRaiseErrorResult(err error) *ast.Term { - obj := ast.NewObject() - obj.Insert(ast.StringTerm("status_code"), ast.InternedIntNumberTerm(0)) - - errObj := ast.NewObject() - + var errObj ast.Object switch err.(type) { case *url.Error: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendNetworkErr)) + errObj = ast.NewObject( + ast.Item(codeTerm, httpSendNetworkErrTerm), + ast.Item(messageTerm, ast.StringTerm(err.Error())), + ) default: - errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendInternalErr)) + errObj = ast.NewObject( + ast.Item(codeTerm, httpSendInternalErrTerm), + ast.Item(messageTerm, ast.StringTerm(err.Error())), + ) } - errObj.Insert(ast.StringTerm("message"), ast.StringTerm(err.Error())) - obj.Insert(ast.StringTerm("error"), ast.NewTerm(errObj)) - - return ast.NewTerm(obj) + return ast.NewTerm(ast.NewObject( + ast.Item(statusCodeTerm, ast.InternedIntNumberTerm(0)), + ast.Item(errorTerm, ast.NewTerm(errObj)), + )) } func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { @@ -212,21 +227,21 @@ func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) { func getKeyFromRequest(req ast.Object) (ast.Object, error) { // deep copy so changes to key do not reflect in the request object key := req.Copy() - cacheIgnoredHeadersTerm := req.Get(ast.StringTerm("cache_ignored_headers")) + cacheIgnoredHeadersTerm := req.Get(keyCache["cache_ignored_headers"]) allHeadersTerm := req.Get(ast.StringTerm("headers")) // skip because no headers to delete if cacheIgnoredHeadersTerm == nil || allHeadersTerm == nil { // need to explicitly set cache_ignored_headers to null // equivalent requests might have different sets of exclusion lists - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + key.Insert(ast.StringTerm("cache_ignored_headers"), ast.InternedNullTerm) return key, nil } var cacheIgnoredHeaders []string - var allHeaders map[string]interface{} err := ast.As(cacheIgnoredHeadersTerm.Value, &cacheIgnoredHeaders) if err != nil { return nil, err } + var allHeaders map[string]interface{} err = ast.As(allHeadersTerm.Value, &allHeaders) if err != nil { return nil, err @@ -238,14 +253,14 @@ func getKeyFromRequest(req ast.Object) (ast.Object, error) { if err != nil { return nil, err } - key.Insert(ast.StringTerm("headers"), ast.NewTerm(val)) + key.Insert(keyCache["headers"], ast.NewTerm(val)) // remove cache_ignored_headers key - key.Insert(ast.StringTerm("cache_ignored_headers"), ast.NullTerm()) + key.Insert(keyCache["cache_ignored_headers"], ast.InternedNullTerm) return key, nil } func init() { - createAllowedKeys() + createKeys() createCacheableHTTPStatusCodes() initDefaults() RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend) @@ -389,33 +404,24 @@ func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error { } func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) { - var url string - var method string - - // Additional CA certificates loading options. - var tlsCaCert []byte - var tlsCaCertEnvVar string - var tlsCaCertFile string - - // Client TLS certificate and key options. Each input source - // comes in a matched pair. - var tlsClientCert []byte - var tlsClientKey []byte - - var tlsClientCertEnvVar string - var tlsClientKeyEnvVar string - - var tlsClientCertFile string - var tlsClientKeyFile string - - var tlsServerName string - var body *bytes.Buffer - var rawBody *bytes.Buffer - var enableRedirect bool - var tlsUseSystemCerts *bool - var tlsConfig tls.Config - var customHeaders map[string]interface{} - var tlsInsecureSkipVerify bool + var ( + url, method string + // Additional CA certificates loading options. + tlsCaCert []byte + tlsCaCertEnvVar, tlsCaCertFile string + // Client TLS certificate and key options. Each input source + // comes in a matched pair. + tlsClientCert, tlsClientKey []byte + tlsClientCertEnvVar, tlsClientKeyEnvVar string + tlsClientCertFile, tlsClientKeyFile, tlsServerName string + + body, rawBody *bytes.Buffer + enableRedirect, tlsInsecureSkipVerify bool + tlsUseSystemCerts *bool + tlsConfig tls.Config + customHeaders map[string]interface{} + ) + timeout := defaultHTTPRequestTimeout for _, val := range obj.Keys() { @@ -724,7 +730,7 @@ func executeHTTPRequest(req *http.Request, client *http.Client, inputReqObj ast. var err error var retry int - retry, err = getNumberValFromReqObj(inputReqObj, ast.StringTerm("max_retry_attempts")) + retry, err = getNumberValFromReqObj(inputReqObj, keyCache["max_retry_attempts"]) if err != nil { return nil, err } @@ -1009,9 +1015,12 @@ func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp return nil } -func createAllowedKeys() { +func createKeys() { for _, element := range allowedKeyNames { - allowedKeys.Add(ast.StringTerm(element)) + term := ast.StringTerm(element) + + allowedKeys.Add(term) + keyCache[element] = term } } @@ -1045,7 +1054,7 @@ func parseTimeout(timeoutVal ast.Value) (time.Duration, error) { } return timeout, nil default: - return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t)) + return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.ValueName(t)) } } @@ -1078,7 +1087,7 @@ func getNumberValFromReqObj(req ast.Object, key *ast.Term) (int, error) { } func getCachingMode(req ast.Object) (cachingMode, error) { - key := ast.StringTerm("caching_mode") + key := keyCache["caching_mode"] var s ast.String var ok bool if v := req.Get(key); v != nil { @@ -1477,11 +1486,11 @@ func (c *interQueryCache) CheckCache() (ast.Value, error) { return resp, nil } - c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + c.forceJSONDecode, err = getBoolValFromReqObj(c.key, keyCache["force_json_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } - c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, keyCache["force_yaml_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } @@ -1545,11 +1554,11 @@ func (c *intraQueryCache) CheckCache() (ast.Value, error) { // InsertIntoCache inserts the key set on this object into the cache with the given value func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) { - forceJSONDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode")) + forceJSONDecode, err := getBoolValFromReqObj(c.key, keyCache["force_json_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } - forceYAMLDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode")) + forceYAMLDecode, err := getBoolValFromReqObj(c.key, keyCache["force_yaml_decode"]) if err != nil { return nil, handleHTTPSendErr(c.bctx, err) } @@ -1580,12 +1589,12 @@ func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) { } func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) { - value, err := getBoolValFromReqObj(req, ast.StringTerm("cache")) + value, err := getBoolValFromReqObj(req, keyCache["cache"]) if err != nil { return false, nil, err } - valueForceCache, err := getBoolValFromReqObj(req, ast.StringTerm("force_cache")) + valueForceCache, err := getBoolValFromReqObj(req, keyCache["force_cache"]) if err != nil { return false, nil, err } @@ -1603,7 +1612,7 @@ type forceCacheParams struct { } func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { - term := req.Get(ast.StringTerm("force_cache_duration_seconds")) + term := req.Get(keyCache["force_cache_duration_seconds"]) if term == nil { return nil, fmt.Errorf("'force_cache' set but 'force_cache_duration_seconds' parameter is missing") } @@ -1621,7 +1630,7 @@ func newForceCacheParams(req ast.Object) (*forceCacheParams, error) { func getRaiseErrorValue(req ast.Object) (bool, error) { result := ast.Boolean(true) var ok bool - if v := req.Get(ast.StringTerm("raise_error")); v != nil { + if v := req.Get(keyCache["raise_error"]); v != nil { if result, ok = v.Value.(ast.Boolean); !ok { return false, fmt.Errorf("invalid value for raise_error field") } diff --git a/v1/topdown/json.go b/v1/topdown/json.go index 57e079d2e0..5b7c414e40 100644 --- a/v1/topdown/json.go +++ b/v1/topdown/json.go @@ -189,7 +189,7 @@ func parsePath(path *ast.Term) (ast.Ref, error) { pathSegments = append(pathSegments, term) }) default: - return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p)) + return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.ValueName(p)) } return pathSegments, nil @@ -231,7 +231,7 @@ func pathsToObject(paths []ast.Ref) ast.Object { } if !done { - node.Insert(path[len(path)-1], ast.NullTerm()) + node.Insert(path[len(path)-1], ast.InternedNullTerm) } } diff --git a/v1/topdown/jsonschema.go b/v1/topdown/jsonschema.go index 588b7ec4ce..b1609fb044 100644 --- a/v1/topdown/jsonschema.go +++ b/v1/topdown/jsonschema.go @@ -61,7 +61,7 @@ func builtinJSONSchemaVerify(_ BuiltinContext, operands []*ast.Term, iter func(* return iter(newResultTerm(false, ast.StringTerm("jsonschema: "+err.Error()))) } - return iter(newResultTerm(true, ast.NullTerm())) + return iter(newResultTerm(true, ast.InternedNullTerm)) } // builtinJSONMatchSchema accepts 2 arguments both can be string or object and verifies if the document matches the JSON schema. diff --git a/v1/topdown/object.go b/v1/topdown/object.go index 11671da5f3..4db8fa8272 100644 --- a/v1/topdown/object.go +++ b/v1/topdown/object.go @@ -92,7 +92,7 @@ func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast. filterObj := ast.NewObject() keys.Foreach(func(key *ast.Term) { - filterObj.Insert(key, ast.NullTerm()) + filterObj.Insert(key, ast.InternedNullTerm) }) // Actually do the filtering diff --git a/v1/topdown/print.go b/v1/topdown/print.go index 2d16c2baab..f852f3e320 100644 --- a/v1/topdown/print.go +++ b/v1/topdown/print.go @@ -62,7 +62,7 @@ func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operand xs, ok := operands.Elem(i).Value.(ast.Set) if !ok { - return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))} + return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.ValueName(operands.Elem(i).Value)))} } if xs.Len() == 0 { diff --git a/v1/topdown/runtime.go b/v1/topdown/runtime.go index f892f1751e..9323225832 100644 --- a/v1/topdown/runtime.go +++ b/v1/topdown/runtime.go @@ -12,14 +12,16 @@ import ( var configStringTerm = ast.StringTerm("config") +var nothingResolver ast.Resolver = illegalResolver{} + func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error { if bctx.Runtime == nil { - return iter(ast.ObjectTerm()) + return iter(ast.InternedEmptyObject) } if bctx.Runtime.Get(configStringTerm) != nil { - iface, err := ast.ValueToInterface(bctx.Runtime.Value, illegalResolver{}) + iface, err := ast.ValueToInterface(bctx.Runtime.Value, nothingResolver) if err != nil { return err } diff --git a/v1/topdown/strings.go b/v1/topdown/strings.go index 8d6c753e6d..929a18ea0a 100644 --- a/v1/topdown/strings.go +++ b/v1/topdown/strings.go @@ -10,6 +10,8 @@ import ( "sort" "strconv" "strings" + "unicode" + "unicode/utf8" "github.com/tchap/go-patricia/v2/patricia" @@ -153,33 +155,48 @@ func builtinConcat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return err } - strs := []string{} + var strs []string switch b := operands[1].Value.(type) { case *ast.Array: - err := b.Iter(func(x *ast.Term) error { - s, ok := x.Value.(ast.String) + var l int + for i := 0; i < b.Len(); i++ { + s, ok := b.Elem(i).Value.(ast.String) if !ok { - return builtins.NewOperandElementErr(2, operands[1].Value, x.Value, "string") + return builtins.NewOperandElementErr(2, operands[1].Value, b.Elem(i).Value, "string") } - strs = append(strs, string(s)) - return nil - }) - if err != nil { - return err + l += len(string(s)) + } + + if b.Len() == 1 { + return iter(b.Elem(0)) } + + strs = make([]string, 0, l) + for i := 0; i < b.Len(); i++ { + strs = append(strs, string(b.Elem(i).Value.(ast.String))) + } + case ast.Set: - err := b.Iter(func(x *ast.Term) error { - s, ok := x.Value.(ast.String) + var l int + terms := b.Slice() + for i := 0; i < len(terms); i++ { + s, ok := terms[i].Value.(ast.String) if !ok { - return builtins.NewOperandElementErr(2, operands[1].Value, x.Value, "string") + return builtins.NewOperandElementErr(2, operands[1].Value, terms[i].Value, "string") } - strs = append(strs, string(s)) - return nil - }) - if err != nil { - return err + l += len(string(s)) + } + + if b.Len() == 1 { + return iter(b.Slice()[0]) + } + + strs = make([]string, 0, l) + for i := 0; i < b.Len(); i++ { + strs = append(strs, string(terms[i].Value.(ast.String))) } + default: return builtins.NewOperandTypeErr(2, operands[1].Value, "set", "array") } @@ -213,6 +230,10 @@ func builtinIndexOf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return fmt.Errorf("empty search character") } + if isASCII(string(base)) && isASCII(string(search)) { + return iter(ast.InternedIntNumberTerm(strings.Index(string(base), string(search)))) + } + baseRunes := []rune(string(base)) searchRunes := []rune(string(search)) searchLen := len(searchRunes) @@ -268,15 +289,10 @@ func builtinSubstring(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter if err != nil { return err } - runes := []rune(base) startIndex, err := builtins.IntOperand(operands[1].Value, 2) if err != nil { return err - } else if startIndex >= len(runes) { - return iter(ast.StringTerm("")) - } else if startIndex < 0 { - return fmt.Errorf("negative offset") } length, err := builtins.IntOperand(operands[2].Value, 3) @@ -284,18 +300,60 @@ func builtinSubstring(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - var s ast.String + if startIndex < 0 { + return fmt.Errorf("negative offset") + } + + sbase := string(base) + if sbase == "" { + return iter(ast.InternedEmptyString) + } + + // Optimized path for the likely common case of ASCII strings. + // This allocates less memory and runs in about 1/3 the time. + if isASCII(sbase) { + if startIndex >= len(sbase) { + return iter(ast.InternedEmptyString) + } + + if length < 0 { + return iter(ast.StringTerm(sbase[startIndex:])) + } + + upto := startIndex + length + if len(sbase) < upto { + upto = len(sbase) + } + return iter(ast.StringTerm(sbase[startIndex:upto])) + } + + runes := []rune(base) + + if startIndex >= len(runes) { + return iter(ast.InternedEmptyString) + } + + var s string if length < 0 { - s = ast.String(runes[startIndex:]) + s = string(runes[startIndex:]) } else { upto := startIndex + length if len(runes) < upto { upto = len(runes) } - s = ast.String(runes[startIndex:upto]) + s = string(runes[startIndex:upto]) } - return iter(ast.NewTerm(s)) + return iter(ast.StringTerm(s)) +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] > unicode.MaxASCII { + return false + } + } + return true } func builtinContains(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -325,7 +383,6 @@ func builtinStringCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T baseTerm := string(s) searchTerm := string(substr) - count := strings.Count(baseTerm, searchTerm) return iter(ast.InternedIntNumberTerm(count)) @@ -382,15 +439,22 @@ func builtinSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e if err != nil { return err } + d, err := builtins.StringOperand(operands[1].Value, 2) if err != nil { return err } + + if !strings.Contains(string(s), string(d)) { + return iter(ast.ArrayTerm(operands[0])) + } + elems := strings.Split(string(s), string(d)) arr := util.NewPtrSlice[ast.Term](len(elems)) for i := range elems { arr[i].Value = ast.String(elems[i]) } + return iter(ast.ArrayTerm(arr...)) } @@ -410,7 +474,12 @@ func builtinReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) return err } - return iter(ast.StringTerm(strings.Replace(string(s), string(old), string(n), -1))) + replaced := strings.Replace(string(s), string(old), string(n), -1) + if replaced == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(replaced)) } func builtinReplaceN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -454,6 +523,11 @@ func builtinTrim(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) er return err } + trimmed := strings.Trim(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + return iter(ast.StringTerm(strings.Trim(string(s), string(c)))) } @@ -468,7 +542,12 @@ func builtinTrimLeft(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term return err } - return iter(ast.StringTerm(strings.TrimLeft(string(s), string(c)))) + trimmed := strings.TrimLeft(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimPrefix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -482,7 +561,12 @@ func builtinTrimPrefix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - return iter(ast.StringTerm(strings.TrimPrefix(string(s), string(pre)))) + trimmed := strings.TrimPrefix(string(s), string(pre)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimRight(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -496,7 +580,12 @@ func builtinTrimRight(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - return iter(ast.StringTerm(strings.TrimRight(string(s), string(c)))) + trimmed := strings.TrimRight(string(s), string(c)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimSuffix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -510,7 +599,12 @@ func builtinTrimSuffix(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return err } - return iter(ast.StringTerm(strings.TrimSuffix(string(s), string(suf)))) + trimmed := strings.TrimSuffix(string(s), string(suf)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinTrimSpace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -519,7 +613,12 @@ func builtinTrimSpace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return err } - return iter(ast.StringTerm(strings.TrimSpace(string(s)))) + trimmed := strings.TrimSpace(string(s)) + if trimmed == string(s) { + return iter(operands[0]) + } + + return iter(ast.StringTerm(trimmed)) } func builtinSprintf(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -577,15 +676,23 @@ func builtinReverse(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) } func reverseString(str string) string { - sRunes := []rune(str) - length := len(sRunes) - reversedRunes := make([]rune, length) + var buf []byte + var arr [255]byte + size := len(str) + + if size < 255 { + buf = arr[:size:size] + } else { + buf = make([]byte, size) + } - for index, r := range sRunes { - reversedRunes[length-index-1] = r + for start := 0; start < size; { + r, n := utf8.DecodeRuneInString(str[start:]) + start += n + utf8.EncodeRune(buf[size-start:], r) } - return string(reversedRunes) + return string(buf) } func init() { diff --git a/v1/topdown/subset.go b/v1/topdown/subset.go index 08bdc8db45..29354d9730 100644 --- a/v1/topdown/subset.go +++ b/v1/topdown/subset.go @@ -88,9 +88,8 @@ func arraySet(t1, t2 *ast.Term) (bool, *ast.Array, ast.Set) { // associated with a key. func objectSubset(super ast.Object, sub ast.Object) bool { var superTerm *ast.Term - isSubset := true - sub.Until(func(key, subTerm *ast.Term) bool { + notSubset := sub.Until(func(key, subTerm *ast.Term) bool { // This really wants to be a for loop, hence the somewhat // weird internal structure. However, using Until() in this // was is a performance optimization, as it avoids performing @@ -98,10 +97,9 @@ func objectSubset(super ast.Object, sub ast.Object) bool { superTerm = super.Get(key) - // subTerm is can't be nil because we got it from Until(), so + // subTerm can't be nil because we got it from Until(), so // we only need to verify that super is non-nil. if superTerm == nil { - isSubset = false return true // break, not a subset } @@ -114,58 +112,39 @@ func objectSubset(super ast.Object, sub ast.Object) bool { // them normally. If only one term is an object, then we // do a normal comparison which will come up false. if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok { - if !objectSubset(superObj, subObj) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !objectSubset(superObj, subObj) } if ok, superSet, subSet := bothSets(superTerm, subTerm); ok { - if !setSubset(superSet, subSet) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !setSubset(superSet, subSet) } if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok { - if !arraySubset(superArray, subArray) { - isSubset = false - return true // break, not a subset - } - - return false // continue + return !arraySubset(superArray, subArray) } // We have already checked for exact equality, as well as for // all of the types of nested subsets we care about, so if we // get here it means this isn't a subset. - isSubset = false return true // break, not a subset }) - return isSubset + return !notSubset } // setSubset implements the subset operation on sets. // // Unlike in the object case, this is not recursive, we just compare values -// using ast.Set.Contains() because we have no well defined way to "match up" +// using ast.Set.Contains() because we have no well-defined way to "match up" // objects that are in different sets. func setSubset(super ast.Set, sub ast.Set) bool { - isSubset := true - sub.Until(func(t *ast.Term) bool { - if !super.Contains(t) { - isSubset = false - return true + for _, elem := range sub.Slice() { + if !super.Contains(elem) { + return false } - return false - }) + } - return isSubset + return true } // arraySubset implements the subset operation on arrays. @@ -197,12 +176,12 @@ func arraySubset(super, sub *ast.Array) bool { return false } - subElem := sub.Elem(subCursor) superElem := super.Elem(superCursor + subCursor) if superElem == nil { return false } + subElem := sub.Elem(subCursor) if superElem.Value.Compare(subElem.Value) == 0 { subCursor++ } else { diff --git a/v1/util/performance.go b/v1/util/performance.go index 03dc7d0601..b7222b23cb 100644 --- a/v1/util/performance.go +++ b/v1/util/performance.go @@ -1,6 +1,9 @@ package util -import "slices" +import ( + "slices" + "unsafe" +) // NewPtrSlice returns a slice of pointers to T with length n, // with only 2 allocations performed no matter the size of n. @@ -22,3 +25,15 @@ func GrowPtrSlice[T any](s []*T, n int) []*T { } return s } + +// Allocation free conversion from []byte to string (unsafe) +// Note that the byte slice must not be modified after conversion +func ByteSliceToString(bs []byte) string { + return unsafe.String(unsafe.SliceData(bs), len(bs)) +} + +// Allocation free conversion from ~string to []byte (unsafe) +// Note that the byte slice must not be modified after conversion +func StringToByteSlice[T ~string](s T) []byte { + return unsafe.Slice(unsafe.StringData(string(s)), len(s)) +}