Skip to content

Commit e4eedbc

Browse files
authored
Merge pull request #151 from hpidcock/multichecker
#151 MultiChecker allows you to perform a DeepEquals but have bespoke checkers based on path matching. For example, this allows the value at "b" to be ignored. ``` a1 := map[string]string{"a": "a", "b": "b", "c": "c"} a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"} checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore) c.Check(a1, checker, a2) ``` This allows the second element to have the SameContents check applied, ignoring order of elements: ``` a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}} a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}} checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue) c.Check(a1, checker, a2) ```
2 parents 6c8c298 + 543482a commit e4eedbc

File tree

4 files changed

+256
-8
lines changed

4 files changed

+256
-8
lines changed

checkers/bool.go

+13
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,16 @@ func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (r
115115
}
116116
return true, ""
117117
}
118+
119+
type ignoreChecker struct {
120+
*gc.CheckerInfo
121+
}
122+
123+
// Ignore always succeeds.
124+
var Ignore gc.Checker = &ignoreChecker{
125+
&gc.CheckerInfo{Name: "Ignore", Params: []string{"obtained"}},
126+
}
127+
128+
func (checker *ignoreChecker) Check(params []interface{}, names []string) (result bool, error string) {
129+
return true, ""
130+
}

checkers/deepequal.go

+59-8
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func printable(v reflect.Value) interface{} {
5353
// Tests for deep equality using reflected types. The map argument tracks
5454
// comparisons that have already been seen, which allows short circuiting on
5555
// recursive types.
56-
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int) (ok bool, err error) {
56+
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int, customCheckFunc CustomCheckFunc) (ok bool, err error) {
5757
errorf := func(f string, a ...interface{}) error {
5858
return &mismatchError{
5959
v1: v1,
@@ -105,6 +105,13 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
105105
visited[v] = true
106106
}
107107

108+
if customCheckFunc != nil && v1.CanInterface() && v2.CanInterface() {
109+
useDefault, equal, err := customCheckFunc(path, v1.Interface(), v2.Interface())
110+
if !useDefault {
111+
return equal, err
112+
}
113+
}
114+
108115
switch v1.Kind() {
109116
case reflect.Array:
110117
if v1.Len() != v2.Len() {
@@ -114,7 +121,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
114121
for i := 0; i < v1.Len(); i++ {
115122
if ok, err := deepValueEqual(
116123
fmt.Sprintf("%s[%d]", path, i),
117-
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
124+
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
118125
return false, err
119126
}
120127
}
@@ -130,7 +137,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
130137
for i := 0; i < v1.Len(); i++ {
131138
if ok, err := deepValueEqual(
132139
fmt.Sprintf("%s[%d]", path, i),
133-
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
140+
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
134141
return false, err
135142
}
136143
}
@@ -142,9 +149,9 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
142149
}
143150
return true, nil
144151
}
145-
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1)
152+
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
146153
case reflect.Ptr:
147-
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1)
154+
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
148155
case reflect.Struct:
149156
if v1.Type() == timeType {
150157
// Special case for time - we ignore the time zone.
@@ -157,7 +164,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
157164
}
158165
for i, n := 0, v1.NumField(); i < n; i++ {
159166
path := path + "." + v1.Type().Field(i).Name
160-
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1); !ok {
167+
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1, customCheckFunc); !ok {
161168
return false, err
162169
}
163170
}
@@ -179,7 +186,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
179186
} else {
180187
p = path + "[someKey]"
181188
}
182-
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1); !ok {
189+
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1, customCheckFunc); !ok {
183190
return false, err
184191
}
185192
}
@@ -263,9 +270,53 @@ func DeepEqual(a1, a2 interface{}) (bool, error) {
263270
if v1.Type() != v2.Type() {
264271
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
265272
}
266-
return deepValueEqual("", v1, v2, make(map[visit]bool), 0)
273+
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, nil)
274+
}
275+
276+
// DeepEqualWithCustomCheck tests for deep equality. It uses normal == equality where
277+
// possible but will scan elements of arrays, slices, maps, and fields
278+
// of structs. In maps, keys are compared with == but elements use deep
279+
// equality. DeepEqual correctly handles recursive types. Functions are
280+
// equal only if they are both nil.
281+
//
282+
// DeepEqual differs from reflect.DeepEqual in two ways:
283+
// - an empty slice is considered equal to a nil slice.
284+
// - two time.Time values that represent the same instant
285+
// but with different time zones are considered equal.
286+
//
287+
// If the two values compare unequal, the resulting error holds the
288+
// first difference encountered.
289+
//
290+
// If both values are interface-able and customCheckFunc is non nil,
291+
// customCheckFunc will be invoked. If it returns useDefault as true, the
292+
// DeepEqual continues, otherwise the result of the customCheckFunc is used.
293+
func DeepEqualWithCustomCheck(a1 interface{}, a2 interface{}, customCheckFunc CustomCheckFunc) (bool, error) {
294+
errorf := func(f string, a ...interface{}) error {
295+
return &mismatchError{
296+
v1: reflect.ValueOf(a1),
297+
v2: reflect.ValueOf(a2),
298+
path: "",
299+
how: fmt.Sprintf(f, a...),
300+
}
301+
}
302+
if a1 == nil || a2 == nil {
303+
if a1 == a2 {
304+
return true, nil
305+
}
306+
return false, errorf("nil vs non-nil mismatch")
307+
}
308+
v1 := reflect.ValueOf(a1)
309+
v2 := reflect.ValueOf(a2)
310+
if v1.Type() != v2.Type() {
311+
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
312+
}
313+
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, customCheckFunc)
267314
}
268315

316+
// CustomCheckFunc should return true for useDefault if DeepEqualWithCustomCheck should behave like DeepEqual.
317+
// Otherwise the result of the CustomCheckFunc is used.
318+
type CustomCheckFunc func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error)
319+
269320
// interfaceOf returns v.Interface() even if v.CanInterface() == false.
270321
// This enables us to call fmt.Printf on a value even if it's derived
271322
// from inside an unexported field.

checkers/multichecker.go

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright 2020 Canonical Ltd.
2+
// Licensed under the LGPLv3, see LICENCE file for details.
3+
4+
package checkers
5+
6+
import (
7+
"fmt"
8+
"regexp"
9+
10+
gc "gopkg.in/check.v1"
11+
)
12+
13+
// MultiChecker is a deep checker that by default matches for equality.
14+
// But checks can be overriden based on path (either explicit match or regexp)
15+
type MultiChecker struct {
16+
*gc.CheckerInfo
17+
checks map[string]multiCheck
18+
regexChecks []regexCheck
19+
}
20+
21+
type multiCheck struct {
22+
checker gc.Checker
23+
args []interface{}
24+
}
25+
26+
type regexCheck struct {
27+
multiCheck
28+
regex *regexp.Regexp
29+
}
30+
31+
// NewMultiChecker creates a MultiChecker which is a deep checker that by default matches for equality.
32+
// But checks can be overriden based on path (either explicit match or regexp)
33+
func NewMultiChecker() *MultiChecker {
34+
return &MultiChecker{
35+
CheckerInfo: &gc.CheckerInfo{Name: "MultiChecker", Params: []string{"obtained", "expected"}},
36+
checks: make(map[string]multiCheck),
37+
}
38+
}
39+
40+
// Add an explict checker by path.
41+
func (checker *MultiChecker) Add(path string, c gc.Checker, args ...interface{}) *MultiChecker {
42+
checker.checks[path] = multiCheck{
43+
checker: c,
44+
args: args,
45+
}
46+
return checker
47+
}
48+
49+
// AddRegex exception which matches path with regex.
50+
func (checker *MultiChecker) AddRegex(pathRegex string, c gc.Checker, args ...interface{}) *MultiChecker {
51+
checker.regexChecks = append(checker.regexChecks, regexCheck{
52+
multiCheck: multiCheck{
53+
checker: c,
54+
args: args,
55+
},
56+
regex: regexp.MustCompile("^" + pathRegex + "$"),
57+
})
58+
return checker
59+
}
60+
61+
// Check for go check Checker interface.
62+
func (checker *MultiChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
63+
customCheckFunc := func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) {
64+
var mc *multiCheck
65+
if c, ok := checker.checks[path]; ok {
66+
mc = &c
67+
} else {
68+
for _, v := range checker.regexChecks {
69+
if v.regex.MatchString(path) {
70+
mc = &v.multiCheck
71+
break
72+
}
73+
}
74+
}
75+
if mc == nil {
76+
return true, false, nil
77+
}
78+
79+
params := append([]interface{}{a1}, mc.args...)
80+
info := mc.checker.Info()
81+
if len(params) < len(info.Params) {
82+
return false, false, fmt.Errorf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(info.Params), len(params)+1)
83+
}
84+
// Copy since it may be mutated by Check.
85+
names := append([]string{}, info.Params...)
86+
87+
// Trim to the expected params len.
88+
params = params[:len(info.Params)]
89+
90+
// Perform substitution
91+
for i, v := range params {
92+
if v == ExpectedValue {
93+
params[i] = a2
94+
}
95+
}
96+
97+
result, errStr := mc.checker.Check(params, names)
98+
if result {
99+
return false, true, nil
100+
}
101+
if path == "" {
102+
path = "top level"
103+
}
104+
return false, false, fmt.Errorf("mismatch at %s: %s", path, errStr)
105+
}
106+
if ok, err := DeepEqualWithCustomCheck(params[0], params[1], customCheckFunc); !ok {
107+
return false, err.Error()
108+
}
109+
return true, ""
110+
}
111+
112+
// ExpectedValue if passed to MultiChecker.Add or MultiChecker.AddRegex, will be substituded with the expected value.
113+
var ExpectedValue = &struct{}{}

checkers/multichecker_test.go

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package checkers_test
2+
3+
import (
4+
jc "github.com/juju/testing/checkers"
5+
gc "gopkg.in/check.v1"
6+
)
7+
8+
type MultiCheckerSuite struct{}
9+
10+
var _ = gc.Suite(&MultiCheckerSuite{})
11+
12+
func (s *MultiCheckerSuite) TestDeepEquals(c *gc.C) {
13+
for i, test := range deepEqualTests {
14+
c.Logf("test %d. %v == %v is %v", i, test.a, test.b, test.eq)
15+
result, msg := jc.NewMultiChecker().Check([]interface{}{test.a, test.b}, nil)
16+
c.Check(result, gc.Equals, test.eq)
17+
if test.eq {
18+
c.Check(msg, gc.Equals, "")
19+
} else {
20+
c.Check(msg, gc.Not(gc.Equals), "")
21+
}
22+
}
23+
}
24+
25+
func (s *MultiCheckerSuite) TestArray(c *gc.C) {
26+
a1 := []string{"a", "b", "c"}
27+
a2 := []string{"a", "bbb", "c"}
28+
29+
checker := jc.NewMultiChecker().Add("[1]", jc.Ignore)
30+
c.Check(a1, checker, a2)
31+
}
32+
33+
func (s *MultiCheckerSuite) TestMap(c *gc.C) {
34+
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
35+
a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"}
36+
37+
checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore)
38+
c.Check(a1, checker, a2)
39+
}
40+
41+
func (s *MultiCheckerSuite) TestRegexArray(c *gc.C) {
42+
a1 := []string{"a", "b", "c"}
43+
a2 := []string{"a", "bbb", "ccc"}
44+
45+
checker := jc.NewMultiChecker().AddRegex("\\[[1-2]\\]", jc.Ignore)
46+
c.Check(a1, checker, a2)
47+
}
48+
49+
func (s *MultiCheckerSuite) TestRegexMap(c *gc.C) {
50+
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
51+
a2 := map[string]string{"a": "aaaa", "b": "bbbb", "c": "cccc"}
52+
53+
checker := jc.NewMultiChecker().AddRegex(`\[".*"\]`, jc.Ignore)
54+
c.Check(a1, checker, a2)
55+
}
56+
57+
func (s *MultiCheckerSuite) TestArrayArraysUnordered(c *gc.C) {
58+
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
59+
a2 := [][]string{{"a", "b", "c"}, {}}
60+
61+
checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, []string{"e", "c", "d"})
62+
c.Check(a1, checker, a2)
63+
}
64+
65+
func (s *MultiCheckerSuite) TestArrayArraysUnorderedWithExpected(c *gc.C) {
66+
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
67+
a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}}
68+
69+
checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue)
70+
c.Check(a1, checker, a2)
71+
}

0 commit comments

Comments
 (0)