diff --git a/openapi3gen/openapi3gen.go b/openapi3gen/openapi3gen.go index f2a45c4e0..4387727f7 100644 --- a/openapi3gen/openapi3gen.go +++ b/openapi3gen/openapi3gen.go @@ -18,6 +18,11 @@ type CycleError struct{} func (err *CycleError) Error() string { return "detected cycle" } +// ExcludeSchemaSentinel indicates that the schema for a specific field should not be included in the final output. +type ExcludeSchemaSentinel struct{} + +func (err *ExcludeSchemaSentinel) Error() string { return "schema excluded" } + // Option allows tweaking SchemaRef generation type Option func(*generatorOpt) @@ -25,7 +30,10 @@ type Option func(*generatorOpt) // the OpenAPI schema definition to be updated with additional // properties during the generation process, based on the // name of the field, the Go type, and the struct tags. -// name will be "_root" for the top level object, and tag will be "" +// name will be "_root" for the top level object, and tag will be "". +// A SchemaCustomizerFn can return an ExcludeSchemaSentinel error to +// indicate that the schema for this field should not be included in +// the final output type SchemaCustomizerFn func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error type generatorOpt struct { @@ -117,6 +125,10 @@ func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect return ref, nil } ref, err := g.generateWithoutSaving(parents, t, name, tag) + if _, ok := err.(*ExcludeSchemaSentinel); ok { + // This schema should not be included in the final output + return nil, nil + } if err != nil { return nil, err } diff --git a/openapi3gen/openapi3gen_test.go b/openapi3gen/openapi3gen_test.go index 3640d4c0f..a85f5da4a 100644 --- a/openapi3gen/openapi3gen_test.go +++ b/openapi3gen/openapi3gen_test.go @@ -468,6 +468,30 @@ func TestSchemaCustomizerError(t *testing.T) { require.EqualError(t, err, "test error") } +func TestSchemaCustomizerExcludeSchema(t *testing.T) { + type Bla struct { + Str string + } + + customizer := openapi3gen.SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + return nil + }) + schema, err := openapi3gen.NewSchemaRefForValue(&Bla{}, nil, openapi3gen.UseAllExportedFields(), customizer) + require.NoError(t, err) + require.Equal(t, &openapi3.SchemaRef{Value: &openapi3.Schema{ + Type: "object", + Properties: map[string]*openapi3.SchemaRef{ + "Str": {Value: &openapi3.Schema{Type: "string"}}, + }}}, schema) + + customizer = openapi3gen.SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + return &openapi3gen.ExcludeSchemaSentinel{} + }) + schema, err = openapi3gen.NewSchemaRefForValue(&Bla{}, nil, openapi3gen.UseAllExportedFields(), customizer) + require.NoError(t, err) + require.Nil(t, schema) +} + func ExampleNewSchemaRefForValue_recursive() { type RecursiveType struct { Field1 string `json:"field1"`