Skip to content

Commit

Permalink
Allow buf curl to use multiple schemas when resolving elements (#2587)
Browse files Browse the repository at this point in the history
This enables the use of multiple `--schema` flags, and it also enables
using both `--reflect` and `--schema`. By default, if a `--schema` flag
is present, reflection is not used. But if one also explicitly specifies
`--reflect`, then both will be used.

This allows resolving elements in an RPC that may fall outside the
module that defines the RPC method. This can happen if responses
use extensions or google.protobuf.Any messages (which includes the
use of error details).
  • Loading branch information
jhump authored Nov 14, 2023
1 parent aaf2d2c commit 819a804
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 20 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

## [Unreleased]

- No changes yet.
- The `buf curl` command has been updated to support the use of multiple schemas.
This allows users to specify multiple `--schema` flags and/or to use both `--schema`
and `--reflect` flags at the same time. The result is that additional sources can
be consulted to resolve an element. This can be useful when the result of an RPC
contains extensions or values in `google.protobuf.Any` messages that are not defined
in the same schema that defines the RPC service.

## [v1.28.0] - 2023-11-10

Expand Down
51 changes: 32 additions & 19 deletions private/buf/cmd/buf/command/curl/curl.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ exit code that is the gRPC code, shifted three bits to the left.

type flags struct {
// Flags for defining input schema
Schema string
Schemas []string

// Flags for server reflection
Reflect bool
Expand Down Expand Up @@ -244,18 +244,22 @@ func newFlags() *flags {
func (f *flags) Bind(flagSet *pflag.FlagSet) {
f.flagSet = flagSet

flagSet.StringVar(
&f.Schema,
flagSet.StringSliceVar(
&f.Schemas,
schemaFlagName,
"",
nil,
fmt.Sprintf(
`The module to use for the RPC schema. This is necessary if the server does not support
server reflection. The format of this argument is the same as for the <input> arguments to
other buf sub-commands such as build and generate. It can indicate a directory, a file, a
remote module in the Buf Schema Registry, or even standard in ("-") for feeding an image or
file descriptor set to the command in a shell pipeline.
Setting this flags implies --%s=false`,
reflectFlagName,
If multiple %s flags are present, they will be consulted in order to resolve service and type
names. Setting this flags implies --%s=false unless a %s flag is explicitly present. If both
%s and %s flags are in use, reflection will be used first and the schemas will be consulted
in order thereafter if reflection fails to resolve a schema element.`,
schemaFlagName, reflectFlagName, reflectFlagName,
schemaFlagName, reflectFlagName,
),
)
flagSet.BoolVar(
Expand Down Expand Up @@ -504,18 +508,24 @@ func (f *flags) validate(isSecure bool) error {
return fmt.Errorf("--%s and --%s flags are mutually exclusive; they may not both be specified", netrcFlagName, netrcFileFlagName)
}

if f.Schema != "" && f.Reflect {
if f.flagSet.Changed(reflectFlagName) {
// explicitly enabled both
return fmt.Errorf("cannot specify both --%s and --%s", schemaFlagName, reflectFlagName)
}
if len(f.Schemas) > 0 && f.Reflect && !f.flagSet.Changed(reflectFlagName) {
// Reflect just has default value; unset it since we're going to use --schema instead
f.Reflect = false
}
if !f.Reflect && f.Schema == "" {
if !f.Reflect && len(f.Schemas) == 0 {
return fmt.Errorf("must specify --%s if --%s is false", schemaFlagName, reflectFlagName)
}
schemaIsStdin := strings.HasPrefix(f.Schema, "-")
var schemaIsStdin bool
for _, schema := range f.Schemas {
isStdin := strings.HasPrefix(schema, "-")
if isStdin && schemaIsStdin {
// more than one schema argument wants to use stdin
return fmt.Errorf("multiple --%s flags indicate the use of stdin which is not allowed", schemaFlagName)
}
if isStdin {
schemaIsStdin = true
}
}
if (len(f.ReflectHeaders) > 0 || f.flagSet.Changed(reflectProtocolFlagName)) && !f.Reflect {
return fmt.Errorf(
"reflection flags (--%s, --%s) should not be used if --%s is false",
Expand Down Expand Up @@ -866,7 +876,7 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error)
}
}

var res protoencoding.Resolver
resolvers := make([]protoencoding.Resolver, 0, len(f.Schemas)+1)
if f.Reflect {
reflectHeaders, _, err := bufcurl.LoadHeaders(f.ReflectHeaders, "", requestHeaders)
if err != nil {
Expand All @@ -892,11 +902,12 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error)
if err != nil {
return err
}
var closeRes func()
res, closeRes = bufcurl.NewServerReflectionResolver(ctx, transport, clientOptions, baseURL, reflectProtocol, reflectHeaders, container.VerbosePrinter())
res, closeRes := bufcurl.NewServerReflectionResolver(ctx, transport, clientOptions, baseURL, reflectProtocol, reflectHeaders, container.VerbosePrinter())
defer closeRes()
} else {
ref, err := buffetch.NewRefParser(container.Logger()).GetRef(ctx, f.Schema)
resolvers = append(resolvers, res)
}
for _, schema := range f.Schemas {
ref, err := buffetch.NewRefParser(container.Logger()).GetRef(ctx, schema)
if err != nil {
return err
}
Expand Down Expand Up @@ -943,11 +954,13 @@ func run(ctx context.Context, container appflag.Container, f *flags) (err error)
if err != nil {
return err
}
res, err = protoencoding.NewResolver(bufimage.ImageToFileDescriptorProtos(image)...)
res, err := protoencoding.NewResolver(bufimage.ImageToFileDescriptorProtos(image)...)
if err != nil {
return err
}
resolvers = append(resolvers, res)
}
res := protoencoding.CombineResolvers(resolvers...)

methodDescriptor, err := bufcurl.ResolveMethodDescriptor(res, service, method)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions private/pkg/protoencoding/protoencoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ func NewLazyResolver[F protodescriptor.FileDescriptor](fileDescriptors ...F) Res
}}
}

// CombineResolvers returns a resolver that uses all of the given resolvers. It
// will use the first resolver, and if it returns an error, the second will be
// tried, and so on.
func CombineResolvers(resolvers ...Resolver) Resolver {
return combinedResolver(resolvers)
}

// Marshaler marshals Messages.
type Marshaler interface {
Marshal(message proto.Message) ([]byte, error)
Expand Down
107 changes: 107 additions & 0 deletions private/pkg/protoencoding/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,110 @@ func (l *lazyResolver) FindMessageByURL(url string) (protoreflect.MessageType, e
}
return l.resolver.FindMessageByURL(url)
}

type combinedResolver []Resolver

func (c combinedResolver) FindFileByPath(s string) (protoreflect.FileDescriptor, error) {
var lastErr error
for _, res := range c {
file, err := res.FindFileByPath(s)
if err == nil {
return file, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
var lastErr error
for _, res := range c {
desc, err := res.FindDescriptorByName(name)
if err == nil {
return desc, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
var lastErr error
for _, res := range c {
extension, err := res.FindExtensionByName(field)
if err == nil {
return extension, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
var lastErr error
for _, res := range c {
extension, err := res.FindExtensionByNumber(message, field)
if err == nil {
return extension, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
var lastErr error
for _, res := range c {
msg, err := res.FindMessageByName(message)
if err == nil {
return msg, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
var lastErr error
for _, res := range c {
msg, err := res.FindMessageByURL(url)
if err == nil {
return msg, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

func (c combinedResolver) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
var lastErr error
for _, res := range c {
msg, err := res.FindEnumByName(enum)
if err == nil {
return msg, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, protoregistry.NotFound
}

0 comments on commit 819a804

Please sign in to comment.