Skip to content

Commit

Permalink
feat: configurable limits for header counts
Browse files Browse the repository at this point in the history
This commit adds support to limit the amount of request and response
headers gorouter will accept and process.

Resolves: cloudfoundry/routing-release#309
  • Loading branch information
maxmoehl committed Nov 12, 2024
1 parent 7b766d6 commit f4d4095
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 26 deletions.
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ type Config struct {
MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host,omitempty"`
MaxRequestHeaderBytes int `yaml:"max_header_bytes"`
MaxResponseHeaderBytes int `yaml:"max_response_header_bytes"`
MaxRequestHeaders int `yaml:"max_request_headers"`
MaxResponseHeaders int `yaml:"max_response_headers"`
KeepAlive100ContinueRequests bool `yaml:"keep_alive_100_continue_requests"`

HTTPRewrite HTTPRewrite `yaml:"http_rewrite,omitempty"`
Expand Down
19 changes: 12 additions & 7 deletions handlers/max_request_size.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
)

type MaxRequestSize struct {
cfg *config.Config
MaxSize int
logger *slog.Logger
cfg *config.Config
MaxSize int
MaxCount int
logger *slog.Logger
}

const ONE_MB = 1024 * 1024 // bytes * kb
Expand All @@ -33,9 +34,10 @@ func NewMaxRequestSize(cfg *config.Config, logger *slog.Logger) *MaxRequestSize
}

return &MaxRequestSize{
MaxSize: maxSize,
logger: logger,
cfg: cfg,
MaxSize: maxSize,
MaxCount: cfg.MaxRequestHeaders,
logger: logger,
cfg: cfg,
}
}

Expand All @@ -49,17 +51,20 @@ func (m *MaxRequestSize) ServeHTTP(rw http.ResponseWriter, r *http.Request, next
// Host header which is not passed on to us, plus eight bytes for 'Host: ' and \r\n
reqSize += len(r.Host) + 8

hdrCount := 0

// Go doesn't split header values on commas, instead it only splits the value when it's
// provided via repeated header keys. Therefore we have to account for each value of a repeated
// header as well as its key.
for k, vv := range r.Header {
for _, v := range vv {
// Four additional bytes for the colon and space after the header key and \r\n.
reqSize += len(k) + len(v) + 4
hdrCount++
}
}

if reqSize >= m.MaxSize {
if reqSize >= m.MaxSize || (m.MaxCount > 0 && hdrCount > m.MaxCount) {
reqInfo, err := ContextRequestInfo(r)
if err != nil {
logger.Error("request-info-err", log.ErrAttr(err))
Expand Down
12 changes: 12 additions & 0 deletions handlers/max_request_size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ var _ = Describe("MaxRequestSize", func() {
BeforeEach(func() {
cfg = &config.Config{
MaxRequestHeaderBytes: 89,
MaxRequestHeaders: 15,
LoadBalance: config.LOAD_BALANCE_RR,
StickySessionCookieNames: config.StringSet{"blarg": struct{}{}},
}
Expand Down Expand Up @@ -177,6 +178,17 @@ var _ = Describe("MaxRequestSize", func() {
Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge))
})
})
Context("when there are too many headers", func() {
BeforeEach(func() {
for i := 0; i < 16; i++ {
header.Add("f", "m")
}
})
It("throws an http 431", func() {
handleRequest()
Expect(result.StatusCode).To(Equal(http.StatusRequestHeaderFieldsTooLarge))
})
})
Context("when enough normally-sized headers put the request over the limit", func() {
BeforeEach(func() {
header.Add("header1", "smallRequest")
Expand Down
89 changes: 70 additions & 19 deletions proxy/round_tripper/proxy_round_tripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ const (
AuthNegotiateHeaderCookieMaxAgeInSeconds = 60
)

var NoEndpointsAvailable = errors.New("No endpoints available")
var (
NoEndpointsAvailable = errors.New("No endpoints available")
TooManyResponseHeaders = errors.New("too many response headers")
)

//go:generate counterfeiter -o fakes/fake_proxy_round_tripper.go . ProxyRoundTripper
type ProxyRoundTripper interface {
Expand Down Expand Up @@ -178,23 +181,26 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
}
res, err = rt.backendRoundTrip(request, endpoint, iter, logger)

logger = logger.With(
slog.Int("attempt", attempt),
slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)),
slog.Int("num-endpoints", numberOfEndpoints),
slog.Bool("got-connection", trace.GotConn()),
slog.Bool("wrote-headers", trace.WroteHeaders()),
slog.Bool("conn-reused", trace.ConnReused()),
slog.Float64("dns-lookup-time", trace.DnsTime()),
slog.Float64("dial-time", trace.DialTime()),
slog.Float64("tls-handshake-time", trace.TlsTime()),
)

if err != nil {
reqInfo.FailedAttempts++
reqInfo.LastFailedAttemptFinishedAt = time.Now()
retriable, err := rt.isRetriable(request, err, trace)

logger.Error("backend-endpoint-failed",
log.ErrAttr(err),
slog.Int("attempt", attempt),
slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)),
slog.Bool("retriable", retriable),
slog.Int("num-endpoints", numberOfEndpoints),
slog.Bool("got-connection", trace.GotConn()),
slog.Bool("wrote-headers", trace.WroteHeaders()),
slog.Bool("conn-reused", trace.ConnReused()),
slog.Float64("dns-lookup-time", trace.DnsTime()),
slog.Float64("dial-time", trace.DialTime()),
slog.Float64("tls-handshake-time", trace.TlsTime()),
)

iter.EndpointFailed(err)
Expand All @@ -204,6 +210,17 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
}
}

if res != nil && err == nil {
err = checkResponseHeaders(rt.config.MaxResponseHeaders, res.Header)
if err != nil {
logger.Error("backend-too-many-response-headers",
log.ErrAttr(err),
slog.Bool("retriable", false),
)
break
}
}

break
} else {
logger.Debug(
Expand All @@ -227,6 +244,19 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
}

res, err = rt.timedRoundTrip(roundTripper, request, logger)

logger = logger.With(
slog.Int("attempt", attempt),
slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)),
slog.Int("num-endpoints", numberOfEndpoints),
slog.Bool("got-connection", trace.GotConn()),
slog.Bool("wrote-headers", trace.WroteHeaders()),
slog.Bool("conn-reused", trace.ConnReused()),
slog.Float64("dns-lookup-time", trace.DnsTime()),
slog.Float64("dial-time", trace.DialTime()),
slog.Float64("tls-handshake-time", trace.TlsTime()),
)

if err != nil {
reqInfo.FailedAttempts++
reqInfo.LastFailedAttemptFinishedAt = time.Now()
Expand All @@ -236,23 +266,26 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
"route-service-connection-failed",
slog.String("route-service-endpoint", request.URL.String()),
log.ErrAttr(err),
slog.Int("attempt", attempt),
slog.String("vcap_request_id", request.Header.Get(handlers.VcapRequestIdHeader)),
slog.Bool("retriable", retriable),
slog.Int("num-endpoints", numberOfEndpoints),
slog.Bool("got-connection", trace.GotConn()),
slog.Bool("wrote-headers", trace.WroteHeaders()),
slog.Bool("conn-reused", trace.ConnReused()),
slog.Float64("dns-lookup-time", trace.DnsTime()),
slog.Float64("dial-time", trace.DialTime()),
slog.Float64("tls-handshake-time", trace.TlsTime()),
)

if retriable {
continue
}
}

if res != nil && err == nil {
err = checkResponseHeaders(rt.config.MaxResponseHeaders, res.Header)
if err != nil {
logger.Error("route-service-too-many-response-headers",
log.ErrAttr(err),
slog.Bool("retriable", false),
)
break
}

}

if res != nil && (res.StatusCode < 200 || res.StatusCode >= 300) {
logger.Info(
"route-service-response",
Expand Down Expand Up @@ -391,6 +424,24 @@ func (rt *roundTripper) selectEndpoint(iter route.EndpointIterator, request *htt
return endpoint, nil
}

func checkResponseHeaders(maxCount int, headers http.Header) error {
if maxCount > 0 {
// Go doesn't split header values on commas, instead it only splits the value when it's
// provided via repeated header keys. We can therefore get the number of header lines by
// checking how many values are in the map.
hdrCount := 0
for _, vv := range headers {
hdrCount += len(vv)
}

if hdrCount > maxCount {
return TooManyResponseHeaders
}
}

return nil
}

func setRequestXCfInstanceId(request *http.Request, endpoint *route.Endpoint) {
value := endpoint.PrivateInstanceId
if value == "" {
Expand Down
43 changes: 43 additions & 0 deletions proxy/round_tripper/proxy_round_tripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,49 @@ var _ = Describe("ProxyRoundTripper", func() {
Expect(transport.CancelRequestArgsForCall(0)).To(Equal(req))
})
})
Context("when response headers are limited in count", func() {
// Note: we can only test the header count as the limit on header bytes is
// implemented in the http.Transport which we fake for these tests.
BeforeEach(func() {
cfg.MaxResponseHeaders = 20
})
It("returns an error when the response exceeds it", func() {
transport.RoundTripStub = func(r *http.Request) (*http.Response, error) {
header := http.Header{}
for i := 0; i < 21; i++ {
header[fmt.Sprintf("header-%d", i)] = []string{"foobar"}
}

return &http.Response{
StatusCode: http.StatusTeapot,
Header: header,
}, nil
}

_, err := proxyRoundTripper.RoundTrip(req)

Expect(err).To(HaveOccurred())
Expect(err).To(Equal(round_tripper.TooManyResponseHeaders))
})
It("doesn't return an error when the response does not exceed it", func() {
transport.RoundTripStub = func(r *http.Request) (*http.Response, error) {
header := http.Header{}
for i := 0; i < 10; i++ {
header[fmt.Sprintf("header-%d", i)] = []string{"foobar"}
}

return &http.Response{
StatusCode: http.StatusTeapot,
Header: header,
}, nil
}

res, err := proxyRoundTripper.RoundTrip(req)

Expect(err).NotTo(HaveOccurred())
Expect(res.StatusCode).To(Equal(http.StatusTeapot))
})
})
})
})
})

0 comments on commit f4d4095

Please sign in to comment.