diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 9d9b3395da..fb9e54de93 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -131,6 +131,51 @@ func TestAuthSources(t *testing.T) { } } +func TestMultipleKeyLookup(t *testing.T) { + const ( + desc = "auth with correct key" + success = "Success!" + ) + + // setup the fiber endpoint + app := fiber.New() + authMiddleware := New(Config{ + KeyLookup: "header:key|query:key", + Validator: func(c *fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + }) + app.Use(authMiddleware) + app.Get("/foo", func(c *fiber.Ctx) error { + return c.SendString(success) + }) + + // construct the test HTTP request + var req *http.Request + req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", nil) + utils.AssertEqual(t, err, nil) + q := req.URL.Query() + q.Add("key", CorrectKey) + req.URL.RawQuery = q.Encode() + + res, err := app.Test(req, -1) + + utils.AssertEqual(t, nil, err, desc) + + // test the body of the request + body, err := io.ReadAll(res.Body) + utils.AssertEqual(t, 200, res.StatusCode, desc) + // body + utils.AssertEqual(t, nil, err, desc) + utils.AssertEqual(t, success, string(body), desc) + + err = res.Body.Close() + utils.AssertEqual(t, err, nil) +} + func TestMultipleKeyAuth(t *testing.T) { // setup the fiber endpoint app := fiber.New()