Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 Feature: Add support for multiple keys in the KeyAuth middleware #3027

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/middleware/keyauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| AdditionalKeyLookups | `[]string` | If additional fallback sources of keys are required, they can be specified here in order of precedence | []string{} (empty) |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(*fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| ContextKey | `interface{}` | Context key to store the bearer token from the token into context. | "token" |
Expand Down
7 changes: 7 additions & 0 deletions middleware/keyauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ type Config struct {
// - "cookie:<name>"
KeyLookup string

// AdditionalKeyLookups is a slice of strings, containing secondary sources of keys if KeyLookup does not find one
// Each element should be a value used in KeyLookup
AdditionalKeyLookups []string

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
Expand Down Expand Up @@ -84,6 +88,9 @@ func configDefault(config ...Config) Config {
cfg.AuthScheme = ConfigDefault.AuthScheme
}
}
if cfg.AdditionalKeyLookups == nil {
cfg.AdditionalKeyLookups = []string{}
}
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
Expand Down
69 changes: 53 additions & 16 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keyauth

import (
"errors"
"fmt"
"net/url"
"strings"

Expand All @@ -19,23 +20,40 @@ const (
cookie = "cookie"
)

type extractorFunc func(c *fiber.Ctx) (string, error)

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)

// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])

var extractor extractorFunc
extractor, err := parseSingleExtractor(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.KeyLookup: %w", err))
}
if len(cfg.AdditionalKeyLookups) > 0 {
subExtractors := map[string]extractorFunc{cfg.KeyLookup: extractor}
for _, keyLookup := range cfg.AdditionalKeyLookups {
subExtractors[keyLookup], err = parseSingleExtractor(keyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.AdditionalKeyLookups[%s]: %w", keyLookup, err))
}
}
extractor = func(c *fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", fmt.Errorf("[%s] %w", keyLookup, err)
}
}
return "", ErrMissingOrMalformedAPIKey
}
}

// Return middleware handler
Expand All @@ -61,8 +79,27 @@ func New(config ...Config) fiber.Handler {
}
}

func parseSingleExtractor(keyLookup, authScheme string) (extractorFunc, error) {
parts := strings.Split(keyLookup, ":")
if len(parts) <= 1 {
return nil, fmt.Errorf("invalid keyLookup")
}
extractor := keyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header:
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
}
return extractor, nil
}

// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c *fiber.Ctx) (string, error) {
func keyFromHeader(header, authScheme string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
auth := c.Get(header)
l := len(authScheme)
Expand All @@ -77,7 +114,7 @@ func keyFromHeader(header, authScheme string) func(c *fiber.Ctx) (string, error)
}

// keyFromQuery returns a function that extracts api key from the query string.
func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) {
func keyFromQuery(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.Query(param)
if key == "" {
Expand All @@ -88,7 +125,7 @@ func keyFromQuery(param string) func(c *fiber.Ctx) (string, error) {
}

// keyFromForm returns a function that extracts api key from the form.
func keyFromForm(param string) func(c *fiber.Ctx) (string, error) {
func keyFromForm(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.FormValue(param)
if key == "" {
Expand All @@ -99,7 +136,7 @@ func keyFromForm(param string) func(c *fiber.Ctx) (string, error) {
}

// keyFromParam returns a function that extracts api key from the url param string.
func keyFromParam(param string) func(c *fiber.Ctx) (string, error) {
func keyFromParam(param string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key, err := url.PathUnescape(c.Params(param))
if err != nil {
Expand All @@ -110,7 +147,7 @@ func keyFromParam(param string) func(c *fiber.Ctx) (string, error) {
}

// keyFromCookie returns a function that extracts api key from the named cookie.
func keyFromCookie(name string) func(c *fiber.Ctx) (string, error) {
func keyFromCookie(name string) extractorFunc {
return func(c *fiber.Ctx) (string, error) {
key := c.Cookies(name)
if key == "" {
Expand Down
46 changes: 46 additions & 0 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,52 @@ func TestAuthSources(t *testing.T) {
}
}

func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
)

// setup the fiber endpoint
app := fiber.New()
authMiddleware := New(Config{
KeyLookup: "header:key",
AdditionalKeyLookups: []string{"cookie:key", "query:key"},
Validator: func(c *fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c *fiber.Ctx) error {
return c.SendString(success)
})

// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
utils.AssertEqual(t, err, nil)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()

res, err := app.Test(req, -1)

utils.AssertEqual(t, nil, err, desc)

// test the body of the request
body, err := io.ReadAll(res.Body)
utils.AssertEqual(t, 200, res.StatusCode, desc)
// body
utils.AssertEqual(t, nil, err, desc)
utils.AssertEqual(t, success, string(body), desc)

err = res.Body.Close()
utils.AssertEqual(t, err, nil)
}

func TestMultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
Expand Down
Loading