Skip to content

Commit

Permalink
finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
technicallyty committed Jan 15, 2025
1 parent c5854cb commit 6135523
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 9 deletions.
57 changes: 48 additions & 9 deletions server/v2/api/grpcgateway/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"errors"
"fmt"
"io"
"maps"
"net/http"
"reflect"
"regexp"
"slices"
"strconv"
"strings"

Expand Down Expand Up @@ -62,9 +64,25 @@ func registerMethods[T transaction.Tx](mux *http.ServeMux, am appmanager.AppMana
gateway.ServeHTTP(w, r)
})

// register the dynamic handlers.
for uri, queryMD := range annotationToMetadata {
mux.Handle(uri, &protoHandler[T]{msg: queryMD.msg, gateway: gateway, wildcardKeyNames: queryMD.wildcardKeyNames})
// register in deterministic order. we do this because of the problem mentioned below, and different nodes could
// end up with one version of the handler or the other.
uris := slices.Sorted(maps.Keys(annotationToMetadata))

for _, uri := range uris {
queryMD := annotationToMetadata[uri]
// we need to wrap this in a panic handler because cosmos SDK proto stubs contains a duplicate annotation
// that causes the registration to panic.
func(u string, qMD queryMetadata) {
defer func() {
_ = recover()
}()
mux.Handle(u, &protoHandler[T]{
msg: qMD.msg,
gateway: gateway,
appManager: am,
wildcardKeyNames: qMD.wildcardKeyNames,
})
}(uri, queryMD)
}
}

Expand Down Expand Up @@ -107,8 +125,13 @@ func (p *protoHandler[T]) ServeHTTP(writer http.ResponseWriter, request *http.Re

responseMsg, err := p.appManager.Query(request.Context(), height, inputMsg)
if err != nil {
// for all other errors, we just return the error.
runtime.HTTPError(request.Context(), p.gateway, out, writer, request, err)
// if we couldn't find a handler for this request, just fall back to the gateway mux.
if strings.Contains(err.Error(), "no handler") {
p.gateway.ServeHTTP(writer, request)
} else {
// for all other errors, we just return the error.
runtime.HTTPError(request.Context(), p.gateway, out, writer, request, err)
}
return
}

Expand All @@ -119,7 +142,9 @@ func (p *protoHandler[T]) populateMessage(req *http.Request, marshaler runtime.M
// see if we have path params to populate the message with.
if len(pathParams) > 0 {
for pathKey, pathValue := range pathParams {
runtime.PopulateFieldFromPath(input, pathKey, pathValue)
if err := runtime.PopulateFieldFromPath(input, pathKey, pathValue); err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Errorf("failed to populate field %s with value %s: %w", pathKey, pathValue, err).Error())
}
}
}

Expand Down Expand Up @@ -195,10 +220,10 @@ func newHTTPAnnotationMapping() (map[string]string, error) {
httpRules := append(httpRule.GetAdditionalBindings(), httpRule)
for _, rule := range httpRules {
if httpAnnotation := rule.GetGet(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
if httpAnnotation := rule.GetPost(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
annotationToQueryInputName[fixCatchAll(httpAnnotation)] = queryInputName
}
}
}
Expand All @@ -208,12 +233,24 @@ func newHTTPAnnotationMapping() (map[string]string, error) {
return annotationToQueryInputName, nil
}

var catchAllRegex = regexp.MustCompile(`\{([^=]+)=\*\*\}`)

// fixCatchAll replaces grpc gateway catch all syntax with net/http syntax.
//
// {foo=**} -> {foo...}
func fixCatchAll(uri string) string {
return catchAllRegex.ReplaceAllString(uri, `{$1...}`)
}

// annotationsToQueryMetadata takes annotations and creates a mapping of URIs to queryMetadata.
func annotationsToQueryMetadata(annotations map[string]string) (map[string]queryMetadata, error) {
annotationToMetadata := make(map[string]queryMetadata)
for uri, queryInputName := range annotations {
// extract the proto message type.
msgType := gogoproto.MessageType(queryInputName)
if msgType == nil {
continue
}
msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, fmt.Errorf("query input type %q does not implement gogoproto.Message", queryInputName)
Expand All @@ -232,7 +269,9 @@ func extractWildcardKeyNames(uri string) []string {
for _, match := range matches {
// match[0] is the full string including braces (i.e. "{bar}")
// match[1] is the captured group (i.e. "bar")
extracted = append(extracted, match[1])
// we also need to handle the catch-all case with URI's like "bar..." and
// transform them to just "bar".
extracted = append(extracted, strings.TrimRight(match[1], "."))
}
return extracted
}
34 changes: 34 additions & 0 deletions server/v2/api/grpcgateway/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@ import (
"cosmossdk.io/core/transaction"
)

func Test_fixCatchAll(t *testing.T) {
tests := []struct {
name string
uri string
want string
}{
{
name: "replaces catch all",
uri: "/foo/bar/{baz=**}",
want: "/foo/bar/{baz...}",
},
{
name: "returns original",
uri: "/foo/bar/baz",
want: "/foo/bar/baz",
},
{
name: "doesn't tamper with normal wildcard",
uri: "/foo/{baz}",
want: "/foo/{baz}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, fixCatchAll(tt.uri))
})
}
}

func Test_extractWildcardKeyNames(t *testing.T) {
tests := []struct {
name string
Expand All @@ -31,6 +60,11 @@ func Test_extractWildcardKeyNames(t *testing.T) {
uri: "/foo/{bar}/baz/{buzz}",
want: []string{"bar", "buzz"},
},
{
name: "catch-all wildcard",
uri: "/foo/{buzz...}",
want: []string{"buzz"},
},
{
name: "none",
uri: "/foo/bar",
Expand Down

0 comments on commit 6135523

Please sign in to comment.