From 6fdf626ad816b8b824fe3ab8fff1f9212ebb3e20 Mon Sep 17 00:00:00 2001 From: Maximilian Moehl Date: Fri, 8 Nov 2024 15:58:43 +0100 Subject: [PATCH] feat: configurable limits for header counts This commit adds support to limit the amount of request and response headers gorouter will accept and process. Resolves: https://github.com/cloudfoundry/routing-release/issues/309 --- config/config.go | 2 + handlers/max_request_size.go | 19 ++-- handlers/max_request_size_test.go | 12 +++ proxy/round_tripper/proxy_round_tripper.go | 86 +++++++++++++++---- .../round_tripper/proxy_round_tripper_test.go | 43 ++++++++++ 5 files changed, 136 insertions(+), 26 deletions(-) diff --git a/config/config.go b/config/config.go index 062bc825..63198163 100644 --- a/config/config.go +++ b/config/config.go @@ -462,6 +462,8 @@ type Config struct { MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host,omitempty"` MaxRequestHeaderBytes int `yaml:"max_header_bytes"` MaxResponseHeaderBytes int `yaml:"max_response_header_bytes"` + MaxRequestHeaders int `yaml:"max_request_headers"` + MaxResponseHeaders int `yaml:"max_response_headers"` KeepAlive100ContinueRequests bool `yaml:"keep_alive_100_continue_requests"` HTTPRewrite HTTPRewrite `yaml:"http_rewrite,omitempty"` diff --git a/handlers/max_request_size.go b/handlers/max_request_size.go index a64680bf..d164b5e8 100644 --- a/handlers/max_request_size.go +++ b/handlers/max_request_size.go @@ -11,9 +11,10 @@ import ( ) type MaxRequestSize struct { - cfg *config.Config - MaxSize int - logger *slog.Logger + cfg *config.Config + MaxSize int + MaxCount int + logger *slog.Logger } const ONE_MB = 1024 * 1024 // bytes * kb @@ -33,9 +34,10 @@ func NewMaxRequestSize(cfg *config.Config, logger *slog.Logger) *MaxRequestSize } return &MaxRequestSize{ - MaxSize: maxSize, - logger: logger, - cfg: cfg, + MaxSize: maxSize, + MaxCount: cfg.MaxRequestHeaders, + logger: logger, + cfg: cfg, } } @@ -49,6 +51,8 @@ func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next // Host header which is not passed on to us, plus eight bytes for 'Host: ' and \r\n reqSize += len(r.Host) + 8 + hdrCount := 0 + // Go doesn't split header values on commas, instead it only splits the value when it's // provided via repeated header keys. Therefore we have to account for each value of a repeated // header as well as its key. @@ -56,10 +60,11 @@ func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next for _, v := range vv { // Four additional bytes for the colon and space after the header key and \r\n. reqSize += len(k) + len(v) + 4 + hdrCount++ } } - if reqSize >= m.MaxSize { + if reqSize >= m.MaxSize || (m.MaxCount > 0 && hdrCount > m.MaxCount) { reqInfo, err := ContextRequestInfo(r) if err != nil { logger.Error("request-info-err", log.ErrAttr(err)) diff --git a/handlers/max_request_size_test.go b/handlers/max_request_size_test.go index 534d63dd..6755bca1 100644 --- a/handlers/max_request_size_test.go +++ b/handlers/max_request_size_test.go @@ -64,6 +64,7 @@ var _ = Describe("MaxRequestSize", func() { BeforeEach(func() { cfg = &config.Config{ MaxRequestHeaderBytes: 89, + MaxRequestHeaders: 15, LoadBalance: config.LOAD_BALANCE_RR, StickySessionCookieNames: config.StringSet{"blarg": struct{}{}}, } @@ -177,6 +178,17 @@ var _ = Describe("MaxRequestSize", func() { Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) }) }) + Context("when there are too many headers", func() { + BeforeEach(func() { + for i := 0; i < 16; i++ { + header.Add("f", "m") + } + }) + It("throws an http 431", func() { + handleRequest() + Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) + }) + }) Context("when enough normally-sized headers put the request over the limit", func() { BeforeEach(func() { header.Add("header1", "smallRequest") diff --git a/proxy/round_tripper/proxy_round_tripper.go b/proxy/round_tripper/proxy_round_tripper.go index 75801a60..48b4544f 100644 --- a/proxy/round_tripper/proxy_round_tripper.go +++ b/proxy/round_tripper/proxy_round_tripper.go @@ -38,7 +38,10 @@ const ( AuthNegotiateHeaderCookieMaxAgeInSeconds = 60 ) -var NoEndpointsAvailable = errors.New("No endpoints available") +var ( + NoEndpointsAvailable = errors.New("No endpoints available") + TooManyResponseHeaders = errors.New("too many response headers") +) //go:generate counterfeiter -o fakes/fake_proxy_round_tripper.go . ProxyRoundTripper type ProxyRoundTripper interface { @@ -178,6 +181,18 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response } res, err = rt.backendRoundTrip(request, endpoint, iter, logger) + logger = logger.With( + slog.Int("attempt", attempt), + slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)), + slog.Int("num-endpoints", numberOfEndpoints), + slog.Bool("got-connection", trace.GotConn()), + slog.Bool("wrote-headers", trace.WroteHeaders()), + slog.Bool("conn-reused", trace.ConnReused()), + slog.Float64("dns-lookup-time", trace.DnsTime()), + slog.Float64("dial-time", trace.DialTime()), + slog.Float64("tls-handshake-time", trace.TlsTime()), + ) + if err != nil { reqInfo.FailedAttempts++ reqInfo.LastFailedAttemptFinishedAt = time.Now() @@ -185,16 +200,7 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response logger.Error("backend-endpoint-failed", log.ErrAttr(err), - slog.Int("attempt", attempt), - slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)), slog.Bool("retriable", retriable), - slog.Int("num-endpoints", numberOfEndpoints), - slog.Bool("got-connection", trace.GotConn()), - slog.Bool("wrote-headers", trace.WroteHeaders()), - slog.Bool("conn-reused", trace.ConnReused()), - slog.Float64("dns-lookup-time", trace.DnsTime()), - slog.Float64("dial-time", trace.DialTime()), - slog.Float64("tls-handshake-time", trace.TlsTime()), ) iter.EndpointFailed(err) @@ -204,6 +210,17 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response } } + if res != nil && err == nil { + err = checkResponseHeaders(rt.config.MaxResponseHeaders, res.Header) + if err != nil { + logger.Error("backend-too-many-response-headers", + log.ErrAttr(err), + slog.Bool("retriable", false), + ) + break + } + } + break } else { logger.Debug( @@ -227,6 +244,19 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response } res, err = rt.timedRoundTrip(roundTripper, request, logger) + + logger = logger.With( + slog.Int("attempt", attempt), + slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)), + slog.Int("num-endpoints", numberOfEndpoints), + slog.Bool("got-connection", trace.GotConn()), + slog.Bool("wrote-headers", trace.WroteHeaders()), + slog.Bool("conn-reused", trace.ConnReused()), + slog.Float64("dns-lookup-time", trace.DnsTime()), + slog.Float64("dial-time", trace.DialTime()), + slog.Float64("tls-handshake-time", trace.TlsTime()), + ) + if err != nil { reqInfo.FailedAttempts++ reqInfo.LastFailedAttemptFinishedAt = time.Now() @@ -236,16 +266,7 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response "route-service-connection-failed", slog.String("route-service-endpoint", request.URL.String()), log.ErrAttr(err), - slog.Int("attempt", attempt), - slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)), slog.Bool("retriable", retriable), - slog.Int("num-endpoints", numberOfEndpoints), - slog.Bool("got-connection", trace.GotConn()), - slog.Bool("wrote-headers", trace.WroteHeaders()), - slog.Bool("conn-reused", trace.ConnReused()), - slog.Float64("dns-lookup-time", trace.DnsTime()), - slog.Float64("dial-time", trace.DialTime()), - slog.Float64("tls-handshake-time", trace.TlsTime()), ) if retriable { @@ -253,6 +274,18 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response } } + if res != nil && err == nil { + err = checkResponseHeaders(rt.config.MaxResponseHeaders, res.Header) + if err != nil { + logger.Error("route-service-too-many-response-headers", + log.ErrAttr(err), + slog.Bool("retriable", false), + ) + break + } + + } + if res != nil && (res.StatusCode < 200 || res.StatusCode >= 300) { logger.Info( "route-service-response", @@ -391,6 +424,21 @@ func (rt *roundTripper) selectEndpoint(iter route.EndpointIterator, request *htt return endpoint, nil } +func checkResponseHeaders(maxCount int, headers http.Header) error { + if maxCount > 0 { + hdrCount := 0 + for _, vv := range headers { + hdrCount += len(vv) + } + + if hdrCount > maxCount { + return TooManyResponseHeaders + } + } + + return nil +} + func setRequestXCfInstanceId(request *http.Request, endpoint *route.Endpoint) { value := endpoint.PrivateInstanceId if value == "" { diff --git a/proxy/round_tripper/proxy_round_tripper_test.go b/proxy/round_tripper/proxy_round_tripper_test.go index f6a9159d..416a7f54 100644 --- a/proxy/round_tripper/proxy_round_tripper_test.go +++ b/proxy/round_tripper/proxy_round_tripper_test.go @@ -1753,6 +1753,49 @@ var _ = Describe("ProxyRoundTripper", func() { Expect(transport.CancelRequestArgsForCall(0)).To(Equal(req)) }) }) + Context("when response headers are limited in count", func() { + // Note: we can only test the header count as the limit on header bytes is + // implemented in the http.Transport which we fake for these tests. + BeforeEach(func() { + cfg.MaxResponseHeaders = 20 + }) + It("returns an error when the response exceeds it", func() { + transport.RoundTripStub = func(r *http.Request) (*http.Response, error) { + header := http.Header{} + for i := 0; i < 21; i++ { + header[fmt.Sprintf("header-%d", i)] = []string{"foobar"} + } + + return &http.Response{ + StatusCode: http.StatusTeapot, + Header: header, + }, nil + } + + _, err := proxyRoundTripper.RoundTrip(req) + + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(round_tripper.TooManyResponseHeaders)) + }) + It("doesn't return an error when the response does not exceed it", func() { + transport.RoundTripStub = func(r *http.Request) (*http.Response, error) { + header := http.Header{} + for i := 0; i < 10; i++ { + header[fmt.Sprintf("header-%d", i)] = []string{"foobar"} + } + + return &http.Response{ + StatusCode: http.StatusTeapot, + Header: header, + }, nil + } + + res, err := proxyRoundTripper.RoundTrip(req) + + Expect(err).NotTo(HaveOccurred()) + Expect(res.StatusCode).To(Equal(http.StatusTeapot)) + }) + }) }) }) })