diff --git a/handlers/max_request_size.go b/handlers/max_request_size.go index 939c7dde..9700ca67 100644 --- a/handlers/max_request_size.go +++ b/handlers/max_request_size.go @@ -41,18 +41,24 @@ func NewMaxRequestSize(cfg *config.Config, logger *slog.Logger) *MaxRequestSize func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { logger := LoggerWithTraceInfo(m.logger, r) - reqSize := len(r.Method) + len(r.URL.RequestURI()) + len(r.Proto) + 5 // add 5 bytes for space-separation of method, URI, protocol, and /r/n - for k, v := range r.Header { - valueLen := 0 - for _, value := range r.Header.Values(k) { - valueLen += len(value) + // Four additional bytes for the two spaces and \r\n: + // GET / HTTP/1.1\r\n + reqSize := len(r.Method) + len(r.URL.RequestURI()) + len(r.Proto) + 4 + + // Host header which is not passed on to us, plus eight bytes for 'Host: ' and \r\n + reqSize += len(r.Host) + 8 + + // Go doesn't split header values on commas, instead it only splits the value when it's + // provided via repeated header keys. Therefore we have to account for each value of a repeated + // header as well as its key. + for k, vv := range r.Header { + for _, v := range vv { + // Four additional bytes for the colon and space after the header key and \r\n. + reqSize += len(k) + len(v) + 4 } - reqSize += len(k) + valueLen + 4 + len(v) - 1 // add padding for ': ' and newlines and comma delimiting of multiple values } - reqSize += len(r.Host) + 8 // add padding for "Host: " and newlines - if reqSize >= m.MaxSize { reqInfo, err := ContextRequestInfo(r) if err != nil { diff --git a/handlers/max_request_size_test.go b/handlers/max_request_size_test.go index 79406a84..863cce50 100644 --- a/handlers/max_request_size_test.go +++ b/handlers/max_request_size_test.go @@ -166,6 +166,17 @@ var _ = Describe("MaxRequestSize", func() { Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) }) }) + Context("when a repeated header has a short value and long key taking it over the limit", func() { + BeforeEach(func() { + for i := 0; i < 10; i++ { + header.Add("foobar", "m") + } + }) + It("throws an http 431", func() { + handleRequest() + Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) + }) + }) Context("when enough normally-sized headers put the request over the limit", func() { BeforeEach(func() { header.Add("header1", "smallRequest")