diff --git a/go/grpcweb/options.go b/go/grpcweb/options.go index 961a76ce..085c5c91 100644 --- a/go/grpcweb/options.go +++ b/go/grpcweb/options.go @@ -24,6 +24,7 @@ type options struct { enableWebsockets bool websocketPingInterval time.Duration websocketOriginFunc func(req *http.Request) bool + websocketReadLimit int64 allowNonRootResources bool endpointsFunc *func() []string } @@ -132,6 +133,15 @@ func WithWebsocketOriginFunc(websocketOriginFunc func(req *http.Request) bool) O } } +// WithWebsocketsMessageReadLimit sets the maximum message read limit on the underlying websocket. +// +// The default message read limit is 32769 bytes +func WithWebsocketsMessageReadLimit(websocketReadLimit int64) Option { + return func(o *options) { + o.websocketReadLimit = websocketReadLimit + } +} + // WithAllowNonRootResource enables the gRPC wrapper to serve requests that have a path prefix // added to the URL, before the service name and method placeholders. // diff --git a/go/grpcweb/wrapper.go b/go/grpcweb/wrapper.go index eb5cd2dc..6e163e2d 100644 --- a/go/grpcweb/wrapper.go +++ b/go/grpcweb/wrapper.go @@ -35,6 +35,7 @@ type WrappedGrpcServer struct { originFunc func(origin string) bool enableWebsockets bool websocketOriginFunc func(req *http.Request) bool + websocketReadLimit int64 allowedHeaders []string endpointFunc func(req *http.Request) string endpointsFunc func() []string @@ -97,6 +98,7 @@ func wrapGrpc(options []Option, handler http.Handler, endpointsFunc func() []str originFunc: opts.originFunc, enableWebsockets: opts.enableWebsockets, websocketOriginFunc: websocketOriginFunc, + websocketReadLimit: opts.websocketReadLimit, allowedHeaders: allowedHeaders, endpointFunc: endpointFunc, endpointsFunc: endpointsFunc, @@ -160,6 +162,11 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, grpclog.Errorf("Unable to upgrade websocket request: %v", err) return } + + if w.websocketReadLimit > 0 { + wsConn.SetReadLimit(w.websocketReadLimit) + } + headers := make(http.Header) for _, name := range w.allowedHeaders { if values, exist := req.Header[name]; exist { diff --git a/go/grpcwebproxy/main.go b/go/grpcwebproxy/main.go index 37b1af27..d9ad9f10 100644 --- a/go/grpcwebproxy/main.go +++ b/go/grpcwebproxy/main.go @@ -41,6 +41,7 @@ var ( useWebsockets = pflag.Bool("use_websockets", false, "whether to use beta websocket transport layer") websocketPingInterval = pflag.Duration("websocket_ping_interval", 0, "whether to use websocket keepalive pinging. Only used when using websockets. Configured interval must be >= 1s.") + websocketReadLimit = pflag.Int64("websocket_read_limit", 0, "sets the maximum message read limit on the underlying websocket. The default message read limit is 32769 bytes.") flagHttpMaxWriteTimeout = pflag.Duration("server_http_max_write_timeout", 10*time.Second, "HTTP server config, max write duration.") flagHttpMaxReadTimeout = pflag.Duration("server_http_max_read_timeout", 10*time.Second, "HTTP server config, max read duration.") @@ -84,6 +85,10 @@ func main() { if *websocketPingInterval >= time.Second { logrus.Infof("websocket keepalive pinging enabled, the timeout interval is %s", websocketPingInterval.String()) } + if *websocketReadLimit > 0 { + options = append(options, grpcweb.WithWebsocketsMessageReadLimit(*websocketReadLimit)) + } + options = append( options, grpcweb.WithWebsocketPingInterval(*websocketPingInterval),