From 57f77b90e2089e8aa5c8a8179cfdd499ecb3004c Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Wed, 8 Jul 2015 13:39:03 +0900 Subject: [PATCH] Format error message in JSON Make the error format customizable. Fixes #25. --- examples/a_bit_of_everything.pb.gw.go | 40 +++++++------- examples/echo_service.pb.gw.go | 8 +-- .../gengateway/template.go | 4 +- runtime/errors.go | 39 ++++++++++++- runtime/errors_test.go | 55 +++++++++++++++++++ runtime/handler.go | 5 +- 6 files changed, 120 insertions(+), 31 deletions(-) create mode 100644 runtime/errors_test.go diff --git a/examples/a_bit_of_everything.pb.gw.go b/examples/a_bit_of_everything.pb.gw.go index 901f1e4feba..3c009a51093 100644 --- a/examples/a_bit_of_everything.pb.gw.go +++ b/examples/a_bit_of_everything.pb.gw.go @@ -380,51 +380,51 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se mux.Handle("POST", pattern_ABitOfEverythingService_Create_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Create_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("POST", pattern_ABitOfEverythingService_CreateBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_CreateBody_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("POST", pattern_ABitOfEverythingService_BulkCreate_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_BulkCreate_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("GET", pattern_ABitOfEverythingService_Lookup_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Lookup_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("GET", pattern_ABitOfEverythingService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_List_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } @@ -435,62 +435,62 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se mux.Handle("PUT", pattern_ABitOfEverythingService_Update_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Update_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("DELETE", pattern_ABitOfEverythingService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Delete_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("GET", pattern_ABitOfEverythingService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Echo_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("POST", pattern_ABitOfEverythingService_Echo_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Echo_1(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("GET", pattern_ABitOfEverythingService_Echo_2, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_Echo_2(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_BulkEcho_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } diff --git a/examples/echo_service.pb.gw.go b/examples/echo_service.pb.gw.go index dfa909fb180..23835caae61 100644 --- a/examples/echo_service.pb.gw.go +++ b/examples/echo_service.pb.gw.go @@ -90,22 +90,22 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn mux.Handle("POST", pattern_EchoService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_EchoService_Echo_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) mux.Handle("POST", pattern_EchoService_EchoBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_EchoService_EchoBody_0(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) }) diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 79259e5ed7c..2820648da93 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -224,13 +224,13 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(runtime.AnnotateContext(ctx, req), client, req, pathParams) if err != nil { - runtime.HTTPError(w, err) + runtime.HTTPError(ctx, w, err) return } {{if $m.GetServerStreaming}} runtime.ForwardResponseStream(w, func() (proto.Message, error) { return resp.Recv() }) {{else}} - runtime.ForwardResponseMessage(w, resp) + runtime.ForwardResponseMessage(ctx, w, resp) {{end}} }) {{end}} diff --git a/runtime/errors.go b/runtime/errors.go index ebb0c877057..dc38373c8b1 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -1,9 +1,12 @@ package runtime import ( + "encoding/json" + "io" "net/http" "github.com/golang/glog" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) @@ -51,10 +54,40 @@ func HTTPStatusFromCode(code codes.Code) int { return http.StatusInternalServerError } -// HTTPError replies to the request with the error. +var ( + // HTTPError replies to the request with the error. + // You can set a custom function to this variable to customize error format. + HTTPError = DefaultHTTPError +) + +type errorBody struct { + Error string `json:"error"` +} + +// DefaultHTTPError is the default implementation of HTTPError. // If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. // If otherwise, it replies with http.StatusInternalServerError. -func HTTPError(w http.ResponseWriter, err error) { +// +// The response body returned by this function is a JSON object, +// which contains a member whose key is "error" and whose value is err.Error(). +func DefaultHTTPError(ctx context.Context, w http.ResponseWriter, err error) { + const fallback = `{"error": "failed to marshal error message"}` + + w.Header().Set("Content-Type", "application/json") + body := errorBody{Error: err.Error()} + buf, merr := json.Marshal(body) + if merr != nil { + glog.Errorf("Failed to marshal error message %q: %v", body, merr) + w.WriteHeader(http.StatusInternalServerError) + if _, err := io.WriteString(w, fallback); err != nil { + glog.Errorf("Failed to write response: %v", err) + } + return + } + st := HTTPStatusFromCode(grpc.Code(err)) - http.Error(w, err.Error(), st) + w.WriteHeader(st) + if _, err := w.Write(buf); err != nil { + glog.Errorf("Failed to write response: %v", err) + } } diff --git a/runtime/errors_test.go b/runtime/errors_test.go new file mode 100644 index 00000000000..5f6edaa57dd --- /dev/null +++ b/runtime/errors_test.go @@ -0,0 +1,55 @@ +package runtime_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gengo/grpc-gateway/runtime" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +func TestDefaultHTTPError(t *testing.T) { + ctx := context.Background() + + for _, spec := range []struct { + err error + status int + msg string + }{ + { + err: fmt.Errorf("example error"), + status: http.StatusInternalServerError, + msg: "example error", + }, + { + err: grpc.Errorf(codes.NotFound, "no such resource"), + status: http.StatusNotFound, + msg: "no such resource", + }, + } { + w := httptest.NewRecorder() + runtime.DefaultHTTPError(ctx, w, spec.err) + + if got, want := w.Header().Get("Content-Type"), "application/json"; got != want { + t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err) + } + if got, want := w.Code, spec.status; got != want { + t.Errorf("w.Code = %d; want %d", got, want) + } + + body := make(map[string]string) + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Errorf("json.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err) + continue + } + if got, want := body["error"], spec.msg; !strings.Contains(got, want) { + t.Errorf(`body["error"] = %q; want %q; on spec.err=%v`, got, want, spec.err) + } + } +} diff --git a/runtime/handler.go b/runtime/handler.go index 69884d608a3..eba21bd6032 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -8,6 +8,7 @@ import ( "github.com/golang/glog" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" ) type responseStreamChunk struct { @@ -58,11 +59,11 @@ func ForwardResponseStream(w http.ResponseWriter, recv func() (proto.Message, er } // ForwardResponseMessage forwards the message from gRPC server to REST client. -func ForwardResponseMessage(w http.ResponseWriter, resp proto.Message) { +func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, resp proto.Message) { buf, err := json.Marshal(resp) if err != nil { glog.Errorf("Marshal error: %v", err) - HTTPError(w, err) + HTTPError(ctx, w, err) return }