From e9ba2e154549b46e4b38ef47ad27a81caa4c6cad Mon Sep 17 00:00:00 2001 From: Maximilian Moehl Date: Thu, 7 Nov 2024 10:08:43 +0100 Subject: [PATCH] fix: count repeated header keys for header limits When counting they bytes sent by the client in the request line and headers, repeated header keys are not counted. This commit adjusts the logic to account for repeated keys. --- handlers/max_request_size.go | 22 ++++++++++++++-------- handlers/max_request_size_test.go | 11 +++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/handlers/max_request_size.go b/handlers/max_request_size.go index 939c7dde..9700ca67 100644 --- a/handlers/max_request_size.go +++ b/handlers/max_request_size.go @@ -41,18 +41,24 @@ func NewMaxRequestSize(cfg *config.Config, logger *slog.Logger) *MaxRequestSize func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { logger := LoggerWithTraceInfo(m.logger, r) - reqSize := len(r.Method) + len(r.URL.RequestURI()) + len(r.Proto) + 5 // add 5 bytes for space-separation of method, URI, protocol, and /r/n - for k, v := range r.Header { - valueLen := 0 - for _, value := range r.Header.Values(k) { - valueLen += len(value) + // Four additional bytes for the two spaces and \r\n: + // GET / HTTP/1.1\r\n + reqSize := len(r.Method) + len(r.URL.RequestURI()) + len(r.Proto) + 4 + + // Host header which is not passed on to us, plus eight bytes for 'Host: ' and \r\n + reqSize += len(r.Host) + 8 + + // Go doesn't split header values on commas, instead it only splits the value when it's + // provided via repeated header keys. Therefore we have to account for each value of a repeated + // header as well as its key. + for k, vv := range r.Header { + for _, v := range vv { + // Four additional bytes for the colon and space after the header key and \r\n. + reqSize += len(k) + len(v) + 4 } - reqSize += len(k) + valueLen + 4 + len(v) - 1 // add padding for ': ' and newlines and comma delimiting of multiple values } - reqSize += len(r.Host) + 8 // add padding for "Host: " and newlines - if reqSize >= m.MaxSize { reqInfo, err := ContextRequestInfo(r) if err != nil { diff --git a/handlers/max_request_size_test.go b/handlers/max_request_size_test.go index 79406a84..863cce50 100644 --- a/handlers/max_request_size_test.go +++ b/handlers/max_request_size_test.go @@ -166,6 +166,17 @@ var _ = Describe("MaxRequestSize", func() { Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) }) }) + Context("when a repeated header has a short value and long key taking it over the limit", func() { + BeforeEach(func() { + for i := 0; i < 10; i++ { + header.Add("foobar", "m") + } + }) + It("throws an http 431", func() { + handleRequest() + Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge)) + }) + }) Context("when enough normally-sized headers put the request over the limit", func() { BeforeEach(func() { header.Add("header1", "smallRequest")