diff --git a/protoc-gen-grpc-gateway/internal/gengateway/generator.go b/protoc-gen-grpc-gateway/internal/gengateway/generator.go index 7646f6ae566..c1ea1189a5f 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/generator.go @@ -32,12 +32,13 @@ type generator struct { useRequestContext bool registerFuncSuffix string pathType pathType + modulePath string allowPatchFeature bool standalone bool } // New returns a new generator which generates grpc gateway files. -func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString string, +func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString, modulePathString string, allowPatchFeature, standalone bool) gen.Generator { var imports []descriptor.GoPackage for _, pkgpath := range []string{ @@ -85,6 +86,7 @@ func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, p useRequestContext: useRequestContext, registerFuncSuffix: registerFuncSuffix, pathType: pathType, + modulePath: modulePathString, allowPatchFeature: allowPatchFeature, standalone: standalone, } @@ -107,9 +109,10 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*pluginpb.CodeGenera glog.Errorf("%v: %s", err, code) return nil, err } - name := file.GetName() - if g.pathType == pathTypeImport && file.GoPkg.Path != "" { - name = fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)) + name, err := g.getFilePath(file) + if err != nil { + glog.Errorf("%v: %s", err, code) + return nil, err } ext := filepath.Ext(name) base := strings.TrimSuffix(name, ext) @@ -123,6 +126,27 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*pluginpb.CodeGenera return files, nil } +func (g *generator) getFilePath(file *descriptor.File) (string, error) { + name := file.GetName() + switch { + case g.modulePath != "" && g.pathType != pathTypeImport: + return "", errors.New("cannot use module= with paths=") + + case g.modulePath != "": + trimPath, pkgPath := g.modulePath+"/", file.GoPkg.Path+"/" + if !strings.HasPrefix(pkgPath, trimPath) { + return "", fmt.Errorf("%v: file go path does not match module prefix: %v", file.GoPkg.Path, trimPath) + } + return filepath.Join(strings.TrimPrefix(pkgPath, trimPath), filepath.Base(name)), nil + + case g.pathType == pathTypeImport && file.GoPkg.Path != "": + return fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)), nil + + default: + return name, nil + } +} + func (g *generator) generate(file *descriptor.File) (string, error) { pkgSeen := make(map[string]bool) var imports []descriptor.GoPackage diff --git a/protoc-gen-grpc-gateway/internal/gengateway/generator_test.go b/protoc-gen-grpc-gateway/internal/gengateway/generator_test.go index 052cd07cc8a..e4816508ba3 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/generator_test.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/generator_test.go @@ -1,6 +1,7 @@ package gengateway import ( + "errors" "path/filepath" "strings" "testing" @@ -100,9 +101,11 @@ func TestGenerateServiceWithoutBindings(t *testing.T) { func TestGenerateOutputPath(t *testing.T) { cases := []struct { - file *descriptor.File - pathType pathType - expected string + file *descriptor.File + pathType pathType + modulePath string + expected string + expectedError error }{ { file: newExampleFileDescriptorWithGoPkg( @@ -142,13 +145,79 @@ func TestGenerateOutputPath(t *testing.T) { pathType: pathTypeSourceRelative, expected: ".", }, + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/root", + Name: "example_pb", + }, + ), + modulePath: "example.com/path/root", + expected: ".", + }, + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/to/example", + Name: "example_pb", + }, + ), + modulePath: "example.com/path/to", + expected: "example", + }, + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/to/example/with/many/nested/paths", + Name: "example_pb", + }, + ), + modulePath: "example.com/path/to", + expected: "example/with/many/nested/paths", + }, + + // Error cases + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/root", + Name: "example_pb", + }, + ), + modulePath: "example.com/path/root", + pathType: pathTypeSourceRelative, // Not allowed + expectedError: errors.New("cannot use module= with paths="), + }, + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/rootextra", + Name: "example_pb", + }, + ), + modulePath: "example.com/path/root", + expectedError: errors.New("example.com/path/rootextra: file go path does not match module prefix: example.com/path/root/"), + }, } for _, c := range cases { - g := &generator{pathType: c.pathType} + g := &generator{ + pathType: c.pathType, + modulePath: c.modulePath, + } file := c.file gots, err := g.Generate([]*descriptor.File{crossLinkFixture(file)}) + + // If we expect an error response, check it matches what we want + if c.expectedError != nil { + if err == nil || err.Error() != c.expectedError.Error() { + t.Errorf("Generate(%#v) failed with %v; wants error of: %v", file, err, c.expectedError) + } + return + } + + // Handle case where we don't expect an error if err != nil { t.Errorf("Generate(%#v) failed with %v; wants success", file, err) return diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 64ca6dd1fa4..c7fd89233f7 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -30,6 +30,7 @@ var ( allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") grpcAPIConfiguration = flag.String("grpc_api_configuration", "", "path to gRPC API Configuration in YAML format") pathType = flag.String("paths", "", "specifies how the paths of generated files are structured") + modulePath = flag.String("module", "", "specifies a module prefix that will be stripped from the go package to determine the output directory") allowRepeatedFieldsInBody = flag.Bool("allow_repeated_fields_in_body", false, "allows to use repeated field in `body` and `response_body` field of `google.api.http` annotation option") repeatedPathParamSeparator = flag.String("repeated_path_param_separator", "csv", "configures how repeated fields should be split. Allowed values are `csv`, `pipes`, `ssv` and `tsv`.") allowPatchFeature = flag.Bool("allow_patch_feature", true, "determines whether to use PATCH feature involving update masks (using google.protobuf.FieldMask).") @@ -81,7 +82,7 @@ func main() { } } - g := gengateway.New(reg, *useRequestContext, *registerFuncSuffix, *pathType, *allowPatchFeature, *standalone) + g := gengateway.New(reg, *useRequestContext, *registerFuncSuffix, *pathType, *modulePath, *allowPatchFeature, *standalone) if *grpcAPIConfiguration != "" { if err := reg.LoadGrpcAPIServiceFromYAML(*grpcAPIConfiguration); err != nil {