diff --git a/structs/structs.go b/structs/structs.go index 47496ae..f71b028 100644 --- a/structs/structs.go +++ b/structs/structs.go @@ -43,14 +43,15 @@ func Walk(s interface{}, callback CallbackFunc) { // - input: the original struct. // - includeFields: list of fields to include (if empty, includes all). // - excludeFields: list of fields to exclude (processed after include). -func FilterStruct(input interface{}, includeFields, excludeFields []string) (interface{}, error) { +func FilterStruct[T any](input T, includeFields, excludeFields []string) (T, error) { + var zeroValue T val := reflect.ValueOf(input) if val.Kind() == reflect.Ptr { val = val.Elem() } if val.Kind() != reflect.Struct { - return nil, errors.New("input must be a struct") + return zeroValue, errors.New("input must be a struct") } includeMap := make(map[string]bool) @@ -76,13 +77,13 @@ func FilterStruct(input interface{}, includeFields, excludeFields []string) (int } } - return filteredStruct.Interface(), nil + return filteredStruct.Interface().(T), nil } // GetStructFields returns all the top-level field names from the given struct. // - input: the original struct. // Returns a slice of field names or an error if the input is not a struct. -func GetStructFields(input interface{}) ([]string, error) { +func GetStructFields[T any](input T) ([]string, error) { val := reflect.ValueOf(input) if val.Kind() == reflect.Ptr { val = val.Elem()