diff --git a/examples/internal/clients/echo/api/swagger.yaml b/examples/internal/clients/echo/api/swagger.yaml index 63642265c75..487d856826e 100644 --- a/examples/internal/clients/echo/api/swagger.yaml +++ b/examples/internal/clients/echo/api/swagger.yaml @@ -449,13 +449,6 @@ paths: type: "string" x-exportParamName: "StatusNote" x-optionalDataType: "String" - - name: "en" - in: "query" - required: false - type: "string" - format: "int64" - x-exportParamName: "En" - x-optionalDataType: "String" responses: 200: description: "A successful response." diff --git a/examples/internal/clients/echo/api_echo_service.go b/examples/internal/clients/echo/api_echo_service.go index c600199ccc2..0b69393825b 100644 --- a/examples/internal/clients/echo/api_echo_service.go +++ b/examples/internal/clients/echo/api_echo_service.go @@ -837,7 +837,6 @@ EchoServiceApiService EchoBody method receives a simple message and returns it. * @param "Lang" (optional.String) - * @param "StatusProgress" (optional.String) - * @param "StatusNote" (optional.String) - - * @param "En" (optional.String) - @return ExamplepbSimpleMessage */ @@ -848,7 +847,6 @@ type EchoServiceEchoBody2Opts struct { Lang optional.String StatusProgress optional.String StatusNote optional.String - En optional.String } func (a *EchoServiceApiService) EchoServiceEchoBody2(ctx context.Context, id string, no ExamplepbEmbedded, localVarOptionals *EchoServiceEchoBody2Opts) (ExamplepbSimpleMessage, *http.Response, error) { @@ -883,9 +881,6 @@ func (a *EchoServiceApiService) EchoServiceEchoBody2(ctx context.Context, id str if localVarOptionals != nil && localVarOptionals.StatusNote.IsSet() { localVarQueryParams.Add("status.note", parameterToString(localVarOptionals.StatusNote.Value(), "")) } - if localVarOptionals != nil && localVarOptionals.En.IsSet() { - localVarQueryParams.Add("en", parameterToString(localVarOptionals.En.Value(), "")) - } // to determine the Content-Type header localVarHttpContentTypes := []string{"application/json"} diff --git a/examples/internal/proto/examplepb/echo_service.swagger.json b/examples/internal/proto/examplepb/echo_service.swagger.json index fa22d28b47d..2f9fba21225 100644 --- a/examples/internal/proto/examplepb/echo_service.swagger.json +++ b/examples/internal/proto/examplepb/echo_service.swagger.json @@ -535,13 +535,6 @@ "in": "query", "required": false, "type": "string" - }, - { - "name": "en", - "in": "query", - "required": false, - "type": "string", - "format": "int64" } ], "tags": [ diff --git a/protoc-gen-openapiv2/internal/genopenapi/generator_test.go b/protoc-gen-openapiv2/internal/genopenapi/generator_test.go index 559259ca2ea..70af1944643 100644 --- a/protoc-gen-openapiv2/internal/genopenapi/generator_test.go +++ b/protoc-gen-openapiv2/internal/genopenapi/generator_test.go @@ -1182,6 +1182,108 @@ func TestGenerateRPCOrderPreservedAdditionalBindings(t *testing.T) { } } +func TestGenerateRPCOneOfFieldBodyAdditionalBindings(t *testing.T) { + t.Parallel() + + const in = ` + file_to_generate: "exampleproto/v1/example.proto" + parameter: "output_format=yaml,allow_delete_body=true" + proto_file: { + name: "exampleproto/v1/example.proto" + package: "example.v1" + message_type: { + name: "Foo" + oneof_decl: { + name: "foo" + } + field: { + name: "bar" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "bar" + oneof_index: 0 + } + field: { + name: "baz" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + json_name: "bar" + oneof_index: 0 + } + } + service: { + name: "TestService" + method: { + name: "Test1" + input_type: ".example.v1.Foo" + output_type: ".example.v1.Foo" + options: { + [google.api.http]: { + post: "/b/foo" + body: "*" + additional_bindings { + post: "/b/foo/bar" + body: "bar" + } + additional_bindings { + post: "/b/foo/baz" + body: "baz" + } + } + } + } + } + options: { + go_package: "exampleproto/v1;exampleproto" + } + }` + + var req pluginpb.CodeGeneratorRequest + if err := prototext.Unmarshal([]byte(in), &req); err != nil { + t.Fatalf("failed to marshall yaml: %s", err) + } + + formats := [...]genopenapi.Format{ + genopenapi.FormatJSON, + genopenapi.FormatYAML, + } + + for _, format := range formats { + format := format + t.Run(string(format), func(t *testing.T) { + t.Parallel() + + resp := requireGenerate(t, &req, format, true, false) + if len(resp) != 1 { + t.Fatalf("invalid count, expected: 1, actual: %d", len(resp)) + } + + content := resp[0].GetContent() + + t.Log(content) + + contentsSlice := strings.Fields(content) + expectedPaths := []string{"/b/foo", "/b/foo/bar", "/b/foo/baz"} + + foundPaths := []string{} + for _, contentValue := range contentsSlice { + findExpectedPaths(&foundPaths, expectedPaths, contentValue) + } + + if allPresent := reflect.DeepEqual(foundPaths, expectedPaths); !allPresent { + t.Fatalf("Found paths differed from expected paths. Got: %#v, want %#v", foundPaths, expectedPaths) + } + + // The input message only contains oneof fields, so no other fields should be mapped to the query. + if strings.Contains(content, "query") { + t.Fatalf("Found query in content, expected not to find any") + } + }) + } +} + func TestGenerateRPCOrderNotPreservedAdditionalBindings(t *testing.T) { t.Parallel() @@ -1694,6 +1796,12 @@ func TestGenerateRPCOrderNotPreservedMergeFilesAdditionalBindingsMultipleService func findExpectedPaths(foundPaths *[]string, expectedPaths []string, potentialPath string) { seenPaths := map[string]struct{}{} + // foundPaths may not be empty when this function is called multiple times, + // so we add them to seenPaths map to avoid duplicates. + for _, path := range *foundPaths { + seenPaths[path] = struct{}{} + } + for _, path := range expectedPaths { _, pathAlreadySeen := seenPaths[path] if strings.Contains(potentialPath, path) && !pathAlreadySeen { diff --git a/protoc-gen-openapiv2/internal/genopenapi/template.go b/protoc-gen-openapiv2/internal/genopenapi/template.go index 7b0ac902e5d..888aeaf2d08 100644 --- a/protoc-gen-openapiv2/internal/genopenapi/template.go +++ b/protoc-gen-openapiv2/internal/genopenapi/template.go @@ -167,6 +167,11 @@ func getEnumDefaultNumber(reg *descriptor.Registry, enum *descriptor.Enum) inter // messageToQueryParameters converts a message to a list of OpenAPI query parameters. func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, httpMethod string) (params []openapiParameterObject, err error) { for _, field := range message.Fields { + // When body is set to oneof field, we want to skip other fields in the oneof group. + if isBodySameOneOf(body, field) { + continue + } + if !isVisible(getFieldVisibilityOption(field), reg) { continue } @@ -183,6 +188,22 @@ func messageToQueryParameters(message *descriptor.Message, reg *descriptor.Regis return params, nil } +func isBodySameOneOf(body *descriptor.Body, field *descriptor.Field) bool { + if field.OneofIndex == nil { + return false + } + + if body == nil || len(body.FieldPath) == 0 { + return false + } + + if body.FieldPath[0].Target.OneofIndex == nil { + return false + } + + return *body.FieldPath[0].Target.OneofIndex == *field.OneofIndex +} + // queryParams converts a field to a list of OpenAPI query parameters recursively through the use of nestedQueryParams. func queryParams(message *descriptor.Message, field *descriptor.Field, prefix string, reg *descriptor.Registry, pathParams []descriptor.Parameter, body *descriptor.Body, recursiveCount int) (params []openapiParameterObject, err error) { return nestedQueryParams(message, field, prefix, reg, pathParams, body, newCycleChecker(recursiveCount))