Skip to content

Commit

Permalink
fix(middleware/cors): Categorize requests correctly (#2921)
Browse files Browse the repository at this point in the history
* fix(middleware/cors): categorise requests correctly

* test(middleware/cors): improve test coverage for request types

* test(middleware/cors): Add subdomain matching tests

* test(middleware/cors): parallel tests for CORS headers based on request type

* test(middleware/cors): Add benchmark for CORS subdomain matching

* test(middleware/cors): cover additiona test cases

* refactor(middleware/cors): origin validation and normalization
  • Loading branch information
sixcolors authored Mar 20, 2024
1 parent 1aac6f6 commit 1607d87
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 32 deletions.
40 changes: 16 additions & 24 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,33 +119,23 @@ func New(config ...Config) fiber.Handler {
allowSOrigins := []subdomain{}
allowAllOrigins := false

// processOrigin processes an origin string, normalizes it and checks its validity
// it will panic if the origin is invalid
processOrigin := func(origin string) (string, bool) {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin)
panic("[CORS] Invalid origin provided in configuration")
}
return normalizedOrigin, true
}

// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
for _, origin := range origins {
if i := strings.Index(origin, "://*."); i != -1 {
normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:])
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
continue
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
allowSOrigins = append(allowSOrigins, sd)
} else {
normalizedOrigin, isValid := processOrigin(origin)
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
continue
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
Expand All @@ -172,8 +162,9 @@ func New(config ...Config) fiber.Handler {
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))

// If the request does not have an Origin header, the request is outside the scope of CORS
if originHeader == "" {
// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}

Expand Down Expand Up @@ -211,8 +202,9 @@ func New(config ...Config) fiber.Handler {
}

// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}

Expand All @@ -233,14 +225,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if len(allowOrigin) > 0 {
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
Expand Down
Loading

1 comment on commit 1607d87

@ReneWerner87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 1607d87 Previous: d2b19e2 Ratio
Benchmark_CORS_NewHandlerWildcard 457.4 ns/op 0 B/op 0 allocs/op 206.8 ns/op 0 B/op 0 allocs/op 2.21
Benchmark_CORS_NewHandlerWildcardParallel 205.8 ns/op 0 B/op 0 allocs/op 93.06 ns/op 0 B/op 0 allocs/op 2.21

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.