From cb23d418945f274f605ea4eae7c0394b52546e45 Mon Sep 17 00:00:00 2001 From: Travis Cline Date: Thu, 26 May 2016 17:39:14 -0700 Subject: [PATCH] Add bidirectional streaming support by interleaving Send() and Recv() calls. --- .../examplepb/a_bit_of_everything.pb.gw.go | 61 ++++++++++++------ examples/examplepb/flow_combination.pb.gw.go | 61 ++++++++++++------ .../gengateway/template.go | 64 ++++++++++++++++++- 3 files changed, 143 insertions(+), 43 deletions(-) diff --git a/examples/examplepb/a_bit_of_everything.pb.gw.go b/examples/examplepb/a_bit_of_everything.pb.gw.go index 022e65f8be6..1451c0852ae 100644 --- a/examples/examplepb/a_bit_of_everything.pb.gw.go +++ b/examples/examplepb/a_bit_of_everything.pb.gw.go @@ -439,35 +439,54 @@ func request_ABitOfEverythingService_BulkEcho_0(ctx context.Context, marshaler r return nil, metadata, err } dec := marshaler.NewDecoder(req.Body) - for { - var protoReq sub.StringMessage - err = dec.Decode(&protoReq) - if err == io.EOF { - break - } - if err != nil { - grpclog.Printf("Failed to decode request: %v", err) - return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) + sendErrs := make(chan error, 1) + go func(errs chan<- error) { + for { + var protoReq sub.StringMessage + err = dec.Decode(&protoReq) + if err == nil { + select { + case errs <- err: + default: + } + } + if err == io.EOF { + select { + case errs <- err: + default: + } + return + } + if err != nil { + grpclog.Printf("Failed to decode request: %v", err) + select { + case errs <- grpc.Errorf(codes.InvalidArgument, "%v", err): + default: + } + } + if err = stream.Send(&protoReq); err != nil { + grpclog.Printf("Failed to send request: %v", err) + select { + case errs <- err: + default: + } + } } - if err = stream.Send(&protoReq); err != nil { - grpclog.Printf("Failed to send request: %v", err) - return nil, metadata, err + if err := stream.CloseSend(); err != nil { + grpclog.Printf("Failed to terminate client stream: %v", err) + select { + case errs <- err: + default: + } } - } - - if err := stream.CloseSend(); err != nil { - grpclog.Printf("Failed to terminate client stream: %v", err) - return nil, metadata, err - } + }(sendErrs) header, err := stream.Header() if err != nil { grpclog.Printf("Failed to get header from client: %v", err) return nil, metadata, err } metadata.HeaderMD = header - - return stream, metadata, nil - + return stream, metadata, <-sendErrs } func request_ABitOfEverythingService_DeepPathEcho_0(ctx context.Context, marshaler runtime.Marshaler, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { diff --git a/examples/examplepb/flow_combination.pb.gw.go b/examples/examplepb/flow_combination.pb.gw.go index 8b575104f05..ed79b37d4df 100644 --- a/examples/examplepb/flow_combination.pb.gw.go +++ b/examples/examplepb/flow_combination.pb.gw.go @@ -102,35 +102,54 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler return nil, metadata, err } dec := marshaler.NewDecoder(req.Body) - for { - var protoReq EmptyProto - err = dec.Decode(&protoReq) - if err == io.EOF { - break - } - if err != nil { - grpclog.Printf("Failed to decode request: %v", err) - return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) + sendErrs := make(chan error, 1) + go func(errs chan<- error) { + for { + var protoReq EmptyProto + err = dec.Decode(&protoReq) + if err == nil { + select { + case errs <- err: + default: + } + } + if err == io.EOF { + select { + case errs <- err: + default: + } + return + } + if err != nil { + grpclog.Printf("Failed to decode request: %v", err) + select { + case errs <- grpc.Errorf(codes.InvalidArgument, "%v", err): + default: + } + } + if err = stream.Send(&protoReq); err != nil { + grpclog.Printf("Failed to send request: %v", err) + select { + case errs <- err: + default: + } + } } - if err = stream.Send(&protoReq); err != nil { - grpclog.Printf("Failed to send request: %v", err) - return nil, metadata, err + if err := stream.CloseSend(); err != nil { + grpclog.Printf("Failed to terminate client stream: %v", err) + select { + case errs <- err: + default: + } } - } - - if err := stream.CloseSend(); err != nil { - grpclog.Printf("Failed to terminate client stream: %v", err) - return nil, metadata, err - } + }(sendErrs) header, err := stream.Header() if err != nil { grpclog.Printf("Failed to get header from client: %v", err) return nil, metadata, err } metadata.HeaderMD = header - - return stream, metadata, nil - + return stream, metadata, <-sendErrs } func request_FlowCombination_RpcBodyRpc_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 416c8c4e6cc..64ea0fea18c 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -117,7 +117,9 @@ var _ = utilities.NewDoubleArray `)) handlerTemplate = template.Must(template.New("handler").Parse(` -{{if .Method.GetClientStreaming}} +{{if and .Method.GetClientStreaming .Method.GetServerStreaming}} +{{template "bidi-streaming-request-func" .}} +{{else if .Method.GetClientStreaming}} {{template "client-streaming-request-func" .}} {{else}} {{template "client-rpc-request-func" .}} @@ -234,6 +236,66 @@ var ( {{end}} }`)) + _ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(` +{{template "request-func-signature" .}} { + var metadata runtime.ServerMetadata + stream, err := client.{{.Method.GetName}}(ctx) + if err != nil { + grpclog.Printf("Failed to start streaming: %v", err) + return nil, metadata, err + } + dec := marshaler.NewDecoder(req.Body) + sendErrs := make(chan error, 1) + go func(errs chan<- error) { + for { + var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} + err = dec.Decode(&protoReq) + if err == nil { + select { + case errs <- err: + default: + } + } + if err == io.EOF { + select { + case errs <- err: + default: + } + return + } + if err != nil { + grpclog.Printf("Failed to decode request: %v", err) + select { + case errs <- grpc.Errorf(codes.InvalidArgument, "%v", err): + default: + } + } + if err = stream.Send(&protoReq); err != nil { + grpclog.Printf("Failed to send request: %v", err) + select { + case errs <- err: + default: + } + } + } + if err := stream.CloseSend(); err != nil { + grpclog.Printf("Failed to terminate client stream: %v", err) + select { + case errs <- err: + default: + } + } + }(sendErrs) + header, err := stream.Header() + if err != nil { + grpclog.Printf("Failed to get header from client: %v", err) + return nil, metadata, err + } + metadata.HeaderMD = header + return stream, metadata, <-sendErrs +} +`)) + trailerTemplate = template.Must(template.New("trailer").Parse(` {{range $svc := .}} // Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but