Skip to content

Commit

Permalink
fix(server/v2): post request fallback (#23361)
Browse files Browse the repository at this point in the history
  • Loading branch information
technicallyty authored Jan 14, 2025
1 parent b461a31 commit 265cb94
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
28 changes: 17 additions & 11 deletions server/v2/api/grpcgateway/interceptor.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package grpcgateway

import (
"bytes"
"errors"
"io"
"net/http"
Expand Down Expand Up @@ -89,7 +90,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
msgType := gogoproto.MessageType(match.QueryInputName)
msg, ok := reflect.New(msgType.Elem()).Interface().(gogoproto.Message)
if !ok {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.Internal, "unable to to create gogoproto message from query input name %s", match.QueryInputName))
return
}

Expand All @@ -102,12 +103,12 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
case http.MethodPost:
inputMsg, err = g.createMessageFromPostRequest(in, request, msg)
default:
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET"))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Error(codes.InvalidArgument, "HTTP method was not POST or GET"))
return
}
if err != nil {
// the errors returned from the message creation methods return status errors. no need to make one here.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err)
return
}

Expand All @@ -118,7 +119,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
if heightStr != "" && heightStr != "latest" {
height, err = strconv.ParseUint(heightStr, 10, 64)
if err != nil {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, status.Errorf(codes.InvalidArgument, "invalid height in header: %s", heightStr))
return
}
}
Expand All @@ -130,7 +131,7 @@ func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *h
g.gateway.ServeHTTP(writer, request)
} else {
// for all other errors, we just return the error.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
runtime.HTTPError(request.Context(), g.gateway, out, writer, request, err)
}
return
}
Expand All @@ -143,12 +144,17 @@ func (g *gatewayInterceptor[T]) createMessageFromPostRequest(marshaler runtime.M
if req.ContentLength > MaxBodySize {
return nil, status.Errorf(codes.InvalidArgument, "request body too large: %d bytes, max=%d", req.ContentLength, MaxBodySize)
}
newReader, err := utilities.IOReaderFactory(req.Body)

// this block of code ensures that the body can be re-read. this is needed as if the query fails in the
// app's query handler, we need to pass the request back to the canonical gateway, which needs to be able to
// read the body again.
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))

if err = marshaler.NewDecoder(newReader()).Decode(input); err != nil && !errors.Is(err, io.EOF) {
if err = marshaler.NewDecoder(bytes.NewReader(bodyBytes)).Decode(input); err != nil && !errors.Is(err, io.EOF) {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}

Expand Down Expand Up @@ -217,12 +223,12 @@ func getHTTPGetAnnotationMapping() (map[string]string, error) {
continue
}
queryInputName := string(methodDesc.Input().FullName())
annotations := append(httpRule.GetAdditionalBindings(), httpRule)
for _, a := range annotations {
if httpAnnotation := a.GetGet(); httpAnnotation != "" {
httpRules := append(httpRule.GetAdditionalBindings(), httpRule)
for _, rule := range httpRules {
if httpAnnotation := rule.GetGet(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
if httpAnnotation := a.GetPost(); httpAnnotation != "" {
if httpAnnotation := rule.GetPost(); httpAnnotation != "" {
annotationToQueryInputName[httpAnnotation] = queryInputName
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/systemtests/distribution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func TestDistrValidatorGRPCQueries(t *testing.T) {

// test validator slashes grpc endpoint
slashURL := baseurl + `/cosmos/distribution/v1beta1/validators/%s/slashes`
invalidHeightOutput := `{"code":"NUMBER", "details":[]interface {}{}, "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}`
invalidHeightOutput := `{"code":"NUMBER", "details":[], "message":"strconv.ParseUint: parsing \"NUMBER\": invalid syntax"}`

slashTestCases := []systest.RestTestCase{
{
Expand Down

0 comments on commit 265cb94

Please sign in to comment.