Skip to content

Commit

Permalink
major revision: instead of FallbackKeyLookups, expose CustomKeyLookup…
Browse files Browse the repository at this point in the history
… as function, with utility functions to make creating these easy

Signed-off-by: Dave Lee <dave@gray101.com>
  • Loading branch information
dave-gray101 committed Jun 10, 2024
1 parent a432b80 commit 4e061aa
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 51 deletions.
26 changes: 16 additions & 10 deletions docs/middleware/keyauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000

## Config

| Property | Type | Description | Default |
|:---------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| FallbackKeyLookups | `[]string` | If additional fallback sources of keys are required, they can be specified here in order of precedence | []string{} (empty) |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |
| Property | Type | Description | Default |
|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract the key from the request. | "header:Authorization" |
| CustomKeyLookup | `KeyauthKeyLookupFunc` aka `func(c fiber.Ctx) (string, error)` | If more complex logic is required to extract the key from the request, an arbitrary function to extract it can be specified here. Utility helper functions are described below. | `nil` |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |

## Default Config

Expand All @@ -238,7 +238,13 @@ var ConfigDefault = Config{
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
FallbackKeyLookups: []string{},
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}
```

## CustomKeyLookup

Two public utility functions are provided that may be useful when creating custom extraction:
* `SingleKeyLookup(keyLookup string, authScheme string)`: This is the function that implements the default `KeyLookup` behavior, exposed to be used as a component of custom parsing logic
* `MultipleKeySourceLookup(keyLookups []string, authScheme string)`: Creates a CustomKeyLookup function that checks each listed source using the above function until a key is found or the options are all exhausted
15 changes: 6 additions & 9 deletions middleware/keyauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/gofiber/fiber/v3"
)

type KeyauthKeyLookupFunc func(c fiber.Ctx) (string, error)

// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
Expand All @@ -32,9 +34,7 @@ type Config struct {
// - "cookie:<name>"
KeyLookup string

// FallbackKeyLookups is a slice of strings, containing secondary sources of keys if KeyLookup does not find one
// Each element should be a value used in KeyLookup
FallbackKeyLookups []string
CustomKeyLookup KeyauthKeyLookupFunc

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
Expand All @@ -55,9 +55,9 @@ var ConfigDefault = Config{
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
FallbackKeyLookups: []string{},
AuthScheme: "Bearer",
KeyLookup: "header:" + fiber.HeaderAuthorization,
CustomKeyLookup: nil,
AuthScheme: "Bearer",
}

// Helper function to set default values
Expand All @@ -84,9 +84,6 @@ func configDefault(config ...Config) Config {
cfg.AuthScheme = ConfigDefault.AuthScheme
}
}
if cfg.FallbackKeyLookups == nil {
cfg.FallbackKeyLookups = []string{}
}
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
Expand Down
61 changes: 32 additions & 29 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,17 @@ const (
cookie = "cookie"
)

type extractorFunc func(c fiber.Ctx) (string, error)

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)

// Initialize

var extractor extractorFunc
extractor, err := parseSingleExtractor(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.KeyLookup: %w", err))
}
if len(cfg.FallbackKeyLookups) > 0 {
subExtractors := map[string]extractorFunc{cfg.KeyLookup: extractor}
for _, keyLookup := range cfg.FallbackKeyLookups {
subExtractors[keyLookup], err = parseSingleExtractor(keyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.FallbackKeyLookups[%s]: %w", keyLookup, err))
}
}
extractor = func(c fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", fmt.Errorf("[%s] %w", keyLookup, err)
}
}
return "", ErrMissingOrMalformedAPIKey
if cfg.CustomKeyLookup == nil {
var err error
cfg.CustomKeyLookup, err = SingleKeyLookup(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("unable to create lookup function: %w", err))

Check warning on line 42 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L42

Added line #L42 was not covered by tests
}
}

Expand All @@ -73,7 +51,7 @@ func New(config ...Config) fiber.Handler {
}

// Extract and verify key
key, err := extractor(c)
key, err := cfg.CustomKeyLookup(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}
Expand All @@ -98,7 +76,32 @@ func TokenFromContext(c fiber.Ctx) string {
return token
}

func parseSingleExtractor(keyLookup, authScheme string) (extractorFunc, error) {
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found
// Each element should be specified according to the format used in KeyLookup
func MultipleKeySourceLookup(keyLookups []string, authScheme string) (KeyauthKeyLookupFunc, error) {
subExtractors := map[string]KeyauthKeyLookupFunc{}
var err error
for _, keyLookup := range keyLookups {
subExtractors[keyLookup], err = SingleKeyLookup(keyLookup, authScheme)
if err != nil {
return nil, err

Check warning on line 87 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L87

Added line #L87 was not covered by tests
}
}
return func(c fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", fmt.Errorf("[%s] %w", keyLookup, err)

Check warning on line 97 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L97

Added line #L97 was not covered by tests
}
}
return "", ErrMissingOrMalformedAPIKey

Check warning on line 100 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L100

Added line #L100 was not covered by tests
}, nil
}

func SingleKeyLookup(keyLookup, authScheme string) (KeyauthKeyLookupFunc, error) {
parts := strings.Split(keyLookup, ":")
if len(parts) <= 1 {
return nil, fmt.Errorf("invalid keyLookup: %s", keyLookup)

Check warning on line 107 in middleware/keyauth/keyauth.go

View check run for this annotation

Codecov / codecov/patch

middleware/keyauth/keyauth.go#L107

Added line #L107 was not covered by tests
Expand Down
10 changes: 7 additions & 3 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,17 @@ func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
scheme = "Bearer"
)

// setup the fiber endpoint
app := fiber.New()

customKeyLookup, err := MultipleKeySourceLookup([]string{"header:key", "cookie:key", "query:key"}, scheme)
require.NoError(t, err)

authMiddleware := New(Config{
KeyLookup: "header:key",
FallbackKeyLookups: []string{"cookie:key", "query:key"},
CustomKeyLookup: customKeyLookup,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
Expand All @@ -155,7 +159,7 @@ func TestMultipleKeyLookup(t *testing.T) {

// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.NoError(t, err)
q := req.URL.Query()
q.Add("key", CorrectKey)
Expand Down

0 comments on commit 4e061aa

Please sign in to comment.