From 819a804525564c57a91742e31c4846b120a11a01 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Tue, 14 Nov 2023 17:04:21 -0500 Subject: [PATCH] Allow buf curl to use multiple schemas when resolving elements (#2587) This enables the use of multiple `--schema` flags, and it also enables using both `--reflect` and `--schema`. By default, if a `--schema` flag is present, reflection is not used. But if one also explicitly specifies `--reflect`, then both will be used. This allows resolving elements in an RPC that may fall outside the module that defines the RPC method. This can happen if responses use extensions or google.protobuf.Any messages (which includes the use of error details). --- CHANGELOG.md | 7 +- private/buf/cmd/buf/command/curl/curl.go | 51 ++++++---- private/pkg/protoencoding/protoencoding.go | 7 ++ private/pkg/protoencoding/resolver.go | 107 +++++++++++++++++++++ 4 files changed, 152 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7be1cb23..fcdffaa16f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,12 @@ ## [Unreleased] -- No changes yet. +- The `buf curl` command has been updated to support the use of multiple schemas. + This allows users to specify multiple `--schema` flags and/or to use both `--schema` + and `--reflect` flags at the same time. The result is that additional sources can + be consulted to resolve an element. This can be useful when the result of an RPC + contains extensions or values in `google.protobuf.Any` messages that are not defined + in the same schema that defines the RPC service. ## [v1.28.0] - 2023-11-10 diff --git a/private/buf/cmd/buf/command/curl/curl.go b/private/buf/cmd/buf/command/curl/curl.go index 6e37a3f0f3..4cf42473d9 100644 --- a/private/buf/cmd/buf/command/curl/curl.go +++ b/private/buf/cmd/buf/command/curl/curl.go @@ -197,7 +197,7 @@ exit code that is the gRPC code, shifted three bits to the left. type flags struct { // Flags for defining input schema - Schema string + Schemas []string // Flags for server reflection Reflect bool @@ -244,18 +244,22 @@ func newFlags() *flags { func (f *flags) Bind(flagSet *pflag.FlagSet) { f.flagSet = flagSet - flagSet.StringVar( - &f.Schema, + flagSet.StringSliceVar( + &f.Schemas, schemaFlagName, - "", + nil, fmt.Sprintf( `The module to use for the RPC schema. This is necessary if the server does not support server reflection. The format of this argument is the same as for the arguments to other buf sub-commands such as build and generate. It can indicate a directory, a file, a remote module in the Buf Schema Registry, or even standard in ("-") for feeding an image or file descriptor set to the command in a shell pipeline. -Setting this flags implies --%s=false`, - reflectFlagName, +If multiple %s flags are present, they will be consulted in order to resolve service and type +names. Setting this flags implies --%s=false unless a %s flag is explicitly present. If both +%s and %s flags are in use, reflection will be used first and the schemas will be consulted +in order thereafter if reflection fails to resolve a schema element.`, + schemaFlagName, reflectFlagName, reflectFlagName, + schemaFlagName, reflectFlagName, ), ) flagSet.BoolVar( @@ -504,18 +508,24 @@ func (f *flags) validate(isSecure bool) error { return fmt.Errorf("--%s and --%s flags are mutually exclusive; they may not both be specified", netrcFlagName, netrcFileFlagName) } - if f.Schema != "" && f.Reflect { - if f.flagSet.Changed(reflectFlagName) { - // explicitly enabled both - return fmt.Errorf("cannot specify both --%s and --%s", schemaFlagName, reflectFlagName) - } + if len(f.Schemas) > 0 && f.Reflect && !f.flagSet.Changed(reflectFlagName) { // Reflect just has default value; unset it since we're going to use --schema instead f.Reflect = false } - if !f.Reflect && f.Schema == "" { + if !f.Reflect && len(f.Schemas) == 0 { return fmt.Errorf("must specify --%s if --%s is false", schemaFlagName, reflectFlagName) } - schemaIsStdin := strings.HasPrefix(f.Schema, "-") + var schemaIsStdin bool + for _, schema := range f.Schemas { + isStdin := strings.HasPrefix(schema, "-") + if isStdin && schemaIsStdin { + // more than one schema argument wants to use stdin + return fmt.Errorf("multiple --%s flags indicate the use of stdin which is not allowed", schemaFlagName) + } + if isStdin { + schemaIsStdin = true + } + } if (len(f.ReflectHeaders) > 0 || f.flagSet.Changed(reflectProtocolFlagName)) && !f.Reflect { return fmt.Errorf( "reflection flags (--%s, --%s) should not be used if --%s is false", @@ -866,7 +876,7 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error) } } - var res protoencoding.Resolver + resolvers := make([]protoencoding.Resolver, 0, len(f.Schemas)+1) if f.Reflect { reflectHeaders, _, err := bufcurl.LoadHeaders(f.ReflectHeaders, "", requestHeaders) if err != nil { @@ -892,11 +902,12 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error) if err != nil { return err } - var closeRes func() - res, closeRes = bufcurl.NewServerReflectionResolver(ctx, transport, clientOptions, baseURL, reflectProtocol, reflectHeaders, container.VerbosePrinter()) + res, closeRes := bufcurl.NewServerReflectionResolver(ctx, transport, clientOptions, baseURL, reflectProtocol, reflectHeaders, container.VerbosePrinter()) defer closeRes() - } else { - ref, err := buffetch.NewRefParser(container.Logger()).GetRef(ctx, f.Schema) + resolvers = append(resolvers, res) + } + for _, schema := range f.Schemas { + ref, err := buffetch.NewRefParser(container.Logger()).GetRef(ctx, schema) if err != nil { return err } @@ -943,11 +954,13 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error) if err != nil { return err } - res, err = protoencoding.NewResolver(bufimage.ImageToFileDescriptorProtos(image)...) + res, err := protoencoding.NewResolver(bufimage.ImageToFileDescriptorProtos(image)...) if err != nil { return err } + resolvers = append(resolvers, res) } + res := protoencoding.CombineResolvers(resolvers...) methodDescriptor, err := bufcurl.ResolveMethodDescriptor(res, service, method) if err != nil { diff --git a/private/pkg/protoencoding/protoencoding.go b/private/pkg/protoencoding/protoencoding.go index 5686c50154..8e9fc99037 100644 --- a/private/pkg/protoencoding/protoencoding.go +++ b/private/pkg/protoencoding/protoencoding.go @@ -50,6 +50,13 @@ func NewLazyResolver[F protodescriptor.FileDescriptor](fileDescriptors ...F) Res }} } +// CombineResolvers returns a resolver that uses all of the given resolvers. It +// will use the first resolver, and if it returns an error, the second will be +// tried, and so on. +func CombineResolvers(resolvers ...Resolver) Resolver { + return combinedResolver(resolvers) +} + // Marshaler marshals Messages. type Marshaler interface { Marshal(message proto.Message) ([]byte, error) diff --git a/private/pkg/protoencoding/resolver.go b/private/pkg/protoencoding/resolver.go index 34fc523f2d..af8a35a484 100644 --- a/private/pkg/protoencoding/resolver.go +++ b/private/pkg/protoencoding/resolver.go @@ -157,3 +157,110 @@ func (l *lazyResolver) FindMessageByURL(url string) (protoreflect.MessageType, e } return l.resolver.FindMessageByURL(url) } + +type combinedResolver []Resolver + +func (c combinedResolver) FindFileByPath(s string) (protoreflect.FileDescriptor, error) { + var lastErr error + for _, res := range c { + file, err := res.FindFileByPath(s) + if err == nil { + return file, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) { + var lastErr error + for _, res := range c { + desc, err := res.FindDescriptorByName(name) + if err == nil { + return desc, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { + var lastErr error + for _, res := range c { + extension, err := res.FindExtensionByName(field) + if err == nil { + return extension, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + var lastErr error + for _, res := range c { + extension, err := res.FindExtensionByNumber(message, field) + if err == nil { + return extension, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) { + var lastErr error + for _, res := range c { + msg, err := res.FindMessageByName(message) + if err == nil { + return msg, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) { + var lastErr error + for _, res := range c { + msg, err := res.FindMessageByURL(url) + if err == nil { + return msg, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +} + +func (c combinedResolver) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) { + var lastErr error + for _, res := range c { + msg, err := res.FindEnumByName(enum) + if err == nil { + return msg, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, protoregistry.NotFound +}