Skip to content

Commit

Permalink
get rid of hybrid handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
testinginprod committed Dec 11, 2024
1 parent cef0f84 commit 6dfc518
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 174 deletions.
23 changes: 12 additions & 11 deletions baseapp/grpcrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
)

type QueryRouter interface {
HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error
HandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error
RegisterService(sd *grpc.ServiceDesc, handler interface{})
ResponseNameByRequestName(requestName string) string
Route(path string) GRPCQueryHandler
Expand All @@ -29,9 +29,10 @@ type QueryRouter interface {
type GRPCQueryRouter struct {
// routes maps query handlers used in ABCIQuery.
routes map[string]GRPCQueryHandler
// hybridHandlers maps the request name to the handler. It is a hybrid handler which seamlessly
// handles both gogo and protov2 messages.
hybridHandlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
// handlers maps the request name to the handler. This is the same as routes
// but it handles everything in a non ABCI oriented fashion but more generic
// to the consensus layer.
handlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
// responseByRequestName maps the request name to the response name.
responseByRequestName map[string]string
// binaryCodec is used to encode/decode binary protobuf messages.
Expand All @@ -57,7 +58,7 @@ var (
func NewGRPCQueryRouter() *GRPCQueryRouter {
return &GRPCQueryRouter{
routes: map[string]GRPCQueryHandler{},
hybridHandlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
handlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
responseByRequestName: map[string]string{},
}
}
Expand Down Expand Up @@ -88,7 +89,7 @@ func (qrt *GRPCQueryRouter) RegisterService(sd *grpc.ServiceDesc, handler interf
if err != nil {
panic(err)
}
err = qrt.registerHybridHandler(sd, method, handler)
err = qrt.registerHandler(sd, method, handler)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -143,15 +144,15 @@ func (qrt *GRPCQueryRouter) registerABCIQueryHandler(sd *grpc.ServiceDesc, metho
return nil
}

func (qrt *GRPCQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.hybridHandlers[name]
func (qrt *GRPCQueryRouter) HandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.handlers[name]
}

func (qrt *GRPCQueryRouter) ResponseNameByRequestName(requestName string) string {
return qrt.responseByRequestName[requestName]
}

func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
func (qrt *GRPCQueryRouter) registerHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
// extract message name from method descriptor
inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method)
if err != nil {
Expand All @@ -161,13 +162,13 @@ func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method g
if err != nil {
return err
}
methodHandler, err := protocompat.MakeHybridHandler(qrt.binaryCodec, sd, method, handler)
methodHandler, err := protocompat.MakeHandler(sd, method, handler)
if err != nil {
return err
}
// map input name to output name
qrt.responseByRequestName[string(inputName)] = string(outputName)
qrt.hybridHandlers[string(inputName)] = append(qrt.hybridHandlers[string(inputName)], methodHandler)
qrt.handlers[string(inputName)] = append(qrt.handlers[string(inputName)], methodHandler)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion baseapp/grpcrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestGRPCQueryRouter(t *testing.T) {
func TestGRPCRouterHybridHandlers(t *testing.T) {
assertRouterBehaviour := func(helper *baseapp.QueryServiceTestHelper) {
// test getting the handler by name
handlers := helper.GRPCQueryRouter.HybridHandlerByRequestName("testpb.EchoRequest")
handlers := helper.GRPCQueryRouter.HandlerByRequestName("testpb.EchoRequest")
require.NotNil(t, handlers)
require.Len(t, handlers, 1)
handler := handlers[0]
Expand Down
171 changes: 26 additions & 145 deletions baseapp/internal/protocompat/protocompat.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@ import (
"reflect"

gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/golang/protobuf/proto" //nolint: staticcheck // needed because gogoproto.Merge does not work consistently. See NOTE: comments.
"google.golang.org/grpc"
proto2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoiface"

"github.com/cosmos/cosmos-sdk/codec"
)

var (
gogoType = reflect.TypeOf((*gogoproto.Message)(nil)).Elem()
protov2Type = reflect.TypeOf((*proto2.Message)(nil)).Elem()
protov2MarshalOpts = proto2.MarshalOptions{Deterministic: true}
gogoType = reflect.TypeOf((*gogoproto.Message)(nil)).Elem()
protov2Type = reflect.TypeOf((*proto2.Message)(nil)).Elem()
)

type Handler = func(ctx context.Context, request, response protoiface.MessageV1) error

// MakeHybridHandler returns a handler that can handle both gogo and protov2 messages, no matter
// MakeHandler returns a handler that can handle both gogo and protov2 messages, no matter
// if the handler is a gogo or protov2 handler.
func MakeHybridHandler(cdc codec.BinaryCodec, sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) (Handler, error) {
func MakeHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) (Handler, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
Expand All @@ -42,150 +37,36 @@ func MakeHybridHandler(cdc codec.BinaryCodec, sd *grpc.ServiceDesc, method grpc.
return nil, err
}
if isProtov2Handler {
return makeProtoV2HybridHandler(methodDesc, cdc, method, handler)
return nil, fmt.Errorf("protov2 handlers are not allowed %s", methodFullName)
}
return makeGogoHybridHandler(methodDesc, cdc, method, handler)
return makeGogoHandler(methodDesc, method, handler)
}

// makeProtoV2HybridHandler returns a handler that can handle both gogo and protov2 messages.
func makeProtoV2HybridHandler(prefMethod protoreflect.MethodDescriptor, cdc codec.BinaryCodec, method grpc.MethodDesc, handler any) (Handler, error) {
// it's a protov2 handler, if a gogo counterparty is not found we cannot handle gogo messages.
gogoExists := gogoproto.MessageType(string(prefMethod.Output().FullName())) != nil
if !gogoExists {
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
protov2Request, ok := inReq.(proto2.Message)
if !ok {
return fmt.Errorf("invalid request type %T, method %s does not accept gogoproto messages", inReq, prefMethod.FullName())
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
proto2.Merge(msg.(proto2.Message), protov2Request)
return nil
}, nil)
if err != nil {
return err
}
// merge on the resp
proto2.Merge(outResp.(proto2.Message), resp.(proto2.Message))
return nil
}, nil
}
func makeGogoHandler(prefMethod protoreflect.MethodDescriptor, method grpc.MethodDesc, handler any) (Handler, error) {
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
// we check if the request is a protov2 message.
switch m := inReq.(type) {
case proto2.Message:
// we can just call the handler after making a copy of the message, for safety reasons.
resp, err := method.Handler(handler, ctx, func(msg any) error {
proto2.Merge(msg.(proto2.Message), m)
return nil
}, nil)
if err != nil {
return err
}
// merge on the resp
proto2.Merge(outResp.(proto2.Message), resp.(proto2.Message))
return nil
case gogoproto.Message:
// we need to marshal and unmarshal the request.
requestBytes, err := cdc.Marshal(m)
if err != nil {
return err
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
// unmarshal request into the message.
return proto2.Unmarshal(requestBytes, msg.(proto2.Message))
}, nil)
if err != nil {
return err
}
// the response is a protov2 message, so we cannot just return it.
// since the request came as gogoproto, we expect the response
// to also be gogoproto.
respBytes, err := protov2MarshalOpts.Marshal(resp.(proto2.Message))
if err != nil {
return err
}

// unmarshal response into a gogo message.
return cdc.Unmarshal(respBytes, outResp.(gogoproto.Message))
default:
panic("unreachable")
// we do not handle protov2
_, ok := inReq.(proto2.Message)
if ok {
return fmt.Errorf("invalid request type %T, method %s does not accept protov2 messages", inReq, prefMethod.FullName())
}
}, nil
}

func makeGogoHybridHandler(prefMethod protoreflect.MethodDescriptor, cdc codec.BinaryCodec, method grpc.MethodDesc, handler any) (Handler, error) {
// it's a gogo handler, we check if the existing protov2 counterparty exists.
_, err := protoregistry.GlobalTypes.FindMessageByName(prefMethod.Output().FullName())
if err != nil {
// this can only be a gogo message.
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
_, ok := inReq.(proto2.Message)
if ok {
return fmt.Errorf("invalid request type %T, method %s does not accept protov2 messages", inReq, prefMethod.FullName())
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
// merge! ref: https://github.com/cosmos/cosmos-sdk/issues/18003
// NOTE: using gogoproto.Merge will fail for some reason unknown to me, but
// using proto.Merge with gogo messages seems to work fine.
proto.Merge(msg.(gogoproto.Message), inReq)
return nil
}, nil)
if err != nil {
return err
}
// merge resp, ref: https://github.com/cosmos/cosmos-sdk/issues/18003
// NOTE: using gogoproto.Merge will fail for some reason unknown to me, but
// using proto.Merge with gogo messages seems to work fine.
proto.Merge(outResp.(gogoproto.Message), resp.(gogoproto.Message))
return nil
}, nil
}
// this is a gogo handler, and we have a protov2 counterparty.
return func(ctx context.Context, inReq, outResp protoiface.MessageV1) error {
switch m := inReq.(type) {
case proto2.Message:
// we need to marshal and unmarshal the request.
requestBytes, err := protov2MarshalOpts.Marshal(m)
if err != nil {
return err
}
resp, err := method.Handler(handler, ctx, func(msg any) error {
// unmarshal request into the message.
return cdc.Unmarshal(requestBytes, msg.(gogoproto.Message))
}, nil)
if err != nil {
return err
}
// the response is a gogo message, so we cannot just return it.
// since the request came as protov2, we expect the response
// to also be protov2.
respBytes, err := cdc.Marshal(resp.(gogoproto.Message))
if err != nil {
return err
}
// now we unmarshal back into a protov2 message.
return proto2.Unmarshal(respBytes, outResp.(proto2.Message))
case gogoproto.Message:
// we can just call the handler after making a copy of the message, for safety reasons.
resp, err := method.Handler(handler, ctx, func(msg any) error {
// ref: https://github.com/cosmos/cosmos-sdk/issues/18003
asGogoProto := msg.(gogoproto.Message)
// NOTE: using gogoproto.Merge will fail for some reason unknown to me, but
// using proto.Merge with gogo messages seems to work fine.
proto.Merge(asGogoProto, m)
return nil
}, nil)
if err != nil {
return err
}
// merge on the resp, ref: https://github.com/cosmos/cosmos-sdk/issues/18003
// NOTE: using gogoproto.Merge will fail for some reason unknown to me, but
// using proto.Merge with gogo messages seems to work fine.
proto.Merge(outResp.(gogoproto.Message), resp.(gogoproto.Message))
resp, err := method.Handler(handler, ctx, func(msg any) error {
// reflection to copy from inReq to msg
dstVal := reflect.ValueOf(msg).Elem()
srcVal := reflect.ValueOf(inReq).Elem()
dstVal.Set(srcVal)
return nil
default:
panic("unreachable")
}, nil)
if err != nil {
return err
}

// reflection to copy from resp to outResp
dstVal := reflect.ValueOf(outResp).Elem()
srcVal := reflect.ValueOf(resp).Elem()
dstVal.Set(srcVal)

return nil
}, nil
}

Expand Down
18 changes: 8 additions & 10 deletions baseapp/msg_service_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
errorsmod "cosmossdk.io/errors"

"github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -27,14 +26,14 @@ type MessageRouter interface {
HandlerByTypeURL(typeURL string) MsgServiceHandler

ResponseNameByMsgName(msgName string) string
HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error
HandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error
}

// MsgServiceRouter routes fully-qualified Msg service methods to their handler.
type MsgServiceRouter struct {
interfaceRegistry codectypes.InterfaceRegistry
routes map[string]MsgServiceHandler
hybridHandlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error
handlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error
responseByMsgName map[string]string
circuitBreaker CircuitBreaker
}
Expand All @@ -45,7 +44,7 @@ var _ gogogrpc.Server = &MsgServiceRouter{}
func NewMsgServiceRouter() *MsgServiceRouter {
return &MsgServiceRouter{
routes: map[string]MsgServiceHandler{},
hybridHandlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
handlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
responseByMsgName: map[string]string{},
}
}
Expand Down Expand Up @@ -89,8 +88,8 @@ func (msr *MsgServiceRouter) RegisterService(sd *grpc.ServiceDesc, handler inter
}
}

func (msr *MsgServiceRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
return msr.hybridHandlers[msgName]
func (msr *MsgServiceRouter) HandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
return msr.handlers[msgName]
}

func (msr *MsgServiceRouter) ResponseNameByMsgName(msgName string) string {
Expand All @@ -106,16 +105,15 @@ func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method
if err != nil {
return err
}
cdc := codec.NewProtoCodec(msr.interfaceRegistry)
hybridHandler, err := protocompat.MakeHybridHandler(cdc, sd, method, handler)
hybridHandler, err := protocompat.MakeHandler(sd, method, handler)
if err != nil {
return err
}
// map input name to output name
msr.responseByMsgName[string(inputName)] = string(outputName)
// if circuit breaker is not nil, then we decorate the hybrid handler with the circuit breaker
if msr.circuitBreaker == nil {
msr.hybridHandlers[string(inputName)] = hybridHandler
msr.handlers[string(inputName)] = hybridHandler
return nil
}
// decorate the hybrid handler with the circuit breaker
Expand All @@ -130,7 +128,7 @@ func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method
}
return hybridHandler(ctx, req, resp)
}
msr.hybridHandlers[string(inputName)] = circuitBreakerHybridHandler
msr.handlers[string(inputName)] = circuitBreakerHybridHandler
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion baseapp/msg_service_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestHybridHandlerByMsgName(t *testing.T) {
testdata.MsgServerImpl{},
)

handler := app.MsgServiceRouter().HybridHandlerByMsgName("testpb.MsgCreateDog")
handler := app.MsgServiceRouter().HandlerByMsgName("testpb.MsgCreateDog")

require.NotNil(t, handler)
require.NoError(t, app.Init())
Expand Down
4 changes: 2 additions & 2 deletions runtime/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (failingMsgRouter) ResponseNameByMsgName(msgName string) string {
panic("message router not set")
}

func (failingMsgRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
func (failingMsgRouter) HandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
panic("message router not set")
}

Expand All @@ -91,7 +91,7 @@ type failingQueryRouter struct {
baseapp.QueryRouter
}

func (failingQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
func (failingQueryRouter) HandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
panic("query router not set")
}

Expand Down
Loading

0 comments on commit 6dfc518

Please sign in to comment.