@@ -53,7 +53,7 @@ func printable(v reflect.Value) interface{} {
53
53
// Tests for deep equality using reflected types. The map argument tracks
54
54
// comparisons that have already been seen, which allows short circuiting on
55
55
// recursive types.
56
- func deepValueEqual (path string , v1 , v2 reflect.Value , visited map [visit ]bool , depth int ) (ok bool , err error ) {
56
+ func deepValueEqual (path string , v1 , v2 reflect.Value , visited map [visit ]bool , depth int , customCheckFunc CustomCheckFunc ) (ok bool , err error ) {
57
57
errorf := func (f string , a ... interface {}) error {
58
58
return & mismatchError {
59
59
v1 : v1 ,
@@ -105,6 +105,13 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
105
105
visited [v ] = true
106
106
}
107
107
108
+ if customCheckFunc != nil && v1 .CanInterface () && v2 .CanInterface () {
109
+ useDefault , equal , err := customCheckFunc (path , v1 .Interface (), v2 .Interface ())
110
+ if ! useDefault {
111
+ return equal , err
112
+ }
113
+ }
114
+
108
115
switch v1 .Kind () {
109
116
case reflect .Array :
110
117
if v1 .Len () != v2 .Len () {
@@ -114,7 +121,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
114
121
for i := 0 ; i < v1 .Len (); i ++ {
115
122
if ok , err := deepValueEqual (
116
123
fmt .Sprintf ("%s[%d]" , path , i ),
117
- v1 .Index (i ), v2 .Index (i ), visited , depth + 1 ); ! ok {
124
+ v1 .Index (i ), v2 .Index (i ), visited , depth + 1 , customCheckFunc ); ! ok {
118
125
return false , err
119
126
}
120
127
}
@@ -130,7 +137,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
130
137
for i := 0 ; i < v1 .Len (); i ++ {
131
138
if ok , err := deepValueEqual (
132
139
fmt .Sprintf ("%s[%d]" , path , i ),
133
- v1 .Index (i ), v2 .Index (i ), visited , depth + 1 ); ! ok {
140
+ v1 .Index (i ), v2 .Index (i ), visited , depth + 1 , customCheckFunc ); ! ok {
134
141
return false , err
135
142
}
136
143
}
@@ -142,9 +149,9 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
142
149
}
143
150
return true , nil
144
151
}
145
- return deepValueEqual (path , v1 .Elem (), v2 .Elem (), visited , depth + 1 )
152
+ return deepValueEqual (path , v1 .Elem (), v2 .Elem (), visited , depth + 1 , customCheckFunc )
146
153
case reflect .Ptr :
147
- return deepValueEqual ("(*" + path + ")" , v1 .Elem (), v2 .Elem (), visited , depth + 1 )
154
+ return deepValueEqual ("(*" + path + ")" , v1 .Elem (), v2 .Elem (), visited , depth + 1 , customCheckFunc )
148
155
case reflect .Struct :
149
156
if v1 .Type () == timeType {
150
157
// Special case for time - we ignore the time zone.
@@ -157,7 +164,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
157
164
}
158
165
for i , n := 0 , v1 .NumField (); i < n ; i ++ {
159
166
path := path + "." + v1 .Type ().Field (i ).Name
160
- if ok , err := deepValueEqual (path , v1 .Field (i ), v2 .Field (i ), visited , depth + 1 ); ! ok {
167
+ if ok , err := deepValueEqual (path , v1 .Field (i ), v2 .Field (i ), visited , depth + 1 , customCheckFunc ); ! ok {
161
168
return false , err
162
169
}
163
170
}
@@ -179,7 +186,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
179
186
} else {
180
187
p = path + "[someKey]"
181
188
}
182
- if ok , err := deepValueEqual (p , v1 .MapIndex (k ), v2 .MapIndex (k ), visited , depth + 1 ); ! ok {
189
+ if ok , err := deepValueEqual (p , v1 .MapIndex (k ), v2 .MapIndex (k ), visited , depth + 1 , customCheckFunc ); ! ok {
183
190
return false , err
184
191
}
185
192
}
@@ -263,9 +270,53 @@ func DeepEqual(a1, a2 interface{}) (bool, error) {
263
270
if v1 .Type () != v2 .Type () {
264
271
return false , errorf ("type mismatch %s vs %s" , v1 .Type (), v2 .Type ())
265
272
}
266
- return deepValueEqual ("" , v1 , v2 , make (map [visit ]bool ), 0 )
273
+ return deepValueEqual ("" , v1 , v2 , make (map [visit ]bool ), 0 , nil )
274
+ }
275
+
276
+ // DeepEqualWithCustomCheck tests for deep equality. It uses normal == equality where
277
+ // possible but will scan elements of arrays, slices, maps, and fields
278
+ // of structs. In maps, keys are compared with == but elements use deep
279
+ // equality. DeepEqual correctly handles recursive types. Functions are
280
+ // equal only if they are both nil.
281
+ //
282
+ // DeepEqual differs from reflect.DeepEqual in two ways:
283
+ // - an empty slice is considered equal to a nil slice.
284
+ // - two time.Time values that represent the same instant
285
+ // but with different time zones are considered equal.
286
+ //
287
+ // If the two values compare unequal, the resulting error holds the
288
+ // first difference encountered.
289
+ //
290
+ // If both values are interface-able and customCheckFunc is non nil,
291
+ // customCheckFunc will be invoked. If it returns useDefault as true, the
292
+ // DeepEqual continues, otherwise the result of the customCheckFunc is used.
293
+ func DeepEqualWithCustomCheck (a1 interface {}, a2 interface {}, customCheckFunc CustomCheckFunc ) (bool , error ) {
294
+ errorf := func (f string , a ... interface {}) error {
295
+ return & mismatchError {
296
+ v1 : reflect .ValueOf (a1 ),
297
+ v2 : reflect .ValueOf (a2 ),
298
+ path : "" ,
299
+ how : fmt .Sprintf (f , a ... ),
300
+ }
301
+ }
302
+ if a1 == nil || a2 == nil {
303
+ if a1 == a2 {
304
+ return true , nil
305
+ }
306
+ return false , errorf ("nil vs non-nil mismatch" )
307
+ }
308
+ v1 := reflect .ValueOf (a1 )
309
+ v2 := reflect .ValueOf (a2 )
310
+ if v1 .Type () != v2 .Type () {
311
+ return false , errorf ("type mismatch %s vs %s" , v1 .Type (), v2 .Type ())
312
+ }
313
+ return deepValueEqual ("" , v1 , v2 , make (map [visit ]bool ), 0 , customCheckFunc )
267
314
}
268
315
316
+ // CustomCheckFunc should return true for useDefault if DeepEqualWithCustomCheck should behave like DeepEqual.
317
+ // Otherwise the result of the CustomCheckFunc is used.
318
+ type CustomCheckFunc func (path string , a1 interface {}, a2 interface {}) (useDefault bool , equal bool , err error )
319
+
269
320
// interfaceOf returns v.Interface() even if v.CanInterface() == false.
270
321
// This enables us to call fmt.Printf on a value even if it's derived
271
322
// from inside an unexported field.
0 commit comments