Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-checker #151

Merged
merged 1 commit into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions checkers/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,16 @@ func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (r
}
return true, ""
}

type ignoreChecker struct {
*gc.CheckerInfo
}

// Ignore always succeeds.
var Ignore gc.Checker = &ignoreChecker{
&gc.CheckerInfo{Name: "Ignore", Params: []string{"obtained"}},
}

func (checker *ignoreChecker) Check(params []interface{}, names []string) (result bool, error string) {
return true, ""
}
67 changes: 59 additions & 8 deletions checkers/deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func printable(v reflect.Value) interface{} {
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int) (ok bool, err error) {
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int, customCheckFunc CustomCheckFunc) (ok bool, err error) {
errorf := func(f string, a ...interface{}) error {
return &mismatchError{
v1: v1,
Expand Down Expand Up @@ -105,6 +105,13 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
visited[v] = true
}

if customCheckFunc != nil && v1.CanInterface() && v2.CanInterface() {
useDefault, equal, err := customCheckFunc(path, v1.Interface(), v2.Interface())
if !useDefault {
return equal, err
}
}

switch v1.Kind() {
case reflect.Array:
if v1.Len() != v2.Len() {
Expand All @@ -114,7 +121,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
for i := 0; i < v1.Len(); i++ {
if ok, err := deepValueEqual(
fmt.Sprintf("%s[%d]", path, i),
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -130,7 +137,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
for i := 0; i < v1.Len(); i++ {
if ok, err := deepValueEqual(
fmt.Sprintf("%s[%d]", path, i),
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -142,9 +149,9 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
}
return true, nil
}
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1)
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
case reflect.Ptr:
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1)
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
case reflect.Struct:
if v1.Type() == timeType {
// Special case for time - we ignore the time zone.
Expand All @@ -157,7 +164,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
}
for i, n := 0, v1.NumField(); i < n; i++ {
path := path + "." + v1.Type().Field(i).Name
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1); !ok {
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -179,7 +186,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
} else {
p = path + "[someKey]"
}
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1); !ok {
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand Down Expand Up @@ -263,9 +270,53 @@ func DeepEqual(a1, a2 interface{}) (bool, error) {
if v1.Type() != v2.Type() {
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
}
return deepValueEqual("", v1, v2, make(map[visit]bool), 0)
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, nil)
}

// DeepEqualWithCustomCheck tests for deep equality. It uses normal == equality where
// possible but will scan elements of arrays, slices, maps, and fields
// of structs. In maps, keys are compared with == but elements use deep
// equality. DeepEqual correctly handles recursive types. Functions are
// equal only if they are both nil.
//
// DeepEqual differs from reflect.DeepEqual in two ways:
// - an empty slice is considered equal to a nil slice.
// - two time.Time values that represent the same instant
// but with different time zones are considered equal.
//
// If the two values compare unequal, the resulting error holds the
// first difference encountered.
//
// If both values are interface-able and customCheckFunc is non nil,
// customCheckFunc will be invoked. If it returns useDefault as true, the
// DeepEqual continues, otherwise the result of the customCheckFunc is used.
func DeepEqualWithCustomCheck(a1 interface{}, a2 interface{}, customCheckFunc CustomCheckFunc) (bool, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be exported?

errorf := func(f string, a ...interface{}) error {
return &mismatchError{
v1: reflect.ValueOf(a1),
v2: reflect.ValueOf(a2),
path: "",
how: fmt.Sprintf(f, a...),
}
}
if a1 == nil || a2 == nil {
if a1 == a2 {
return true, nil
}
return false, errorf("nil vs non-nil mismatch")
}
v1 := reflect.ValueOf(a1)
v2 := reflect.ValueOf(a2)
if v1.Type() != v2.Type() {
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
}
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, customCheckFunc)
}

// CustomCheckFunc should return true for useDefault if DeepEqualWithCustomCheck should behave like DeepEqual.
// Otherwise the result of the CustomCheckFunc is used.
type CustomCheckFunc func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error)

// interfaceOf returns v.Interface() even if v.CanInterface() == false.
// This enables us to call fmt.Printf on a value even if it's derived
// from inside an unexported field.
Expand Down
113 changes: 113 additions & 0 deletions checkers/multichecker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2020 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.

package checkers

import (
"fmt"
"regexp"

gc "gopkg.in/check.v1"
)

// MultiChecker is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
type MultiChecker struct {
*gc.CheckerInfo
checks map[string]multiCheck
regexChecks []regexCheck
}

type multiCheck struct {
checker gc.Checker
args []interface{}
}

type regexCheck struct {
multiCheck
regex *regexp.Regexp
}

// NewMultiChecker creates a MultiChecker which is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
func NewMultiChecker() *MultiChecker {
return &MultiChecker{
CheckerInfo: &gc.CheckerInfo{Name: "MultiChecker", Params: []string{"obtained", "expected"}},
checks: make(map[string]multiCheck),
}
}

// Add an explict checker by path.
func (checker *MultiChecker) Add(path string, c gc.Checker, args ...interface{}) *MultiChecker {
checker.checks[path] = multiCheck{
checker: c,
args: args,
}
return checker
}

// AddRegex exception which matches path with regex.
func (checker *MultiChecker) AddRegex(pathRegex string, c gc.Checker, args ...interface{}) *MultiChecker {
checker.regexChecks = append(checker.regexChecks, regexCheck{
multiCheck: multiCheck{
checker: c,
args: args,
},
regex: regexp.MustCompile("^" + pathRegex + "$"),
})
return checker
}

// Check for go check Checker interface.
func (checker *MultiChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
customCheckFunc := func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) {
var mc *multiCheck
if c, ok := checker.checks[path]; ok {
mc = &c
} else {
for _, v := range checker.regexChecks {
if v.regex.MatchString(path) {
mc = &v.multiCheck
break
}
}
}
if mc == nil {
return true, false, nil
}

params := append([]interface{}{a1}, mc.args...)
info := mc.checker.Info()
if len(params) < len(info.Params) {
return false, false, fmt.Errorf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(info.Params), len(params)+1)
}
// Copy since it may be mutated by Check.
names := append([]string{}, info.Params...)

// Trim to the expected params len.
params = params[:len(info.Params)]

// Perform substitution
for i, v := range params {
if v == ExpectedValue {
params[i] = a2
}
}

result, errStr := mc.checker.Check(params, names)
if result {
return false, true, nil
}
if path == "" {
path = "top level"
}
return false, false, fmt.Errorf("mismatch at %s: %s", path, errStr)
}
if ok, err := DeepEqualWithCustomCheck(params[0], params[1], customCheckFunc); !ok {
return false, err.Error()
}
return true, ""
}

// ExpectedValue if passed to MultiChecker.Add or MultiChecker.AddRegex, will be substituded with the expected value.
var ExpectedValue = &struct{}{}
71 changes: 71 additions & 0 deletions checkers/multichecker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package checkers_test

import (
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)

type MultiCheckerSuite struct{}

var _ = gc.Suite(&MultiCheckerSuite{})

func (s *MultiCheckerSuite) TestDeepEquals(c *gc.C) {
for i, test := range deepEqualTests {
c.Logf("test %d. %v == %v is %v", i, test.a, test.b, test.eq)
result, msg := jc.NewMultiChecker().Check([]interface{}{test.a, test.b}, nil)
c.Check(result, gc.Equals, test.eq)
if test.eq {
c.Check(msg, gc.Equals, "")
} else {
c.Check(msg, gc.Not(gc.Equals), "")
}
}
}

func (s *MultiCheckerSuite) TestArray(c *gc.C) {
a1 := []string{"a", "b", "c"}
a2 := []string{"a", "bbb", "c"}

checker := jc.NewMultiChecker().Add("[1]", jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestMap(c *gc.C) {
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"}

checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestRegexArray(c *gc.C) {
a1 := []string{"a", "b", "c"}
a2 := []string{"a", "bbb", "ccc"}

checker := jc.NewMultiChecker().AddRegex("\\[[1-2]\\]", jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestRegexMap(c *gc.C) {
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
a2 := map[string]string{"a": "aaaa", "b": "bbbb", "c": "cccc"}

checker := jc.NewMultiChecker().AddRegex(`\[".*"\]`, jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestArrayArraysUnordered(c *gc.C) {
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
a2 := [][]string{{"a", "b", "c"}, {}}

checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, []string{"e", "c", "d"})
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestArrayArraysUnorderedWithExpected(c *gc.C) {
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}}

checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue)
c.Check(a1, checker, a2)
}