Skip to content

Commit

Permalink
port over FallbackKeyLookups from v2 middleware to v3
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Lee <dave@gray101.com>
  • Loading branch information
dave-gray101 committed Jun 9, 2024
1 parent e561026 commit 0b6b5e9
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/middleware/keyauth.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `nil` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. | `401 Invalid or expired key` |
| KeyLookup | `string` | KeyLookup is a string in the form of "`<source>:<name>`" that is used to extract key from the request. | "header:Authorization" |
| FallbackKeyLookups | `[]string` | If additional fallback sources of keys are required, they can be specified here in order of precedence | []string{} (empty) |
| AuthScheme | `string` | AuthScheme to be used in the Authorization header. | "Bearer" |
| Validator | `func(fiber.Ctx, string) (bool, error)` | Validator is a function to validate the key. | A function for key validation |

Expand All @@ -237,6 +238,7 @@ var ConfigDefault = Config{
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
FallbackKeyLookups: []string{},
AuthScheme: "Bearer",
}
```
12 changes: 10 additions & 2 deletions middleware/keyauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ type Config struct {
// - "cookie:<name>"
KeyLookup string

// FallbackKeyLookups is a slice of strings, containing secondary sources of keys if KeyLookup does not find one
// Each element should be a value used in KeyLookup
FallbackKeyLookups []string

// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
Expand All @@ -51,8 +55,9 @@ var ConfigDefault = Config{
}
return c.Status(fiber.StatusUnauthorized).SendString("Invalid or expired API Key")
},
KeyLookup: "header:" + fiber.HeaderAuthorization,
AuthScheme: "Bearer",
KeyLookup: "header:" + fiber.HeaderAuthorization,
FallbackKeyLookups: []string{},
AuthScheme: "Bearer",
}

// Helper function to set default values
Expand All @@ -79,6 +84,9 @@ func configDefault(config ...Config) Config {
cfg.AuthScheme = ConfigDefault.AuthScheme
}
}
if cfg.FallbackKeyLookups == nil {
cfg.FallbackKeyLookups = []string{}
}
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
Expand Down
59 changes: 48 additions & 11 deletions middleware/keyauth/keyauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keyauth

import (
"errors"
"fmt"
"net/url"
"strings"

Expand All @@ -28,23 +29,40 @@ const (
cookie = "cookie"
)

type extractorFunc func(c fiber.Ctx) (string, error)

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)

// Initialize
parts := strings.Split(cfg.KeyLookup, ":")
extractor := keyFromHeader(parts[1], cfg.AuthScheme)
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])

var extractor extractorFunc
extractor, err := parseSingleExtractor(cfg.KeyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.KeyLookup: %w", err))
}
if len(cfg.FallbackKeyLookups) > 0 {
subExtractors := map[string]extractorFunc{cfg.KeyLookup: extractor}
for _, keyLookup := range cfg.FallbackKeyLookups {
subExtractors[keyLookup], err = parseSingleExtractor(keyLookup, cfg.AuthScheme)
if err != nil {
panic(fmt.Errorf("error creating middleware: invalid keyauth Config.FallbackKeyLookups[%s]: %w", keyLookup, err))
}
}
extractor = func(c fiber.Ctx) (string, error) {
for keyLookup, subExtractor := range subExtractors {
res, err := subExtractor(c)
if err == nil && res != "" {
return res, nil
}
if !errors.Is(err, ErrMissingOrMalformedAPIKey) {
return "", fmt.Errorf("[%s] %w", keyLookup, err)
}
}
return "", ErrMissingOrMalformedAPIKey
}
}

// Return middleware handler
Expand Down Expand Up @@ -80,6 +98,25 @@ func TokenFromContext(c fiber.Ctx) string {
return token
}

func parseSingleExtractor(keyLookup string, authScheme string) (extractorFunc, error) {
parts := strings.Split(keyLookup, ":")
if len(parts) <= 1 {
return nil, fmt.Errorf("invalid keyLookup")
}
extractor := keyFromHeader(parts[1], authScheme) // in the event of an invalid prefix, it is interpreted as header:
switch parts[0] {
case query:
extractor = keyFromQuery(parts[1])
case form:
extractor = keyFromForm(parts[1])
case param:
extractor = keyFromParam(parts[1])
case cookie:
extractor = keyFromCookie(parts[1])
}
return extractor, nil
}

// keyFromHeader returns a function that extracts api key from the request header.
func keyFromHeader(header, authScheme string) func(c fiber.Ctx) (string, error) {
return func(c fiber.Ctx) (string, error) {
Expand Down
46 changes: 46 additions & 0 deletions middleware/keyauth/keyauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,52 @@ func Test_AuthSources(t *testing.T) {
}
}

func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
)

// setup the fiber endpoint
app := fiber.New()
authMiddleware := New(Config{
KeyLookup: "header:key",
FallbackKeyLookups: []string{"cookie:key", "query:key"},
Validator: func(c fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c fiber.Ctx) error {
return c.SendString(success)
})

// construct the test HTTP request
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil)
require.Equal(t, err, nil)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()

res, err := app.Test(req, -1)

require.Equal(t, nil, err, desc)

// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, 200, res.StatusCode, desc)
// body
require.Equal(t, nil, err, desc)
require.Equal(t, success, string(body), desc)

err = res.Body.Close()
require.Equal(t, err, nil)
}

func Test_MultipleKeyAuth(t *testing.T) {
// setup the fiber endpoint
app := fiber.New()
Expand Down

0 comments on commit 0b6b5e9

Please sign in to comment.