Skip to content

Commit

Permalink
777 duplicate CORS response headers (#804)
Browse files Browse the repository at this point in the history
* added max_age to test; added CORS response headers to test as if coming from 'backend'

* added registering of header modifiers that run before or after header modifiers configured by (add|remove|set)_response_headers attributes

* register deleting all existing CORS response headers before setting/adding specific ones, to be run _before_ header modifiers configured by (add|remove|set)_response_headers attributes

* added changelog entry

* fix: setting eval context just once per response writer obj

* rename shadowed var (interface)

* refactor: naming and simplify since we have no after usecase atm

---------

Co-authored-by: Marcel Ludwig <1841067+malud@users.noreply.github.com>
Co-authored-by: Marcel Ludwig <marcel.ludwig@milecrew.com>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 033ccc3 commit ec5f39d
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Unreleased changes are available as `coupergateway/couper:edge` container.
* Selecting of appropriate [error handler](https://docs.couper.io/configuration/block/error_handler) in two cases ([#753](https://github.com/coupergateway/couper/pull/753))
* Storing of digit-starting string object keys in [request context](https://docs.couper.io/configuration/variables#request) and of digit-starting string header field names in [request](https://docs.couper.io/configuration/variables#request) variable ([#799](https://github.com/coupergateway/couper/pull/799))
* Use of boolean values for the `headers` attribute or [modifiers](https://docs.couper.io/configuration/modifiers) ([#805](https://github.com/coupergateway/couper/pull/805))
* Duplicate [CORS](https://docs.couper.io/configuration/block/cors) response headers (with backend sending CORS response headers, too) ([#804](https://github.com/coupergateway/couper/pull/804))

* **Dependencies**
* build with go 1.21 ([#800](https://github.com/coupergateway/couper/pull/800))
Expand Down
2 changes: 1 addition & 1 deletion handler/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (e *Endpoint) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

w.AddModifier(httpCtx, e.modifier...)
w.AddModifier(e.modifier...)
rw = w
}

Expand Down
10 changes: 3 additions & 7 deletions handler/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/coupergateway/couper/config/runtime/server"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/eval"
"github.com/coupergateway/couper/server/writer"
"github.com/coupergateway/couper/utils"
)
Expand Down Expand Up @@ -82,8 +81,7 @@ func (f *File) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}

if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

http.ServeContent(rw, req, reqPath, info.ModTime(), file)
Expand All @@ -97,8 +95,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http.

if !strings.HasSuffix(reqPath, "/") {
if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

rw.Header().Set("Location", utils.JoinPath(req.URL.Path, "/"))
Expand All @@ -116,8 +113,7 @@ func (f *File) serveDirectory(reqPath string, rw http.ResponseWriter, req *http.
defer file.Close()

if r, ok := rw.(*writer.Response); ok {
evalContext := eval.ContextFromRequest(req)
r.AddModifier(evalContext.HCLContext(), f.modifier...)
r.AddModifier(f.modifier...)
}

http.ServeContent(rw, req, reqPath, info.ModTime(), file)
Expand Down
3 changes: 1 addition & 2 deletions handler/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import (
"reflect"
"testing"

"github.com/coupergateway/couper/server/writer"

"github.com/coupergateway/couper/handler"
"github.com/coupergateway/couper/server/writer"
)

func TestHealth_ServeHTTP(t *testing.T) {
Expand Down
12 changes: 11 additions & 1 deletion handler/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/coupergateway/couper/config"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/internal/seetie"
"github.com/coupergateway/couper/server/writer"
)

var _ http.Handler = &CORS{}
Expand Down Expand Up @@ -79,7 +80,11 @@ func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler {
}

func (c *CORS) ServeNextHTTP(rw http.ResponseWriter, nextHandler http.Handler, req *http.Request) {
c.setCorsRespHeaders(rw.Header(), req)
if response, ok := rw.(*writer.Response); ok {
response.AddHeaderModifier(func(header http.Header) {
c.setCorsRespHeaders(header, req)
})
}

if c.isCorsPreflightRequest(req) {
rw.WriteHeader(http.StatusNoContent)
Expand All @@ -100,6 +105,11 @@ func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {
}

func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
headers.Del("Access-Control-Allow-Origin")
headers.Del("Access-Control-Allow-Credentials")
headers.Del("Access-Control-Allow-Headers")
headers.Del("Access-Control-Allow-Methods")
headers.Del("Access-Control-Max-Age")
// see https://fetch.spec.whatwg.org/#http-responses
allowSpecificOrigin := false
if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
Expand Down
9 changes: 6 additions & 3 deletions handler/middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net/http"
"net/http/httptest"
"testing"

"github.com/coupergateway/couper/server/writer"
)

func TestCORSOptions_AllowsOrigin(t *testing.T) {
Expand Down Expand Up @@ -330,7 +332,8 @@ func TestCORS_ServeHTTP(t *testing.T) {
}

rec := httptest.NewRecorder()
corsHandler.ServeHTTP(rec, req)
r := writer.NewResponseWriter(rec, "")
corsHandler.ServeHTTP(r, req)

if !rec.Flushed {
rec.Flush()
Expand Down Expand Up @@ -547,8 +550,8 @@ func TestProxy_ServeHTTP_CORS_PFC(t *testing.T) {
}

rec := httptest.NewRecorder()

corsHandler.ServeHTTP(rec, req)
r := writer.NewResponseWriter(rec, "")
corsHandler.ServeHTTP(r, req)

if !rec.Flushed {
rec.Flush()
Expand Down
3 changes: 1 addition & 2 deletions handler/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ func (p *Proxy) registerWebsocketsResponse(req *http.Request) error {
}

wsBody := p.getWebsocketsBody()
evalCtx := eval.ContextFromRequest(req)

if rw, ok := req.Context().Value(request.ResponseWriter).(*writer.Response); ok {
rw.AddModifier(evalCtx.HCLContextSync(), wsBody, p.context)
rw.AddModifier(wsBody, p.context)
}

return nil
Expand Down
4 changes: 1 addition & 3 deletions handler/spa.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/coupergateway/couper/config"
"github.com/coupergateway/couper/config/runtime/server"
"github.com/coupergateway/couper/errors"
"github.com/coupergateway/couper/eval"
"github.com/coupergateway/couper/server/writer"
)

Expand Down Expand Up @@ -78,9 +77,8 @@ func (s *Spa) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var content io.ReadSeeker
var modTime time.Time

evalContext := eval.ContextFromRequest(req)
if r, ok := rw.(*writer.Response); ok {
r.AddModifier(evalContext.HCLContext(), s.modifier...)
r.AddModifier(s.modifier...)
}

if l := len(s.bootstrapContent); l > 0 {
Expand Down
7 changes: 6 additions & 1 deletion server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/coupergateway/couper/handler"
"github.com/coupergateway/couper/handler/middleware"
"github.com/coupergateway/couper/logging"
"github.com/coupergateway/couper/server/writer"
"github.com/coupergateway/couper/telemetry/instrumentation"
"github.com/coupergateway/couper/telemetry/provider"
)
Expand Down Expand Up @@ -309,7 +310,11 @@ func (s *HTTPServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// due to the middleware callee stack we have to update the 'req' value.
*req = *req.WithContext(s.evalCtx.WithClientRequest(req.WithContext(ctx)))

h.ServeHTTP(rw, req)
w := rw
if respW, is := rw.(*writer.Response); is {
w = respW.WithEvalContext(eval.ContextFromRequest(req))
}
h.ServeHTTP(w, req)
}

func (s *HTTPServer) setGetBody(h http.Handler, req *http.Request) (opt buffer.Option, err error) {
Expand Down
31 changes: 31 additions & 0 deletions server/http_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4611,6 +4611,7 @@ func TestCORS_Configuration(t *testing.T) {
acam, acamExists := res.Header["Access-Control-Allow-Methods"]
acah, acahExists := res.Header["Access-Control-Allow-Headers"]
acac, acacExists := res.Header["Access-Control-Allow-Credentials"]
acax, acaxExists := res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4624,6 +4625,9 @@ func TestCORS_Configuration(t *testing.T) {
if !acacExists || acac[0] != "true" {
subT.Errorf("Expected allowed credentials, got: %v", acac)
}
if !acaxExists || acax[0] != "200" {
subT.Errorf("Expected max-age 200, got: %v", acax)
}
} else {
if acaoExists {
subT.Errorf("Expected not allowed origin, got: %v", acao)
Expand All @@ -4634,6 +4638,9 @@ func TestCORS_Configuration(t *testing.T) {
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
if acacExists {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
Expand All @@ -4660,6 +4667,9 @@ func TestCORS_Configuration(t *testing.T) {

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
acam, acamExists = res.Header["Access-Control-Allow-Methods"]
acah, acahExists = res.Header["Access-Control-Allow-Headers"]
acax, acaxExists = res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4675,6 +4685,15 @@ func TestCORS_Configuration(t *testing.T) {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVary {
subT.Errorf("Expected vary %q, got: %q", tc.expVary, strings.Join(vary, ","))
Expand All @@ -4698,6 +4717,9 @@ func TestCORS_Configuration(t *testing.T) {

acao, acaoExists = res.Header["Access-Control-Allow-Origin"]
acac, acacExists = res.Header["Access-Control-Allow-Credentials"]
acam, acamExists = res.Header["Access-Control-Allow-Methods"]
acah, acahExists = res.Header["Access-Control-Allow-Headers"]
acax, acaxExists = res.Header["Access-Control-Max-Age"]
if tc.expAllowed {
if !acaoExists || acao[0] != tc.origin {
subT.Errorf("Expected allowed origin, got: %v", acao)
Expand All @@ -4713,6 +4735,15 @@ func TestCORS_Configuration(t *testing.T) {
subT.Errorf("Expected not allowed credentials, got: %v", acac)
}
}
if acamExists {
subT.Errorf("Expected not allowed methods, got: %v", acam)
}
if acahExists {
subT.Errorf("Expected not allowed headers, got: %v", acah)
}
if acaxExists {
subT.Errorf("Expected not max-age, got: %v", acax)
}
vary, varyExists = res.Header["Vary"]
if !varyExists || strings.Join(vary, ",") != tc.expVaryCred {
subT.Errorf("Expected vary %q, got: %q", tc.expVaryCred, strings.Join(vary, ","))
Expand Down
13 changes: 12 additions & 1 deletion server/testdata/integration/config/06_couper.hcl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ server "cors" {
cors {
allowed_origins = "a.com"
allow_credentials = true
max_age = "200s"
}
}

Expand All @@ -15,6 +16,7 @@ server "cors" {
cors {
allowed_origins = "b.com"
allow_credentials = true
max_age = "200s"
}
}

Expand All @@ -23,9 +25,18 @@ server "cors" {
cors {
allowed_origins = "c.com"
allow_credentials = true
max_age = "200s"
}
endpoint "/" {
response {}
response {
headers = {
access-control-allow-origin = "foo"
access-control-allow-credentials = "bar"
access-control-allow-methods = "BREW"
access-control-allow-headers = "Auth"
access-control-max-age = 300
}
}
}
}
}
Expand Down
51 changes: 37 additions & 14 deletions server/writer/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ type writer interface {
}

type modifier interface {
AddModifier(*hcl.EvalContext, ...hcl.Body)
AddModifier(...hcl.Body)
AddHeaderModifier(HeaderModifier)
}

var (
Expand All @@ -34,6 +35,8 @@ var (
endOfLine = []byte("\r\n")
)

type HeaderModifier func(header http.Header)

// Response wraps the http.ResponseWriter.
type Response struct {
hijackedConn net.Conn
Expand All @@ -45,21 +48,28 @@ type Response struct {
statusCode int
rawBytesWritten int
bytesWritten int
// modifier
evalCtx *hcl.EvalContext
modifier []hcl.Body
// modifiers
evalCtx *eval.Context
modifiers []hcl.Body
headerModifiers []HeaderModifier
// security
addPrivateCC bool
}

// NewResponseWriter creates a new Response object.
// NewResponseWriter creates a new ResponseWriter. It wraps the http.ResponseWriter.
func NewResponseWriter(rw http.ResponseWriter, secureCookies string) *Response {
return &Response{
rw: rw,
secureCookies: secureCookies,
}
}

// WithEvalContext sets the eval context for the response modifiers.
func (r *Response) WithEvalContext(ctx *eval.Context) *Response {
r.evalCtx = ctx
return r
}

// Header wraps the Header method of the <http.ResponseWriter>.
func (r *Response) Header() http.Header {
return r.rw.Header()
Expand Down Expand Up @@ -135,9 +145,10 @@ func (r *Response) WriteHeader(statusCode int) {
}

r.configureHeader()
r.applyModifier()
r.applyHeaderModifiers()
r.applyModifiers() // hcl body modifiers

// !!! Execute after modifier !!!
// execute after modifiers
if r.addPrivateCC {
r.Header().Add("Cache-Control", "private")
}
Expand Down Expand Up @@ -193,17 +204,29 @@ func (r *Response) AddPrivateCC() {
r.addPrivateCC = true
}

func (r *Response) AddModifier(evalCtx *hcl.EvalContext, modifier ...hcl.Body) {
r.evalCtx = evalCtx
r.modifier = append(r.modifier, modifier...)
func (r *Response) AddModifier(modifier ...hcl.Body) {
r.modifiers = append(r.modifiers, modifier...)
}

func (r *Response) applyModifier() {
if r.evalCtx == nil || r.modifier == nil {
// applyModifiers applies the hcl body modifiers to the response.
func (r *Response) applyModifiers() {
if r.evalCtx == nil || r.modifiers == nil {
return
}

for _, body := range r.modifier {
_ = eval.ApplyResponseHeaderOps(r.evalCtx, body, r.Header())
hctx := r.evalCtx.HCLContextSync()
for _, body := range r.modifiers {
_ = eval.ApplyResponseHeaderOps(hctx, body, r.Header())
}
}

func (r *Response) AddHeaderModifier(headerModifier HeaderModifier) {
r.headerModifiers = append(r.headerModifiers, headerModifier)
}

// applyHeaderModifiers applies the http.Header modifiers to the response.
func (r *Response) applyHeaderModifiers() {
for _, modifierFn := range r.headerModifiers {
modifierFn(r.Header())
}
}

0 comments on commit ec5f39d

Please sign in to comment.