diff --git a/cmd/validate/main.go b/cmd/validate/main.go index a0580967b..d867c767f 100644 --- a/cmd/validate/main.go +++ b/cmd/validate/main.go @@ -59,7 +59,6 @@ func main() { switch { case vd.OpenAPI == "3" || strings.HasPrefix(vd.OpenAPI, "3."): - openapi3.CircularReferenceCounter = *circular loader := openapi3.NewLoader() loader.IsExternalRefsAllowed = *ext diff --git a/openapi3/issue570_test.go b/openapi3/issue570_test.go index 75afb7e3e..1575e5599 100644 --- a/openapi3/issue570_test.go +++ b/openapi3/issue570_test.go @@ -9,5 +9,5 @@ import ( func TestIssue570(t *testing.T) { loader := NewLoader() _, err := loader.LoadFromFile("testdata/issue570.json") - require.ErrorContains(t, err, CircularReferenceError) + require.NoError(t, err) } diff --git a/openapi3/issue615_test.go b/openapi3/issue615_test.go index 67144e9e1..02532bb5a 100644 --- a/openapi3/issue615_test.go +++ b/openapi3/issue615_test.go @@ -9,21 +9,6 @@ import ( ) func TestIssue615(t *testing.T) { - { - var old int - old, openapi3.CircularReferenceCounter = openapi3.CircularReferenceCounter, 1 - defer func() { openapi3.CircularReferenceCounter = old }() - - loader := openapi3.NewLoader() - loader.IsExternalRefsAllowed = true - _, err := loader.LoadFromFile("testdata/recursiveRef/issue615.yml") - require.ErrorContains(t, err, openapi3.CircularReferenceError) - } - - var old int - old, openapi3.CircularReferenceCounter = openapi3.CircularReferenceCounter, 4 - defer func() { openapi3.CircularReferenceCounter = old }() - loader := openapi3.NewLoader() loader.IsExternalRefsAllowed = true doc, err := loader.LoadFromFile("testdata/recursiveRef/issue615.yml") diff --git a/openapi3/issue796_test.go b/openapi3/issue796_test.go index 0900ee5b9..9c8be17f2 100644 --- a/openapi3/issue796_test.go +++ b/openapi3/issue796_test.go @@ -7,11 +7,6 @@ import ( ) func TestIssue796(t *testing.T) { - var old int - // Need to set CircularReferenceCounter to > 10 - old, CircularReferenceCounter = CircularReferenceCounter, 20 - defer func() { CircularReferenceCounter = old }() - loader := NewLoader() doc, err := loader.LoadFromFile("testdata/issue796.yml") require.NoError(t, err) diff --git a/openapi3/load_cicular_ref_with_external_file_test.go b/openapi3/load_cicular_ref_with_external_file_test.go index 7a99e7600..cff5c01fe 100644 --- a/openapi3/load_cicular_ref_with_external_file_test.go +++ b/openapi3/load_cicular_ref_with_external_file_test.go @@ -47,6 +47,19 @@ func TestLoadCircularRefFromFile(t *testing.T) { bar.Value.Properties["foo"] = &openapi3.SchemaRef{Ref: "#/components/schemas/Foo", Value: foo.Value} foo.Value.Properties["bar"] = &openapi3.SchemaRef{Ref: "#/components/schemas/Bar", Value: bar.Value} + bazNestedRef := &openapi3.SchemaRef{Ref: "./baz.yml#/BazNested"} + array := openapi3.NewArraySchema() + array.Items = bazNestedRef + bazNested := &openapi3.Schema{Properties: map[string]*openapi3.SchemaRef{ + "bazArray": { + Value: &openapi3.Schema{ + Items: bazNestedRef, + }, + }, + "baz": bazNestedRef, + }} + bazNestedRef.Value = bazNested + want := &openapi3.T{ OpenAPI: "3.0.3", Info: &openapi3.Info{ @@ -57,6 +70,7 @@ func TestLoadCircularRefFromFile(t *testing.T) { Schemas: openapi3.Schemas{ "Foo": foo, "Bar": bar, + "Baz": bazNestedRef, }, }, } diff --git a/openapi3/loader.go b/openapi3/loader.go index 452d02ef9..993d1cc98 100644 --- a/openapi3/loader.go +++ b/openapi3/loader.go @@ -16,7 +16,14 @@ import ( "strings" ) +// CircularReferenceError is deprecated. +// kin-openapi will never throw this error anymore, and it's kept for compatibility reasons +// Deprecated: CircularReferenceError is deprecated. var CircularReferenceError = "kin-openapi bug found: circular schema reference not handled" + +// CircularReferenceCounter is deprecated. +// kin-openapi does not use this counter anymore, and it's kept for compatibility reasons +// Deprecated: CircularReferenceCounter is deprecated. var CircularReferenceCounter = 3 func foundUnresolvedRef(ref string) error { @@ -44,15 +51,9 @@ type Loader struct { visitedDocuments map[string]*T - visitedCallback map[*Callback]struct{} - visitedExample map[*Example]struct{} - visitedHeader map[*Header]struct{} - visitedLink map[*Link]struct{} - visitedParameter map[*Parameter]struct{} - visitedRequestBody map[*RequestBody]struct{} - visitedResponse map[*Response]struct{} - visitedSchema map[*Schema]struct{} - visitedSecurityScheme map[*SecurityScheme]struct{} + visitedRefs map[string]struct{} + visitedPath []string + backtrack map[string][]func(value any) } // NewLoader returns an empty Loader @@ -299,6 +300,35 @@ func isSingleRefElement(ref string) bool { return !strings.Contains(ref, "#") } +func (loader *Loader) visitRef(ref string) { + if loader.visitedRefs == nil { + loader.visitedRefs = make(map[string]struct{}) + loader.backtrack = make(map[string][]func(value any)) + } + loader.visitedPath = append(loader.visitedPath, ref) + loader.visitedRefs[ref] = struct{}{} +} + +func (loader *Loader) unvisitRef(ref string, value any) { + if value != nil { + for _, fn := range loader.backtrack[ref] { + fn(value) + } + } + delete(loader.visitedRefs, ref) + delete(loader.backtrack, ref) + loader.visitedPath = loader.visitedPath[:len(loader.visitedPath)-1] +} + +func (loader *Loader) shouldVisitRef(ref string, fn func(value any)) bool { + _, ok := loader.visitedRefs[ref] + if ok { + loader.backtrack[ref] = append(loader.backtrack[ref], fn) + return false + } + return true +} + func (loader *Loader) resolveComponent(doc *T, ref string, path *url.URL, resolved any) ( componentDoc *T, componentPath *url.URL, @@ -494,6 +524,8 @@ func drillIntoField(cursor any, fieldName string) (any, error) { } } +var errAlreadyResolved = errors.New("already resolved") + func (loader *Loader) resolveRef(doc *T, ref string, path *url.URL) (*T, string, *url.URL, error) { if ref != "" && ref[0] == '#' { return doc, ref, path, nil @@ -535,17 +567,10 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat return errMUSTHeader } - if component.Value != nil { - if loader.visitedHeader == nil { - loader.visitedHeader = make(map[*Header]struct{}) - } - if _, ok := loader.visitedHeader[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedHeader[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var header Header if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &header); err != nil { @@ -554,8 +579,15 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat component.Value = &header component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Header) + }) { + return nil + } var resolved HeaderRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -587,17 +619,10 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum return errMUSTParameter } - if component.Value != nil { - if loader.visitedParameter == nil { - loader.visitedParameter = make(map[*Parameter]struct{}) - } - if _, ok := loader.visitedParameter[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedParameter[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var param Parameter if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, ¶m); err != nil { @@ -606,8 +631,15 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum component.Value = ¶m component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Parameter) + }) { + return nil + } var resolved ParameterRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -649,17 +681,10 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d return errMUSTRequestBody } - if component.Value != nil { - if loader.visitedRequestBody == nil { - loader.visitedRequestBody = make(map[*RequestBody]struct{}) - } - if _, ok := loader.visitedRequestBody[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedRequestBody[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var requestBody RequestBody if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &requestBody); err != nil { @@ -668,8 +693,15 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d component.Value = &requestBody component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*RequestBody) + }) { + return nil + } var resolved RequestBodyRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -718,17 +750,10 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen return errMUSTResponse } - if component.Value != nil { - if loader.visitedResponse == nil { - loader.visitedResponse = make(map[*Response]struct{}) - } - if _, ok := loader.visitedResponse[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedResponse[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var resp Response if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resp); err != nil { @@ -737,8 +762,15 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen component.Value = &resp component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Response) + }) { + return nil + } var resolved ResponseRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -798,17 +830,10 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat return errMUSTSchema } - if component.Value != nil { - if loader.visitedSchema == nil { - loader.visitedSchema = make(map[*Schema]struct{}) - } - if _, ok := loader.visitedSchema[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedSchema[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var schema Schema if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &schema); err != nil { @@ -817,14 +842,15 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = &schema component.refPath = *documentPath } else { - if visitedLimit(visited, ref) { - visited = append(visited, ref) - return fmt.Errorf("%s with length %d - %s", CircularReferenceError, len(visited), strings.Join(visited, " -> ")) + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Schema) + }) { + return nil } - visited = append(visited, ref) - + loader.visitRef(ref) var resolved SchemaRef doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -837,10 +863,6 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat component.Value = resolved.Value component.refPath = resolved.refPath } - if loader.visitedSchema == nil { - loader.visitedSchema = make(map[*Schema]struct{}) - } - loader.visitedSchema[component.Value] = struct{}{} } value := component.Value if value == nil { @@ -891,17 +913,10 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme return errMUSTSecurityScheme } - if component.Value != nil { - if loader.visitedSecurityScheme == nil { - loader.visitedSecurityScheme = make(map[*SecurityScheme]struct{}) - } - if _, ok := loader.visitedSecurityScheme[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedSecurityScheme[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var scheme SecurityScheme if _, err = loader.loadSingleElementFromURI(ref, documentPath, &scheme); err != nil { @@ -911,7 +926,14 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme component.refPath = *documentPath } else { var resolved SecuritySchemeRef + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*SecurityScheme) + }) { + return nil + } + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -929,21 +951,10 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme } func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentPath *url.URL) (err error) { - if component.isEmpty() { - return errMUSTExample - } - - if component.Value != nil { - if loader.visitedExample == nil { - loader.visitedExample = make(map[*Example]struct{}) - } - if _, ok := loader.visitedExample[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedExample[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var example Example if _, err = loader.loadSingleElementFromURI(ref, documentPath, &example); err != nil { @@ -952,8 +963,15 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP component.Value = &example component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Example) + }) { + return nil + } var resolved ExampleRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -975,17 +993,10 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen return errMUSTCallback } - if component.Value != nil { - if loader.visitedCallback == nil { - loader.visitedCallback = make(map[*Callback]struct{}) - } - if _, ok := loader.visitedCallback[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedCallback[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var resolved Callback if documentPath, err = loader.loadSingleElementFromURI(ref, documentPath, &resolved); err != nil { @@ -994,8 +1005,15 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen component.Value = &resolved component.refPath = *documentPath } else { + if !loader.shouldVisitRef(ref, func(value any) { + component.Value = value.(*Callback) + }) { + return nil + } var resolved CallbackRef + loader.visitRef(ref) doc, componentPath, err := loader.resolveComponent(doc, ref, documentPath, &resolved) + defer loader.unvisitRef(ref, resolved.Value) if err != nil { return err } @@ -1027,17 +1045,10 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u return errMUSTLink } - if component.Value != nil { - if loader.visitedLink == nil { - loader.visitedLink = make(map[*Link]struct{}) - } - if _, ok := loader.visitedLink[component.Value]; ok { + if ref := component.Ref; ref != "" { + if component.Value != nil { return nil } - loader.visitedLink[component.Value] = struct{}{} - } - - if ref := component.Ref; ref != "" { if isSingleRefElement(ref) { var link Link if _, err = loader.loadSingleElementFromURI(ref, documentPath, &link); err != nil { @@ -1126,16 +1137,3 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat func unescapeRefString(ref string) string { return strings.Replace(strings.Replace(ref, "~1", "/", -1), "~0", "~", -1) } - -func visitedLimit(visited []string, ref string) bool { - visitedCount := 0 - for _, v := range visited { - if v == ref { - visitedCount++ - if visitedCount >= CircularReferenceCounter { - return true - } - } - } - return false -} diff --git a/openapi3/testdata/circularRef/base.yml b/openapi3/testdata/circularRef/base.yml index ff8240eb0..897a45f37 100644 --- a/openapi3/testdata/circularRef/base.yml +++ b/openapi3/testdata/circularRef/base.yml @@ -14,3 +14,5 @@ components: properties: foo: $ref: "#/components/schemas/Foo" + Baz: + $ref: "./baz.yml#/BazNested" diff --git a/openapi3/testdata/circularRef/baz.yml b/openapi3/testdata/circularRef/baz.yml new file mode 100644 index 000000000..fb8c85420 --- /dev/null +++ b/openapi3/testdata/circularRef/baz.yml @@ -0,0 +1,9 @@ +BazNested: + type: object + properties: + baz: + $ref: "#/BazNested" + bazArray: + type: array + items: + $ref: "#/BazNested"