From f709571e38c297cd1ad86b16984688fbcb236489 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Mon, 4 May 2015 10:47:57 +0900 Subject: [PATCH] Add intermediate representation of services and messages --- protoc-gen-grpc-gateway/descriptor/name.go | 18 + .../descriptor/registry.go | 204 +++++ .../descriptor/registry_test.go | 346 ++++++++ .../descriptor/services.go | 255 ++++++ .../descriptor/services_test.go | 803 ++++++++++++++++++ protoc-gen-grpc-gateway/descriptor/types.go | 274 ++++++ .../descriptor/types_test.go | 198 +++++ 7 files changed, 2098 insertions(+) create mode 100644 protoc-gen-grpc-gateway/descriptor/name.go create mode 100644 protoc-gen-grpc-gateway/descriptor/registry.go create mode 100644 protoc-gen-grpc-gateway/descriptor/registry_test.go create mode 100644 protoc-gen-grpc-gateway/descriptor/services.go create mode 100644 protoc-gen-grpc-gateway/descriptor/services_test.go create mode 100644 protoc-gen-grpc-gateway/descriptor/types.go create mode 100644 protoc-gen-grpc-gateway/descriptor/types_test.go diff --git a/protoc-gen-grpc-gateway/descriptor/name.go b/protoc-gen-grpc-gateway/descriptor/name.go new file mode 100644 index 00000000000..78234461a73 --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/name.go @@ -0,0 +1,18 @@ +package descriptor + +import ( + "regexp" + "strings" +) + +var ( + upperPattern = regexp.MustCompile("[A-Z]") +) + +func toCamel(str string) string { + var components []string + for _, c := range strings.Split(str, "_") { + components = append(components, strings.Title(strings.ToLower(c))) + } + return strings.Join(components, "") +} diff --git a/protoc-gen-grpc-gateway/descriptor/registry.go b/protoc-gen-grpc-gateway/descriptor/registry.go new file mode 100644 index 00000000000..27a67338934 --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/registry.go @@ -0,0 +1,204 @@ +package descriptor + +import ( + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/golang/glog" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" + plugin "github.com/golang/protobuf/protoc-gen-go/plugin" +) + +// Registry is a registry of information extracted from plugin.CodeGeneratorRequest. +type Registry struct { + // msgs is a mapping from fully-qualified message name to descriptor + msgs map[string]*Message + + // files is a mapping from file path to descriptor + files map[string]*File + + // prefix is a prefix to be inserted to golang pacakge paths generated from proto package names. + prefix string + + // pkgMap is a user-specified mapping from file path to proto package. + pkgMap map[string]string + + // pkgAliases is a mapping from package aliases to package paths in go which are already taken. + pkgAliases map[string]string +} + +// NewRegistry returns a new Registry. +func NewRegistry() *Registry { + return &Registry{ + msgs: make(map[string]*Message), + files: make(map[string]*File), + pkgMap: make(map[string]string), + pkgAliases: map[string]string{ + // TODO(yugui) Move this initialization to generators. + "json": "encoding/json", + "io": "io", + "http": "net/http", + "runtime": "runtime", + "glog": "github.com/golang/glog", + "proto": "github.com/golang/protobuf/proto", + "context": "golang.org/x/net/context", + "grpc": "google.golang.org/grpc", + "codes": "google.golang.org/grpc/codes", + }, + } +} + +// Load loads definitions of services, methods, messages and fields from "req". +func (r *Registry) Load(req *plugin.CodeGeneratorRequest) error { + for _, file := range req.GetProtoFile() { + r.loadFile(file) + } + for _, target := range req.FileToGenerate { + if err := r.loadServices(target); err != nil { + return err + } + } + return nil +} + +// loadFile loads messages and fiels from "file". +// It does not loads services and methods in "file". You need to call +// loadServices after loadFiles is called for all files to load services and methods. +func (r *Registry) loadFile(file *descriptor.FileDescriptorProto) { + pkg := GoPackage{ + Path: r.goPackagePath(file), + Name: defaultGoPackageName(file), + } + if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil { + for i := 0; ; i++ { + alias := fmt.Sprintf("%s_%d", pkg.Name, i) + if err := r.ReserveGoPackageAlias(alias, pkg.Path); err == nil { + pkg.Alias = alias + break + } + } + } + f := &File{ + FileDescriptorProto: file, + GoPkg: pkg, + } + + r.files[file.GetName()] = f + r.registerMsg(f, nil, file.GetMessageType()) +} + +func (r *Registry) registerMsg(file *File, outerPath []string, msgs []*descriptor.DescriptorProto) { + for _, md := range msgs { + m := &Message{ + File: file, + Outers: outerPath, + DescriptorProto: md, + } + for _, fd := range md.GetField() { + m.Fields = append(m.Fields, &Field{ + Message: m, + FieldDescriptorProto: fd, + }) + } + file.Messages = append(file.Messages, m) + r.msgs[m.FQMN()] = m + glog.Infof("register name: %s", m.FQMN()) + + var outers []string + outers = append(outers, outerPath...) + outers = append(outers, m.GetName()) + r.registerMsg(file, outers, m.GetNestedType()) + } +} + +// LookupMsg looks up a message type by "name". +// It tries to resolve "name" from "location" if "name" is a relative message name. +func (r *Registry) LookupMsg(location, name string) (*Message, error) { + glog.Infof("lookup %s from %s", name, location) + if strings.HasPrefix(name, ".") { + m, ok := r.msgs[name] + if !ok { + return nil, fmt.Errorf("no message found: %s", name) + } + return m, nil + } + + if !strings.HasPrefix(location, ".") { + location = fmt.Sprintf(".%s", location) + } + components := strings.Split(location, ".") + for len(components) > 0 { + fqmn := strings.Join(append(components, name), ".") + if m, ok := r.msgs[fqmn]; ok { + return m, nil + } + components = components[:len(components)-1] + } + return nil, fmt.Errorf("no message found: %s", name) +} + +func (r *Registry) LookupFile(name string) (*File, error) { + f, ok := r.files[name] + if !ok { + return nil, fmt.Errorf("no such file given: %s", name) + } + return f, nil +} + +// AddPkgMap adds a mapping from a .proto file to proto package name. +func (r *Registry) AddPkgMap(file, protoPkg string) { + r.pkgMap[file] = protoPkg +} + +// SetPrefix registeres the perfix to be added to go package paths generated from proto package names. +func (r *Registry) SetPrefix(prefix string) { + r.prefix = prefix +} + +// ReserveGoPackageAlias reserves the unique alias of go package. +// If succeeded, the alias will be never used for other packages in generated go files. +// If failed, the alias is already taken by another package, so you need to use another +// alias for the package in your go files. +func (r *Registry) ReserveGoPackageAlias(alias, pkgpath string) error { + if taken, ok := r.pkgAliases[alias]; ok { + if taken == pkgpath { + return nil + } + return fmt.Errorf("package name %s is already taken. Use another alias", alias) + } + r.pkgAliases[alias] = pkgpath + return nil +} + +// goPackagePath returns the go package path which go files generated from "f" should have. +// It respects the mapping registered by AddPkgMap if exists. Or it generates a path from +// the file name of "f" if otherwise. +func (r *Registry) goPackagePath(f *descriptor.FileDescriptorProto) string { + name := f.GetName() + if pkg, ok := r.pkgMap[name]; ok { + return path.Join(r.prefix, pkg) + } + + ext := filepath.Ext(name) + if ext == ".protodevel" || ext == ".proto" { + name = strings.TrimSuffix(name, ext) + } + return path.Join(r.prefix, fmt.Sprintf("%s.pb", name)) +} + +// defaultGoPackageName returns the default go package name to be used for go files generated from "f". +// You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias. +func defaultGoPackageName(f *descriptor.FileDescriptorProto) string { + if f.Options != nil && f.Options.GoPackage != nil { + return f.Options.GetGoPackage() + } + + if f.Package == nil { + base := filepath.Base(f.GetName()) + ext := filepath.Ext(base) + return strings.TrimSuffix(base, ext) + } + return strings.Replace(f.GetPackage(), ".", "_", -1) +} diff --git a/protoc-gen-grpc-gateway/descriptor/registry_test.go b/protoc-gen-grpc-gateway/descriptor/registry_test.go new file mode 100644 index 00000000000..5261f53f2d6 --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/registry_test.go @@ -0,0 +1,346 @@ +package descriptor + +import ( + "testing" + + "github.com/golang/protobuf/proto" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +func load(t *testing.T, reg *Registry, src string) *descriptor.FileDescriptorProto { + var file descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &file); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &file) failed with %v; want success", err) + } + reg.loadFile(&file) + return &file +} + +func TestLoadFile(t *testing.T) { + reg := NewRegistry() + fd := load(t, reg, ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'str' + label: LABEL_OPTIONAL + type: TYPE_STRING + number: 1 + > + > + `) + + file := reg.files["example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "example.pb", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } + + msg, err := reg.LookupMsg("", ".example.ExampleMessage") + if err != nil { + t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", "", ".example.ExampleMessage") + return + } + if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want { + t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", got, want) + } + if got, want := msg.File, file; got != want { + t.Errorf("msg.File = %v; want %v", got, want) + } + if got := msg.Outers; got != nil { + t.Errorf("msg.Outers = %v; want %v", got, nil) + } + if got, want := len(msg.Fields), 1; got != want { + t.Errorf("len(msg.Fields) = %d; want %d", got, want) + } else if got, want := msg.Fields[0].FieldDescriptorProto, fd.MessageType[0].Field[0]; got != want { + t.Errorf("msg.Fields[0].FieldDescriptorProto = %v; want %v", got, want) + } else if got, want := msg.Fields[0].Message, msg; got != want { + t.Errorf("msg.Fields[0].Message = %v; want %v", got, want) + } + + if got, want := len(file.Messages), 1; got != want { + t.Errorf("file.Meeesages = %#v; want %#v", file.Messages, []*Message{msg}) + } + if got, want := file.Messages[0], msg; got != want { + t.Errorf("file.Meeesages[0] = %v; want %v", got, want) + } +} + +func TestLoadFileNestedPackage(t *testing.T) { + reg := NewRegistry() + load(t, reg, ` + name: 'example.proto' + package: 'example.nested.nested2' + `) + + file := reg.files["example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "example.pb", Name: "example_nested_nested2"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithDir(t *testing.T) { + reg := NewRegistry() + load(t, reg, ` + name: 'path/to/example.proto' + package: 'example' + `) + + file := reg.files["path/to/example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "path/to/example.pb", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithoutPackage(t *testing.T) { + reg := NewRegistry() + load(t, reg, ` + name: 'path/to/example_file.proto' + `) + + file := reg.files["path/to/example_file.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "path/to/example_file.pb", Name: "example_file"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithMapping(t *testing.T) { + reg := NewRegistry() + reg.AddPkgMap("path/to/example.proto", "example.com/proj/example/proto") + load(t, reg, ` + name: 'path/to/example.proto' + package: 'example' + `) + + file := reg.files["path/to/example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "example.com/proj/example/proto", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithPackageNameCollision(t *testing.T) { + reg := NewRegistry() + load(t, reg, ` + name: 'path/to/another.proto' + package: 'example' + `) + load(t, reg, ` + name: 'path/to/example.proto' + package: 'example' + `) + if err := reg.ReserveGoPackageAlias("ioutil", "io/ioutil"); err != nil { + t.Fatalf("reg.ReserveGoPackageAlias(%q) failed with %v; want success", "ioutil", err) + } + load(t, reg, ` + name: 'path/to/ioutil.proto' + package: 'ioutil' + `) + + file := reg.files["path/to/another.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/another.proto") + return + } + wantPkg := GoPackage{Path: "path/to/another.pb", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } + + file = reg.files["path/to/example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/example.proto") + return + } + wantPkg = GoPackage{Path: "path/to/example.pb", Name: "example", Alias: "example_0"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } + + file = reg.files["path/to/ioutil.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/ioutil.proto") + return + } + wantPkg = GoPackage{Path: "path/to/ioutil.pb", Name: "ioutil", Alias: "ioutil_0"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithIdenticalGoPkg(t *testing.T) { + reg := NewRegistry() + reg.AddPkgMap("path/to/another.proto", "example.com/example") + reg.AddPkgMap("path/to/example.proto", "example.com/example") + load(t, reg, ` + name: 'path/to/another.proto' + package: 'example' + `) + load(t, reg, ` + name: 'path/to/example.proto' + package: 'example' + `) + + file := reg.files["path/to/example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "example.com/example", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } + + file = reg.files["path/to/another.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg = GoPackage{Path: "example.com/example", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLoadFileWithPrefix(t *testing.T) { + reg := NewRegistry() + reg.SetPrefix("third_party") + load(t, reg, ` + name: 'path/to/example.proto' + package: 'example' + `) + + file := reg.files["path/to/example.proto"] + if file == nil { + t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") + return + } + wantPkg := GoPackage{Path: "third_party/path/to/example.pb", Name: "example"} + if got, want := file.GoPkg, wantPkg; got != want { + t.Errorf("file.GoPkg = %#v; want %#v", got, want) + } +} + +func TestLookupMsgWithoutPackage(t *testing.T) { + reg := NewRegistry() + fd := load(t, reg, ` + name: 'example.proto' + message_type < + name: 'ExampleMessage' + field < + name: 'str' + label: LABEL_OPTIONAL + type: TYPE_STRING + number: 1 + > + > + `) + + msg, err := reg.LookupMsg("", ".ExampleMessage") + if err != nil { + t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", "", ".ExampleMessage") + return + } + if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want { + t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", got, want) + } +} + +func TestLookupMsgWithNestedPackage(t *testing.T) { + reg := NewRegistry() + fd := load(t, reg, ` + name: 'example.proto' + package: 'nested.nested2.mypackage' + message_type < + name: 'ExampleMessage' + field < + name: 'str' + label: LABEL_OPTIONAL + type: TYPE_STRING + number: 1 + > + > + `) + + for _, name := range []string{ + "nested.nested2.mypackage.ExampleMessage", + "nested2.mypackage.ExampleMessage", + "mypackage.ExampleMessage", + "ExampleMessage", + } { + msg, err := reg.LookupMsg("nested.nested2.mypackage", name) + if err != nil { + t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", ".nested.nested2.mypackage", name, err) + return + } + if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want { + t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", ".nested.nested2.mypackage", name, got, want) + } + } + + for _, loc := range []string{ + ".nested.nested2.mypackage", + "nested.nested2.mypackage", + ".nested.nested2", + "nested.nested2", + ".nested", + "nested", + ".", + "", + "somewhere.else", + } { + name := "nested.nested2.mypackage.ExampleMessage" + msg, err := reg.LookupMsg(loc, name) + if err != nil { + t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", loc, name, err) + return + } + if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want { + t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", loc, name, got, want) + } + } + + for _, loc := range []string{ + ".nested.nested2.mypackage", + "nested.nested2.mypackage", + ".nested.nested2", + "nested.nested2", + ".nested", + "nested", + } { + name := "nested2.mypackage.ExampleMessage" + msg, err := reg.LookupMsg(loc, name) + if err != nil { + t.Errorf("reg.LookupMsg(%q, %q)) failed with %v; want success", loc, name, err) + return + } + if got, want := msg.DescriptorProto, fd.MessageType[0]; got != want { + t.Errorf("reg.lookupMsg(%q, %q).DescriptorProto = %#v; want %#v", loc, name, got, want) + } + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go new file mode 100644 index 00000000000..76cf2f39efc --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -0,0 +1,255 @@ +package descriptor + +import ( + "fmt" + "strings" + + "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/httprule" + options "github.com/gengo/grpc-gateway/third_party/googleapis/google/api" + "github.com/golang/glog" + "github.com/golang/protobuf/proto" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +// loadServices registers services and their methods from "targetFile" to "r". +// It must be called after loadFile is called for all files so that loadServices +// can resolve names of message types and their fields. +func (r *Registry) loadServices(targetFile string) error { + file := r.files[targetFile] + if file == nil { + return fmt.Errorf("no such file: %s", targetFile) + } + var svcs []*Service + for _, sd := range file.GetService() { + svc := &Service{ + File: file, + ServiceDescriptorProto: sd, + } + for _, md := range sd.GetMethod() { + opts, err := extractAPIOptions(md) + if err != nil { + glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err) + return err + } + if opts == nil { + glog.V(1).Infof("Skip non-target method: %s.%s", svc.GetName(), md.GetName()) + continue + } + meth, err := r.newMethod(svc, md, opts) + if err != nil { + return err + } + svc.Methods = append(svc.Methods, meth) + } + if len(svc.Methods) == 0 { + continue + } + svcs = append(svcs, svc) + } + file.Services = svcs + return nil +} + +func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) { + var ( + httpMethod string + pathTemplate string + ) + switch { + case opts.Get != "": + httpMethod = "GET" + pathTemplate = opts.Get + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) + } + + case opts.Put != "": + httpMethod = "PUT" + pathTemplate = opts.Put + + case opts.Post != "": + httpMethod = "POST" + pathTemplate = opts.Post + + case opts.Delete != "": + httpMethod = "DELETE" + pathTemplate = opts.Delete + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) + } + + case opts.Patch != "": + httpMethod = "PATCH" + pathTemplate = opts.Patch + + case opts.Custom != nil: + httpMethod = opts.Custom.Kind + pathTemplate = opts.Custom.Path + + default: + glog.Errorf("No pattern specified in google.api.HttpRule: %s", md.GetName()) + return nil, fmt.Errorf("none of pattern specified") + } + + parsed, err := httprule.Parse(pathTemplate) + if err != nil { + return nil, err + } + tmpl := parsed.Compile() + + if md.GetClientStreaming() && len(tmpl.Fields) > 0 { + return nil, fmt.Errorf("cannot use path parameter in client streaming") + } + + requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType()) + if err != nil { + return nil, err + } + responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType()) + if err != nil { + return nil, err + } + + meth := &Method{ + Service: svc, + MethodDescriptorProto: md, + PathTmpl: tmpl, + HTTPMethod: httpMethod, + RequestType: requestType, + ResponseType: responseType, + } + + for _, f := range tmpl.Fields { + param, err := r.newParam(meth, f) + if err != nil { + return nil, err + } + meth.PathParams = append(meth.PathParams, param) + } + + // TODO(yugui) Handle query params + + meth.Body, err = r.newBody(meth, opts.Body) + if err != nil { + return nil, err + } + + return meth, nil +} + +func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) { + if meth.Options == nil { + return nil, nil + } + if !proto.HasExtension(meth.Options, options.E_Http) { + return nil, nil + } + ext, err := proto.GetExtension(meth.Options, options.E_Http) + if err != nil { + return nil, err + } + opts, ok := ext.(*options.HttpRule) + if !ok { + return nil, fmt.Errorf("extension is %T; want an HttpRule", ext) + } + return opts, nil +} + +func (r *Registry) newParam(meth *Method, path string) (Parameter, error) { + msg := meth.RequestType + fields, err := r.resolveFiledPath(msg, path) + if err != nil { + return Parameter{}, err + } + l := len(fields) + if l == 0 { + return Parameter{}, fmt.Errorf("invalid field access list for %s", path) + } + target := fields[l-1].Target + switch target.GetType() { + case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: + return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path) + } + return Parameter{ + FieldPath: FieldPath(fields), + Method: meth, + Target: fields[l-1].Target, + }, nil +} + +func (r *Registry) newBody(meth *Method, path string) (*Body, error) { + msg := meth.RequestType + switch path { + case "": + return nil, nil + case "*": + return &Body{ + DecoderFactoryExpr: "json.NewDecoder", + DecoderImports: []GoPackage{ + { + Path: "encoding/json", + Name: "json", + }, + }, + FieldPath: nil, + }, nil + } + fields, err := r.resolveFiledPath(msg, path) + if err != nil { + return nil, err + } + return &Body{ + DecoderFactoryExpr: "json.NewDecoder", + DecoderImports: []GoPackage{ + { + Path: "encoding/json", + Name: "json", + }, + }, + FieldPath: FieldPath(fields), + }, nil +} + +// lookupField looks up a field named "name" within "msg". +// It returns nil if no such field found. +func lookupField(msg *Message, name string) *Field { + for _, f := range msg.Fields { + if f.GetName() == name { + return f + } + } + return nil +} + +// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg". +func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathComponent, error) { + if path == "" { + return nil, nil + } + + root := msg + var result []FieldPathComponent + for i, c := range strings.Split(path, ".") { + if i > 0 { + f := result[i-1].Target + switch f.GetType() { + case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: + var err error + msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName()) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path) + } + } + + glog.Infof("Lookup %s in %s", c, msg.FQMN()) + f := lookupField(msg, c) + if f == nil { + return nil, fmt.Errorf("no field %q found in %s", path, root.GetName()) + } + result = append(result, FieldPathComponent{Name: c, Target: f}) + } + return result, nil +} diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go new file mode 100644 index 00000000000..e06ae61e22e --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -0,0 +1,803 @@ +package descriptor + +import ( + "reflect" + "testing" + + "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/httprule" + "github.com/golang/protobuf/proto" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +func compilePath(t *testing.T, path string) httprule.Template { + parsed, err := httprule.Parse(path) + if err != nil { + t.Fatalf("httprule.Parse(%q) failed with %v; want success", path, err) + } + return parsed.Compile() +} + +func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, target string, wantSvcs []*Service) { + reg := NewRegistry() + for _, file := range input { + reg.loadFile(file) + } + err := reg.loadServices(target) + if err != nil { + t.Errorf("loadServices(%q) failed with %v; want success; files=%v", target, err, input) + } + + file := reg.files[target] + svcs := file.Services + var i int + for i = 0; i < len(svcs) && i < len(wantSvcs); i++ { + svc, wantSvc := svcs[i], wantSvcs[i] + if got, want := svc.ServiceDescriptorProto, wantSvc.ServiceDescriptorProto; !proto.Equal(got, want) { + t.Errorf("svcs[%d].ServiceDescriptorProto = %v; want %v; input = %v", i, got, want, input) + continue + } + var j int + for j = 0; j < len(svc.Methods) && j < len(wantSvc.Methods); j++ { + meth, wantMeth := svc.Methods[j], wantSvc.Methods[j] + if got, want := meth.MethodDescriptorProto, wantMeth.MethodDescriptorProto; !proto.Equal(got, want) { + t.Errorf("svcs[%d].Methods[%d].MethodDescriptorProto = %v; want %v; input = %v", i, j, got, want, input) + continue + } + if got, want := meth.PathTmpl, wantMeth.PathTmpl; !reflect.DeepEqual(got, want) { + t.Errorf("svcs[%d].Methods[%d].PathTmpl = %#v; want %#v; input = %v", i, j, got, want, input) + } + if got, want := meth.HTTPMethod, wantMeth.HTTPMethod; got != want { + t.Errorf("svcs[%d].Methods[%d].HTTPMethod = %q; want %q; input = %v", i, j, got, want, input) + } + if got, want := meth.RequestType, wantMeth.RequestType; got.FQMN() != want.FQMN() { + t.Errorf("svcs[%d].Methods[%d].RequestType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) + } + if got, want := meth.ResponseType, wantMeth.ResponseType; got.FQMN() != want.FQMN() { + t.Errorf("svcs[%d].Methods[%d].ResponseType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) + } + + var k int + for k = 0; k < len(meth.PathParams) && k < len(wantMeth.PathParams); k++ { + param, wantParam := meth.PathParams[k], wantMeth.PathParams[k] + if got, want := param.FieldPath.String(), wantParam.FieldPath.String(); got != want { + t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, got, want, input) + continue + } + for l := 0; l < len(param.FieldPath) && l < len(wantParam.FieldPath); l++ { + field, wantField := param.FieldPath[l].Target, wantParam.FieldPath[l].Target + if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { + t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, got, want, input) + } + } + } + for ; k < len(meth.PathParams); k++ { + got := meth.PathParams[k].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].PathParams[%d] = %q; want it to be missing; input = %v", i, j, k, got, input) + } + for ; k < len(wantMeth.PathParams); k++ { + want := wantMeth.PathParams[k].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].PathParams[%d] missing; want %q; input = %v", i, j, k, want, input) + } + + if got, want := (meth.Body != nil), (wantMeth.Body != nil); got != want { + if got { + t.Errorf("svcs[%d].Methods[%d].Body = %q; want it to be missing; input = %v", i, j, meth.Body.FieldPath.String(), input) + } else { + t.Errorf("svcs[%d].Methods[%d].Body missing; want %q; input = %v", i, j, wantMeth.Body.FieldPath.String(), input) + } + } + } + for ; j < len(svc.Methods); j++ { + got := svc.Methods[j].MethodDescriptorProto + t.Errorf("svcs[%d].Methods[%d] = %v; want it to be missing; input = %v", i, j, got, input) + } + for ; j < len(wantSvc.Methods); j++ { + want := wantSvc.Methods[j].MethodDescriptorProto + t.Errorf("svcs[%d].Methods[%d] missing; want %v; input = %v", i, j, want, input) + } + } + for ; i < len(svcs); i++ { + got := svcs[i].ServiceDescriptorProto + t.Errorf("svcs[%d] = %v; want it to be missing; input = %v", i, got, input) + } + for ; i < len(wantSvcs); i++ { + want := wantSvcs[i].ServiceDescriptorProto + t.Errorf("svcs[%d] missing; want %v; input = %v", i, want, input) + } +} + +func crossLinkFixture(f *File) *File { + for _, m := range f.Messages { + m.File = f + for _, f := range m.Fields { + f.Message = m + } + } + for _, svc := range f.Services { + svc.File = f + for _, m := range svc.Methods { + m.Service = svc + for _, param := range m.PathParams { + param.Method = m + } + for _, param := range m.QueryParams { + param.Method = m + } + } + } + return f +} + +func TestExtractServicesSimple(t *testing.T) { + src := ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "*" + > + > + > + > + ` + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + msg := &Message{ + DescriptorProto: fd.MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fd.MessageType[0].Field[0], + }, + }, + } + file := &File{ + FileDescriptorProto: &fd, + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{msg}, + Services: []*Service{ + { + ServiceDescriptorProto: fd.Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fd.Service[0].Method[0], + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + RequestType: msg, + ResponseType: msg, + Body: &Body{ + DecoderFactoryExpr: "json.NewDecoder", + DecoderImports: []GoPackage{ + { + Path: "encoding/json", + Name: "json", + }, + }, + FieldPath: nil, + }, + }, + }, + }, + }, + } + + crossLinkFixture(file) + testExtractServices(t, []*descriptor.FileDescriptorProto{&fd}, "path/to/example.proto", file.Services) +} + +func TestExtractServicesCrossPackage(t *testing.T) { + srcs := []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "ToString" + input_type: ".another.example.BoolMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/to_s" + body: "*" + > + > + > + > + `, ` + name: "path/to/another/example.proto", + package: "another.example" + message_type < + name: "BoolMessage" + field < + name: "bool" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_BOOL + > + > + `, + } + var fds []*descriptor.FileDescriptorProto + for _, src := range srcs { + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + fds = append(fds, &fd) + } + stringMsg := &Message{ + DescriptorProto: fds[0].MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fds[0].MessageType[0].Field[0], + }, + }, + } + boolMsg := &Message{ + DescriptorProto: fds[1].MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fds[1].MessageType[0].Field[0], + }, + }, + } + files := []*File{ + { + FileDescriptorProto: fds[0], + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{stringMsg}, + Services: []*Service{ + { + ServiceDescriptorProto: fds[0].Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fds[0].Service[0].Method[0], + PathTmpl: compilePath(t, "/v1/example/to_s"), + HTTPMethod: "POST", + RequestType: boolMsg, + ResponseType: stringMsg, + Body: &Body{ + DecoderFactoryExpr: "json.NewDecoder", + DecoderImports: []GoPackage{ + { + Path: "encoding/json", + Name: "json", + }, + }, + FieldPath: nil, + }, + }, + }, + }, + }, + }, + { + FileDescriptorProto: fds[1], + GoPkg: GoPackage{ + Path: "path/to/another/example.pb", + Name: "example_pb", + }, + Messages: []*Message{boolMsg}, + }, + } + + for _, file := range files { + crossLinkFixture(file) + } + testExtractServices(t, fds, "path/to/example.proto", files[0].Services) +} + +func TestExtractServicesWithBodyPath(t *testing.T) { + src := ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "OuterMessage" + nested_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + field < + name: "nested" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: "StringMessage" + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "OuterMessage" + output_type: "OuterMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "nested" + > + > + > + > + ` + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + msg := &Message{ + DescriptorProto: fd.MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fd.MessageType[0].Field[0], + }, + }, + } + file := &File{ + FileDescriptorProto: &fd, + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{msg}, + Services: []*Service{ + { + ServiceDescriptorProto: fd.Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fd.Service[0].Method[0], + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + RequestType: msg, + ResponseType: msg, + Body: &Body{ + DecoderFactoryExpr: "json.NewDecoder", + DecoderImports: []GoPackage{ + { + Path: "encoding/json", + Name: "json", + }, + }, + FieldPath: FieldPath{ + { + Name: "nested", + Target: msg.Fields[0], + }, + }, + }, + }, + }, + }, + }, + } + + crossLinkFixture(file) + testExtractServices(t, []*descriptor.FileDescriptorProto{&fd}, "path/to/example.proto", file.Services) +} + +func TestExtractServicesWithPathParam(t *testing.T) { + src := ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + get: "/v1/example/echo/{string=*}" + > + > + > + > + ` + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + msg := &Message{ + DescriptorProto: fd.MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fd.MessageType[0].Field[0], + }, + }, + } + file := &File{ + FileDescriptorProto: &fd, + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{msg}, + Services: []*Service{ + { + ServiceDescriptorProto: fd.Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fd.Service[0].Method[0], + PathTmpl: compilePath(t, "/v1/example/echo/{string=*}"), + HTTPMethod: "GET", + RequestType: msg, + ResponseType: msg, + PathParams: []Parameter{ + { + FieldPath: FieldPath{ + { + Name: "string", + Target: msg.Fields[0], + }, + }, + Target: msg.Fields[0], + }, + }, + }, + }, + }, + }, + } + + crossLinkFixture(file) + testExtractServices(t, []*descriptor.FileDescriptorProto{&fd}, "path/to/example.proto", file.Services) +} + +func TestExtractServicesWithError(t *testing.T) { + for _, spec := range []struct { + target string + srcs []string + }{ + { + target: "path/to/example.proto", + srcs: []string{ + // message not found + ` + name: "path/to/example.proto", + package: "example" + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "*" + > + > + > + > + `, + }, + }, + // body field path not resolved + { + target: "path/to/example.proto", + srcs: []string{` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "bool" + > + > + > + >`, + }, + }, + // param field path not resolved + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo/{bool=*}" + > + > + > + > + `, + }, + }, + // non aggregate type on field path + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "OuterMessage" + field < + name: "mid" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + field < + name: "bool" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "OuterMessage" + output_type: "OuterMessage" + options < + [google.api.http] < + post: "/v1/example/echo/{mid.bool=*}" + > + > + > + > + `, + }, + }, + // path param in client streaming + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo/{bool=*}" + > + > + client_streaming: true + > + > + `, + }, + }, + // body for GET + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + get: "/v1/example/echo" + body: "string" + > + > + > + > + `, + }, + }, + // body for DELETE + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "RemoveResource" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + delete: "/v1/example/resource" + body: "string" + > + > + > + > + `, + }, + }, + // no pattern specified + { + target: "path/to/example.proto", + srcs: []string{ + ` + name: "path/to/example.proto", + package: "example" + service < + name: "ExampleService" + method < + name: "RemoveResource" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + body: "string" + > + > + > + > + `, + }, + }, + // unsupported path parameter type + { + target: "path/to/example.proto", + srcs: []string{` + name: "path/to/example.proto", + package: "example" + message_type < + name: "OuterMessage" + nested_type < + name: "StringMessage" + field < + name: "value" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: "StringMessage" + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "OuterMessage" + output_type: "OuterMessage" + options < + [google.api.http] < + get: "/v1/example/echo/{string=*}" + > + > + > + > + `, + }, + }, + } { + reg := NewRegistry() + + var fds []*descriptor.FileDescriptorProto + for _, src := range spec.srcs { + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + reg.loadFile(&fd) + fds = append(fds, &fd) + } + err := reg.loadServices(spec.target) + if err == nil { + t.Errorf("loadServices(%q) succeeded; want an error; files=%v", spec.target, spec.srcs) + } + t.Log(err) + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/types.go b/protoc-gen-grpc-gateway/descriptor/types.go new file mode 100644 index 00000000000..e44f821fa42 --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/types.go @@ -0,0 +1,274 @@ +package descriptor + +import ( + "fmt" + "strings" + + "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/httprule" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +// GoPackage represents a golang package +type GoPackage struct { + // Path is the package path to the package. + Path string + // Name is the package name of the package + Name string + // Alias is an alias of the package unique within the current invokation of grpc-gateway generator. + Alias string +} + +// Standard returns whether the import is a golang standard pacakge. +func (p GoPackage) Standard() bool { + return !strings.Contains(p.Path, ".") +} + +// String returns a string representation of this package in the form of import line in golang. +func (p GoPackage) String() string { + if p.Alias == "" { + return fmt.Sprintf("%q", p.Path) + } + return fmt.Sprintf("%s %q", p.Alias, p.Path) +} + +// File wraps descriptor.FileDescriptorProto for richer features. +type File struct { + *descriptor.FileDescriptorProto + // GoPkg is the go package of the go file generated from this file.. + GoPkg GoPackage + // Messages is the list of messages defined in this file. + Messages []*Message + // Services is the list of services defined in this file. + Services []*Service +} + +// proto2 determines if the syntax of the file is proto2. +func (f *File) proto2() bool { + return f.Syntax == nil || f.GetSyntax() == "proto2" +} + +// Message describes a protocol buffer message types +type Message struct { + // File is the file where the message is defined + File *File + // Outers is a list of outer messages if this message is a nested type. + Outers []string + *descriptor.DescriptorProto + Fields []*Field +} + +// FQMN returns a fully qualified message name of this message. +func (m *Message) FQMN() string { + components := []string{""} + if m.File.Package != nil { + components = append(components, m.File.GetPackage()) + } + components = append(components, m.Outers...) + components = append(components, m.GetName()) + return strings.Join(components, ".") +} + +// GoType returns a go type name for the message type. +// It prefixes the type name with the package alias if +// its belonging package is not "currentPackage". +func (m *Message) GoType(currentPackage string) string { + var components []string + components = append(components, m.Outers...) + components = append(components, m.GetName()) + + name := strings.Join(components, "_") + if m.File.GoPkg.Path == currentPackage { + return name + } + pkg := m.File.GoPkg.Name + if alias := m.File.GoPkg.Alias; alias != "" { + pkg = alias + } + return fmt.Sprintf("%s.%s", pkg, name) +} + +// Service wraps descriptor.ServiceDescriptorProto for richer features. +type Service struct { + // File is the file where this service is defined. + File *File + *descriptor.ServiceDescriptorProto + // Methods is the list of methods defined in this service. + Methods []*Method +} + +// Method wraps descriptor.MethodDescriptorProto for richer features. +type Method struct { + // Service is the service which this method belongs to. + Service *Service + *descriptor.MethodDescriptorProto + + // PathTmpl is path template where this method is mapped to. + PathTmpl httprule.Template + // HTTPMethod is the HTTP method which this method is mapped to. + HTTPMethod string + // RequestType is the message type of requests to this method. + RequestType *Message + // ResponseType is the message type of responses from this method. + ResponseType *Message + // PathParams is the list of parameters provided in HTTP request paths. + PathParams []Parameter + // QueryParam is the list of parameters provided in HTTP query strings. + QueryParams []Parameter + // Body describes parameters provided in HTTP request body. + Body *Body +} + +// Field wraps descriptor.FieldDescriptorProto for richer features. +type Field struct { + // Message is the message type which this field belongs to. + Message *Message + *descriptor.FieldDescriptorProto +} + +// Parameter is a parameter provided in http requests +type Parameter struct { + // FieldPath is a path to a proto field which this parameter is mapped to. + FieldPath + // Target is the proto field which this parameter is mapped to. + Target *Field + // Method is the method which this parameter is used for. + Method *Method +} + +// ConvertFuncExpr returns a go expression of a converter function. +// The converter function converts a string into a value for the parameter. +func (p Parameter) ConvertFuncExpr() (string, error) { + tbl := proto3ConvertFuncs + if p.Target.Message.File.proto2() { + tbl = proto2ConvertFuncs + } + typ := p.Target.GetType() + conv, ok := tbl[typ] + if !ok { + return "", fmt.Errorf("unsupported field type %s of parameter %s in %s.%s", typ, p.FieldPath, p.Method.Service.GetName(), p.Method.GetName()) + } + return conv, nil +} + +// Body describes a http requtest body to be sent to the method. +type Body struct { + // DecoderFactoryExpr is a go expression of a factory function + // which takes a io.Reader and returns a Decoder (unmarshaller). + // TODO(yugui) Extract this to a flag. + DecoderFactoryExpr string + + // DecoderImports is a list of packages to be imported from the + // generated go files so that DecoderFactoryExpr is valid. + // TODO(yugui) Extract this to a flag. + DecoderImports []GoPackage + + // FieldPath is a path to a proto field which the request body is mapped to. + // The request body is mapped to the request type itself if FieldPath is empty. + FieldPath FieldPath +} + +// RHS returns a right-hand-side expression in go to be used to initialize method request object. +// It starts with "msgExpr", which is the go expression of the method request object. +func (b Body) RHS(msgExpr string) string { + return b.FieldPath.RHS(msgExpr) +} + +// FieldPath is a path to a field from a request message. +type FieldPath []FieldPathComponent + +// String returns a string representation of the field path. +func (p FieldPath) String() string { + var components []string + for _, c := range p { + components = append(components, c.Name) + } + return strings.Join(components, ".") +} + +// RHS is a right-hand-side expression in go to be used to assign a value to the target field. +// It starts with "msgExpr", which is the go expression of the method request object. +func (p FieldPath) RHS(msgExpr string) string { + l := len(p) + if l == 0 { + return msgExpr + } + components := []string{msgExpr} + for i, c := range p { + if i == l-1 { + components = append(components, c.RHS()) + continue + } + components = append(components, c.LHS()) + } + return strings.Join(components, ".") +} + +// FieldPathComponent is a path component in FieldPath +type FieldPathComponent struct { + // Name is a name of the proto field which this component corresponds to. + // TODO(yugui) is this necessary? + Name string + // Target is the proto field which this component corresponds to. + Target *Field +} + +// RHS returns a right-hand-side expression in go for this field. +func (c FieldPathComponent) RHS() string { + return toCamel(c.Name) +} + +// LHS returns a left-hand-side expression in go for this field. +func (c FieldPathComponent) LHS() string { + if c.Target.Message.File.proto2() { + return fmt.Sprintf("Get%s()", toCamel(c.Name)) + } + return toCamel(c.Name) +} + +var ( + proto3ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{ + descriptor.FieldDescriptorProto_TYPE_DOUBLE: "runtime.Float64", + descriptor.FieldDescriptorProto_TYPE_FLOAT: "runtime.Float32", + descriptor.FieldDescriptorProto_TYPE_INT64: "runtime.Int64", + descriptor.FieldDescriptorProto_TYPE_UINT64: "runtime.Uint64", + descriptor.FieldDescriptorProto_TYPE_INT32: "runtime.Int32", + descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64", + descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32", + descriptor.FieldDescriptorProto_TYPE_BOOL: "runtime.Bool", + descriptor.FieldDescriptorProto_TYPE_STRING: "runtime.String", + // FieldDescriptorProto_TYPE_GROUP + // FieldDescriptorProto_TYPE_MESSAGE + // FieldDescriptorProto_TYPE_BYTES + // TODO(yugui) Handle bytes + descriptor.FieldDescriptorProto_TYPE_UINT32: "runtime.Uint32", + // FieldDescriptorProto_TYPE_ENUM + // TODO(yugui) Handle Enum + descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32", + descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64", + descriptor.FieldDescriptorProto_TYPE_SINT32: "runtime.Int32", + descriptor.FieldDescriptorProto_TYPE_SINT64: "runtime.Int64", + } + + proto2ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{ + descriptor.FieldDescriptorProto_TYPE_DOUBLE: "runtime.Float64P", + descriptor.FieldDescriptorProto_TYPE_FLOAT: "runtime.Float32P", + descriptor.FieldDescriptorProto_TYPE_INT64: "runtime.Int64P", + descriptor.FieldDescriptorProto_TYPE_UINT64: "runtime.Uint64P", + descriptor.FieldDescriptorProto_TYPE_INT32: "runtime.Int32P", + descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64P", + descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32P", + descriptor.FieldDescriptorProto_TYPE_BOOL: "runtime.BoolP", + descriptor.FieldDescriptorProto_TYPE_STRING: "runtime.StringP", + // FieldDescriptorProto_TYPE_GROUP + // FieldDescriptorProto_TYPE_MESSAGE + // FieldDescriptorProto_TYPE_BYTES + // TODO(yugui) Handle bytes + descriptor.FieldDescriptorProto_TYPE_UINT32: "runtime.Uint32P", + // FieldDescriptorProto_TYPE_ENUM + // TODO(yugui) Handle Enum + descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32P", + descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64P", + descriptor.FieldDescriptorProto_TYPE_SINT32: "runtime.Int32P", + descriptor.FieldDescriptorProto_TYPE_SINT64: "runtime.Int64P", + } +) diff --git a/protoc-gen-grpc-gateway/descriptor/types_test.go b/protoc-gen-grpc-gateway/descriptor/types_test.go new file mode 100644 index 00000000000..8642ff23b19 --- /dev/null +++ b/protoc-gen-grpc-gateway/descriptor/types_test.go @@ -0,0 +1,198 @@ +package descriptor + +import ( + "testing" + + "github.com/golang/protobuf/proto" + descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +func TestGoPackageStandard(t *testing.T) { + for _, spec := range []struct { + pkg GoPackage + want bool + }{ + { + pkg: GoPackage{Path: "fmt", Name: "fmt"}, + want: true, + }, + { + pkg: GoPackage{Path: "encoding/json", Name: "json"}, + want: true, + }, + { + pkg: GoPackage{Path: "golang.org/x/net/context", Name: "context"}, + want: false, + }, + { + pkg: GoPackage{Path: "github.com/gengo/grpc-gateway", Name: "main"}, + want: false, + }, + { + pkg: GoPackage{Path: "github.com/google/googleapis/google/api/http.pb", Name: "http_pb", Alias: "htpb"}, + want: false, + }, + } { + if got, want := spec.pkg.Standard(), spec.want; got != want { + t.Errorf("%#v.Standard() = %v; want %v", spec.pkg, got, want) + } + } +} + +func TestGoPackageString(t *testing.T) { + for _, spec := range []struct { + pkg GoPackage + want string + }{ + { + pkg: GoPackage{Path: "fmt", Name: "fmt"}, + want: `"fmt"`, + }, + { + pkg: GoPackage{Path: "encoding/json", Name: "json"}, + want: `"encoding/json"`, + }, + { + pkg: GoPackage{Path: "golang.org/x/net/context", Name: "context"}, + want: `"golang.org/x/net/context"`, + }, + { + pkg: GoPackage{Path: "github.com/gengo/grpc-gateway", Name: "main"}, + want: `"github.com/gengo/grpc-gateway"`, + }, + { + pkg: GoPackage{Path: "github.com/google/googleapis/google/api/http.pb", Name: "http_pb", Alias: "htpb"}, + want: `htpb "github.com/google/googleapis/google/api/http.pb"`, + }, + } { + if got, want := spec.pkg.String(), spec.want; got != want { + t.Errorf("%#v.String() = %q; want %q", spec.pkg, got, want) + } + } +} + +func TestFieldPath(t *testing.T) { + var fds []*descriptor.FileDescriptorProto + for _, src := range []string{ + ` + name: 'example.proto' + package: 'example' + message_type < + name: 'Nest' + field < + name: 'nest2_field' + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: 'Nest2' + number: 1 + > + field < + name: 'terminal_field' + label: LABEL_OPTIONAL + type: TYPE_STRING + number: 2 + > + > + syntax: "proto3" + `, ` + name: 'another.proto' + package: 'example' + message_type < + name: 'Nest2' + field < + name: 'nest_field' + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: 'Nest' + number: 1 + > + field < + name: 'terminal_field' + label: LABEL_OPTIONAL + type: TYPE_STRING + number: 2 + > + > + syntax: "proto2" + `, + } { + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + fds = append(fds, &fd) + } + nest := &Message{ + DescriptorProto: fds[0].MessageType[0], + Fields: []*Field{ + {FieldDescriptorProto: fds[0].MessageType[0].Field[0]}, + {FieldDescriptorProto: fds[0].MessageType[0].Field[1]}, + }, + } + nest2 := &Message{ + DescriptorProto: fds[1].MessageType[0], + Fields: []*Field{ + {FieldDescriptorProto: fds[1].MessageType[0].Field[0]}, + {FieldDescriptorProto: fds[1].MessageType[0].Field[1]}, + }, + } + file1 := &File{ + FileDescriptorProto: fds[0], + GoPkg: GoPackage{Path: "example", Name: "example"}, + Messages: []*Message{nest}, + } + file2 := &File{ + FileDescriptorProto: fds[1], + GoPkg: GoPackage{Path: "example", Name: "example"}, + Messages: []*Message{nest2}, + } + crossLinkFixture(file1) + crossLinkFixture(file2) + + c1 := FieldPathComponent{ + Name: "nest_field", + Target: nest2.Fields[0], + } + if got, want := c1.LHS(), "GetNestField()"; got != want { + t.Errorf("c1.LHS() = %q; want %q", got, want) + } + if got, want := c1.RHS(), "NestField"; got != want { + t.Errorf("c1.RHS() = %q; want %q", got, want) + } + + c2 := FieldPathComponent{ + Name: "nest2_field", + Target: nest.Fields[0], + } + if got, want := c2.LHS(), "Nest2Field"; got != want { + t.Errorf("c2.LHS() = %q; want %q", got, want) + } + if got, want := c2.LHS(), "Nest2Field"; got != want { + t.Errorf("c2.LHS() = %q; want %q", got, want) + } + + fp := FieldPath{ + c1, c2, c1, FieldPathComponent{ + Name: "terminal_field", + Target: nest.Fields[1], + }, + } + if got, want := fp.RHS("resp"), "resp.GetNestField().Nest2Field.GetNestField().TerminalField"; got != want { + t.Errorf("fp.RHS(%q) = %q; want %q", "resp", got, want) + } + + fp2 := FieldPath{ + c2, c1, c2, FieldPathComponent{ + Name: "terminal_field", + Target: nest2.Fields[1], + }, + } + if got, want := fp2.RHS("resp"), "resp.Nest2Field.GetNestField().Nest2Field.TerminalField"; got != want { + t.Errorf("fp2.RHS(%q) = %q; want %q", "resp", got, want) + } + + var fpEmpty FieldPath + if got, want := fpEmpty.RHS("resp"), "resp"; got != want { + t.Errorf("fpEmpty.RHS(%q) = %q; want %q", "resp", got, want) + } +}