/* Copyright 2018 The KubeSphere Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package deep provides function deep.Equal which is like reflect.DeepEqual but // returns a list of differences. This is helpful when comparing complex types // like structures and maps. package reflectutils import ( "errors" "fmt" "log" "reflect" "strings" ) var ( // FloatPrecision is the number of decimal places to round float values // to when comparing. FloatPrecision = 10 // MaxDiff specifies the maximum number of differences to return. MaxDiff = 10 // MaxDepth specifies the maximum levels of a struct to recurse into, // if greater than zero. If zero, there is no limit. MaxDepth = 0 // LogErrors causes errors to be logged to STDERR when true. LogErrors = false // CompareUnexportedFields causes unexported struct fields, like s in // T{s int}, to be compared when true. CompareUnexportedFields = false ) var ( // ErrMaxRecursion is logged when MaxDepth is reached. ErrMaxRecursion = errors.New("recursed to MaxDepth") // ErrTypeMismatch is logged when Equal passed two different types of values. ErrTypeMismatch = errors.New("variables are different reflect.Type") // ErrNotHandled is logged when a primitive Go kind is not handled. ErrNotHandled = errors.New("cannot compare the reflect.Kind") ) type cmp struct { diff []string buff []string floatFormat string } var errorType = reflect.TypeOf((*error)(nil)).Elem() // Equal compares variables a and b, recursing into their structure up to // MaxDepth levels deep (if greater than zero), and returns a list of differences, // or nil if there are none. Some differences may not be found if an error is // also returned. // // If a type has an Equal method, like time.Equal, it is called to check for // equality. // // When comparing a struct, if a field has the tag `deep:"-"` then it will be // ignored. func Equal(a, b interface{}) []string { aVal := reflect.ValueOf(a) bVal := reflect.ValueOf(b) c := &cmp{ diff: []string{}, buff: []string{}, floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), } if a == nil && b == nil { return nil } else if a == nil && b != nil { c.saveDiff("", b) } else if a != nil && b == nil { c.saveDiff(a, "") } if len(c.diff) > 0 { return c.diff } c.equals(aVal, bVal, 0) if len(c.diff) > 0 { return c.diff // diffs } return nil // no diffs } func (c *cmp) equals(a, b reflect.Value, level int) { if MaxDepth > 0 && level > MaxDepth { logError(ErrMaxRecursion) return } // Check if one value is nil, e.g. T{x: *X} and T.x is nil if !a.IsValid() || !b.IsValid() { if a.IsValid() && !b.IsValid() { c.saveDiff(a.Type(), "") } else if !a.IsValid() && b.IsValid() { c.saveDiff("", b.Type()) } return } // If different types, they can't be equal aType := a.Type() bType := b.Type() if aType != bType { c.saveDiff(aType, bType) logError(ErrTypeMismatch) return } // Primitive https://golang.org/pkg/reflect/#Kind aKind := a.Kind() bKind := b.Kind() // Do a and b have underlying elements? Yes if they're ptr or interface. aElem := aKind == reflect.Ptr || aKind == reflect.Interface bElem := bKind == reflect.Ptr || bKind == reflect.Interface // If both types implement the error interface, compare the error strings. // This must be done before dereferencing because the interface is on a // pointer receiver. Re https://github.com/go-test/deep/issues/31, a/b might // be primitive kinds; see TestErrorPrimitiveKind. if aType.Implements(errorType) && bType.Implements(errorType) { if (!aElem || !a.IsNil()) && (!bElem || !b.IsNil()) { aString := a.MethodByName("Error").Call(nil)[0].String() bString := b.MethodByName("Error").Call(nil)[0].String() if aString != bString { c.saveDiff(aString, bString) return } } } // Dereference pointers and interface{} if aElem || bElem { if aElem { a = a.Elem() } if bElem { b = b.Elem() } c.equals(a, b, level+1) return } switch aKind { ///////////////////////////////////////////////////////////////////// // Iterable kinds ///////////////////////////////////////////////////////////////////// case reflect.Struct: /* The variables are structs like: type T struct { FirstName string LastName string } Type = .T, Kind = reflect.Struct Iterate through the fields (FirstName, LastName), recurse into their values. */ // Types with an Equal() method, like time.Time, only if struct field // is exported (CanInterface) if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() { // Handle https://github.com/go-test/deep/issues/15: // Don't call T.Equal if the method is from an embedded struct, like: // type Foo struct { time.Time } // First, we'll encounter Equal(Ttime, time.Time) but if we pass b // as the 2nd arg we'll panic: "Call using pkg.Foo as type time.Time" // As far as I can tell, there's no way to see that the method is from // time.Time not Foo. So we check the type of the 1st (0) arg and skip // unless it's b type. Later, we'll encounter the time.Time anonymous/ // embedded field and then we'll have Equal(time.Time, time.Time). funcType := eqFunc.Type() if funcType.NumIn() == 1 && funcType.In(0) == bType { retVals := eqFunc.Call([]reflect.Value{b}) if !retVals[0].Bool() { c.saveDiff(a, b) } return } } for i := 0; i < a.NumField(); i++ { if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { continue // skip unexported field, e.g. s in type T struct {s string} } if aType.Field(i).Tag.Get("deep") == "-" { continue // field wants to be ignored } c.push(aType.Field(i).Name) // push field name to buff // Get the Value for each field, e.g. FirstName has Type = string, // Kind = reflect.String. af := a.Field(i) bf := b.Field(i) // Recurse to compare the field values c.equals(af, bf, level+1) c.pop() // pop field name from buff if len(c.diff) >= MaxDiff { break } } case reflect.Map: /* The variables are maps like: map[string]int{ "foo": 1, "bar": 2, } Type = map[string]int, Kind = reflect.Map Or: type T map[string]int{} Type = .T, Kind = reflect.Map Iterate through the map keys (foo, bar), recurse into their values. */ if a.IsNil() || b.IsNil() { if a.IsNil() && !b.IsNil() { c.saveDiff("", b) } else if !a.IsNil() && b.IsNil() { c.saveDiff(a, "") } return } if a.Pointer() == b.Pointer() { return } for _, key := range a.MapKeys() { c.push(fmt.Sprintf("map[%s]", key)) aVal := a.MapIndex(key) bVal := b.MapIndex(key) if bVal.IsValid() { c.equals(aVal, bVal, level+1) } else { c.saveDiff(aVal, "") } c.pop() if len(c.diff) >= MaxDiff { return } } for _, key := range b.MapKeys() { if aVal := a.MapIndex(key); aVal.IsValid() { continue } c.push(fmt.Sprintf("map[%s]", key)) c.saveDiff("", b.MapIndex(key)) c.pop() if len(c.diff) >= MaxDiff { return } } case reflect.Array: n := a.Len() for i := 0; i < n; i++ { c.push(fmt.Sprintf("array[%d]", i)) c.equals(a.Index(i), b.Index(i), level+1) c.pop() if len(c.diff) >= MaxDiff { break } } case reflect.Slice: if a.IsNil() || b.IsNil() { if a.IsNil() && !b.IsNil() { c.saveDiff("", b) } else if !a.IsNil() && b.IsNil() { c.saveDiff(a, "") } return } aLen := a.Len() bLen := b.Len() if a.Pointer() == b.Pointer() && aLen == bLen { return } n := aLen if bLen > aLen { n = bLen } for i := 0; i < n; i++ { c.push(fmt.Sprintf("slice[%d]", i)) if i < aLen && i < bLen { c.equals(a.Index(i), b.Index(i), level+1) } else if i < aLen { c.saveDiff(a.Index(i), "") } else { c.saveDiff("", b.Index(i)) } c.pop() if len(c.diff) >= MaxDiff { break } } ///////////////////////////////////////////////////////////////////// // Primitive kinds ///////////////////////////////////////////////////////////////////// case reflect.Float32, reflect.Float64: // Avoid 0.04147685731961082 != 0.041476857319611 // 6 decimal places is close enough aval := fmt.Sprintf(c.floatFormat, a.Float()) bval := fmt.Sprintf(c.floatFormat, b.Float()) if aval != bval { c.saveDiff(a.Float(), b.Float()) } case reflect.Bool: if a.Bool() != b.Bool() { c.saveDiff(a.Bool(), b.Bool()) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if a.Int() != b.Int() { c.saveDiff(a.Int(), b.Int()) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if a.Uint() != b.Uint() { c.saveDiff(a.Uint(), b.Uint()) } case reflect.String: if a.String() != b.String() { c.saveDiff(a.String(), b.String()) } default: logError(ErrNotHandled) } } func (c *cmp) push(name string) { c.buff = append(c.buff, name) } func (c *cmp) pop() { if len(c.buff) > 0 { c.buff = c.buff[0 : len(c.buff)-1] } } func (c *cmp) saveDiff(aval, bval interface{}) { if len(c.buff) > 0 { varName := strings.Join(c.buff, ".") c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval)) } else { c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval)) } } func logError(err error) { if LogErrors { log.Println(err) } }