diff --git a/checkers/bool.go b/checkers/bool.go index 02e3eec..7e42a76 100644 --- a/checkers/bool.go +++ b/checkers/bool.go @@ -115,3 +115,16 @@ func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (r } return true, "" } + +type ignoreChecker struct { + *gc.CheckerInfo +} + +// Ignore always succeeds. +var Ignore gc.Checker = &ignoreChecker{ + &gc.CheckerInfo{Name: "Ignore", Params: []string{"obtained"}}, +} + +func (checker *ignoreChecker) Check(params []interface{}, names []string) (result bool, error string) { + return true, "" +} diff --git a/checkers/deepequal.go b/checkers/deepequal.go index 43567fa..e01dc10 100644 --- a/checkers/deepequal.go +++ b/checkers/deepequal.go @@ -53,7 +53,7 @@ func printable(v reflect.Value) interface{} { // Tests for deep equality using reflected types. The map argument tracks // comparisons that have already been seen, which allows short circuiting on // recursive types. -func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int) (ok bool, err error) { +func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int, customCheckFunc CustomCheckFunc) (ok bool, err error) { errorf := func(f string, a ...interface{}) error { return &mismatchError{ v1: v1, @@ -105,6 +105,13 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d visited[v] = true } + if customCheckFunc != nil && v1.CanInterface() && v2.CanInterface() { + useDefault, equal, err := customCheckFunc(path, v1.Interface(), v2.Interface()) + if !useDefault { + return equal, err + } + } + switch v1.Kind() { case reflect.Array: if v1.Len() != v2.Len() { @@ -114,7 +121,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d for i := 0; i < v1.Len(); i++ { if ok, err := deepValueEqual( fmt.Sprintf("%s[%d]", path, i), - v1.Index(i), v2.Index(i), visited, depth+1); !ok { + v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok { return false, err } } @@ -130,7 +137,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d for i := 0; i < v1.Len(); i++ { if ok, err := deepValueEqual( fmt.Sprintf("%s[%d]", path, i), - v1.Index(i), v2.Index(i), visited, depth+1); !ok { + v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok { return false, err } } @@ -142,9 +149,9 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d } return true, nil } - return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1) + return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc) case reflect.Ptr: - return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1) + return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc) case reflect.Struct: if v1.Type() == timeType { // Special case for time - we ignore the time zone. @@ -157,7 +164,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d } for i, n := 0, v1.NumField(); i < n; i++ { path := path + "." + v1.Type().Field(i).Name - if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1); !ok { + if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1, customCheckFunc); !ok { return false, err } } @@ -179,7 +186,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d } else { p = path + "[someKey]" } - if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1); !ok { + if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1, customCheckFunc); !ok { return false, err } } @@ -263,9 +270,53 @@ func DeepEqual(a1, a2 interface{}) (bool, error) { if v1.Type() != v2.Type() { return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type()) } - return deepValueEqual("", v1, v2, make(map[visit]bool), 0) + return deepValueEqual("", v1, v2, make(map[visit]bool), 0, nil) +} + +// DeepEqualWithCustomCheck tests for deep equality. It uses normal == equality where +// possible but will scan elements of arrays, slices, maps, and fields +// of structs. In maps, keys are compared with == but elements use deep +// equality. DeepEqual correctly handles recursive types. Functions are +// equal only if they are both nil. +// +// DeepEqual differs from reflect.DeepEqual in two ways: +// - an empty slice is considered equal to a nil slice. +// - two time.Time values that represent the same instant +// but with different time zones are considered equal. +// +// If the two values compare unequal, the resulting error holds the +// first difference encountered. +// +// If both values are interface-able and customCheckFunc is non nil, +// customCheckFunc will be invoked. If it returns useDefault as true, the +// DeepEqual continues, otherwise the result of the customCheckFunc is used. +func DeepEqualWithCustomCheck(a1 interface{}, a2 interface{}, customCheckFunc CustomCheckFunc) (bool, error) { + errorf := func(f string, a ...interface{}) error { + return &mismatchError{ + v1: reflect.ValueOf(a1), + v2: reflect.ValueOf(a2), + path: "", + how: fmt.Sprintf(f, a...), + } + } + if a1 == nil || a2 == nil { + if a1 == a2 { + return true, nil + } + return false, errorf("nil vs non-nil mismatch") + } + v1 := reflect.ValueOf(a1) + v2 := reflect.ValueOf(a2) + if v1.Type() != v2.Type() { + return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type()) + } + return deepValueEqual("", v1, v2, make(map[visit]bool), 0, customCheckFunc) } +// CustomCheckFunc should return true for useDefault if DeepEqualWithCustomCheck should behave like DeepEqual. +// Otherwise the result of the CustomCheckFunc is used. +type CustomCheckFunc func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) + // interfaceOf returns v.Interface() even if v.CanInterface() == false. // This enables us to call fmt.Printf on a value even if it's derived // from inside an unexported field. diff --git a/checkers/multichecker.go b/checkers/multichecker.go new file mode 100644 index 0000000..ab3947a --- /dev/null +++ b/checkers/multichecker.go @@ -0,0 +1,113 @@ +// Copyright 2020 Canonical Ltd. +// Licensed under the LGPLv3, see LICENCE file for details. + +package checkers + +import ( + "fmt" + "regexp" + + gc "gopkg.in/check.v1" +) + +// MultiChecker is a deep checker that by default matches for equality. +// But checks can be overriden based on path (either explicit match or regexp) +type MultiChecker struct { + *gc.CheckerInfo + checks map[string]multiCheck + regexChecks []regexCheck +} + +type multiCheck struct { + checker gc.Checker + args []interface{} +} + +type regexCheck struct { + multiCheck + regex *regexp.Regexp +} + +// NewMultiChecker creates a MultiChecker which is a deep checker that by default matches for equality. +// But checks can be overriden based on path (either explicit match or regexp) +func NewMultiChecker() *MultiChecker { + return &MultiChecker{ + CheckerInfo: &gc.CheckerInfo{Name: "MultiChecker", Params: []string{"obtained", "expected"}}, + checks: make(map[string]multiCheck), + } +} + +// Add an explict checker by path. +func (checker *MultiChecker) Add(path string, c gc.Checker, args ...interface{}) *MultiChecker { + checker.checks[path] = multiCheck{ + checker: c, + args: args, + } + return checker +} + +// AddRegex exception which matches path with regex. +func (checker *MultiChecker) AddRegex(pathRegex string, c gc.Checker, args ...interface{}) *MultiChecker { + checker.regexChecks = append(checker.regexChecks, regexCheck{ + multiCheck: multiCheck{ + checker: c, + args: args, + }, + regex: regexp.MustCompile("^" + pathRegex + "$"), + }) + return checker +} + +// Check for go check Checker interface. +func (checker *MultiChecker) Check(params []interface{}, names []string) (result bool, errStr string) { + customCheckFunc := func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) { + var mc *multiCheck + if c, ok := checker.checks[path]; ok { + mc = &c + } else { + for _, v := range checker.regexChecks { + if v.regex.MatchString(path) { + mc = &v.multiCheck + break + } + } + } + if mc == nil { + return true, false, nil + } + + params := append([]interface{}{a1}, mc.args...) + info := mc.checker.Info() + if len(params) < len(info.Params) { + return false, false, fmt.Errorf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(info.Params), len(params)+1) + } + // Copy since it may be mutated by Check. + names := append([]string{}, info.Params...) + + // Trim to the expected params len. + params = params[:len(info.Params)] + + // Perform substitution + for i, v := range params { + if v == ExpectedValue { + params[i] = a2 + } + } + + result, errStr := mc.checker.Check(params, names) + if result { + return false, true, nil + } + if path == "" { + path = "top level" + } + return false, false, fmt.Errorf("mismatch at %s: %s", path, errStr) + } + if ok, err := DeepEqualWithCustomCheck(params[0], params[1], customCheckFunc); !ok { + return false, err.Error() + } + return true, "" +} + +// ExpectedValue if passed to MultiChecker.Add or MultiChecker.AddRegex, will be substituded with the expected value. +var ExpectedValue = &struct{}{} diff --git a/checkers/multichecker_test.go b/checkers/multichecker_test.go new file mode 100644 index 0000000..55fb73d --- /dev/null +++ b/checkers/multichecker_test.go @@ -0,0 +1,71 @@ +package checkers_test + +import ( + jc "github.com/juju/testing/checkers" + gc "gopkg.in/check.v1" +) + +type MultiCheckerSuite struct{} + +var _ = gc.Suite(&MultiCheckerSuite{}) + +func (s *MultiCheckerSuite) TestDeepEquals(c *gc.C) { + for i, test := range deepEqualTests { + c.Logf("test %d. %v == %v is %v", i, test.a, test.b, test.eq) + result, msg := jc.NewMultiChecker().Check([]interface{}{test.a, test.b}, nil) + c.Check(result, gc.Equals, test.eq) + if test.eq { + c.Check(msg, gc.Equals, "") + } else { + c.Check(msg, gc.Not(gc.Equals), "") + } + } +} + +func (s *MultiCheckerSuite) TestArray(c *gc.C) { + a1 := []string{"a", "b", "c"} + a2 := []string{"a", "bbb", "c"} + + checker := jc.NewMultiChecker().Add("[1]", jc.Ignore) + c.Check(a1, checker, a2) +} + +func (s *MultiCheckerSuite) TestMap(c *gc.C) { + a1 := map[string]string{"a": "a", "b": "b", "c": "c"} + a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"} + + checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore) + c.Check(a1, checker, a2) +} + +func (s *MultiCheckerSuite) TestRegexArray(c *gc.C) { + a1 := []string{"a", "b", "c"} + a2 := []string{"a", "bbb", "ccc"} + + checker := jc.NewMultiChecker().AddRegex("\\[[1-2]\\]", jc.Ignore) + c.Check(a1, checker, a2) +} + +func (s *MultiCheckerSuite) TestRegexMap(c *gc.C) { + a1 := map[string]string{"a": "a", "b": "b", "c": "c"} + a2 := map[string]string{"a": "aaaa", "b": "bbbb", "c": "cccc"} + + checker := jc.NewMultiChecker().AddRegex(`\[".*"\]`, jc.Ignore) + c.Check(a1, checker, a2) +} + +func (s *MultiCheckerSuite) TestArrayArraysUnordered(c *gc.C) { + a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}} + a2 := [][]string{{"a", "b", "c"}, {}} + + checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, []string{"e", "c", "d"}) + c.Check(a1, checker, a2) +} + +func (s *MultiCheckerSuite) TestArrayArraysUnorderedWithExpected(c *gc.C) { + a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}} + a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}} + + checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue) + c.Check(a1, checker, a2) +}