From 96146578d4b7fab976b68b7c43ea93b164259c81 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez Date: Tue, 17 Jan 2023 20:17:17 -0800 Subject: [PATCH] Add unit-test for rewrite middleware --- .github/workflows/test.yml | 1 + main.go | 19 ++-- main_test.go | 175 +++++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be87d45..edc9827 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,6 +13,7 @@ jobs: go-version: - 1.16.x - 1.18.x + - 1.19.x platform: - ubuntu-latest - windows-latest diff --git a/main.go b/main.go index 83d77f2..d4df196 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ type Config struct { // Filter defines a function to skip middleware. // Optional. Default: nil Filter func(*fiber.Ctx) bool + // Rules defines the URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Required. Example: @@ -25,14 +26,7 @@ type Config struct { // "/js/*": "/public/javascripts/$1", // "/users/*/orders/*": "/user/$1/order/$2", Rules map[string]string - // // Redirect determns if the client should be redirected - // // By default this is disabled and urls are rewritten on the server - // // Optional. Default: false - // Redirect bool - // // The status code when redirecting - // // This is ignored if Redirect is disabled - // // Optional. Default: 302 Temporary Redirect - // StatusCode int + rulesRegex map[*regexp.Regexp]string } @@ -40,14 +34,15 @@ type Config struct { func New(config ...Config) fiber.Handler { // Init config var cfg Config + if len(config) > 0 { cfg = config[0] + } else { + cfg = Config{} } - // if cfg.StatusCode == 0 { - // cfg.StatusCode = 302 // Temporary Redirect - // } - cfg = config[0] + cfg.rulesRegex = map[*regexp.Regexp]string{} + // Initialize for k, v := range cfg.Rules { k = strings.Replace(k, "*", "(.*)", -1) diff --git a/main_test.go b/main_test.go index 7370432..c757483 100644 --- a/main_test.go +++ b/main_test.go @@ -3,3 +3,178 @@ // 📝 Github Repository: https://github.com/gofiber/fiber package rewrite + +import ( + "fmt" + "io" + "net/http" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/utils" +) + +func Test_New(t *testing.T) { + // Test with no config + m := New() + + if m == nil { + t.Error("Expected middleware to be returned, got nil") + } + + // Test with config + m = New(Config{ + Rules: map[string]string{ + "/old": "/new", + }, + }) + + if m == nil { + t.Error("Expected middleware to be returned, got nil") + } + + // Test with full config + m = New(Config{ + Filter: func(*fiber.Ctx) bool { + return true + }, + Rules: map[string]string{ + "/old": "/new", + }, + }) + + if m == nil { + t.Error("Expected middleware to be returned, got nil") + } +} + +func Test_Rewrite(t *testing.T) { + // Case 1: filter function always returns true + app := fiber.New() + app.Use(New(Config{ + Filter: func(*fiber.Ctx) bool { + return true + }, + Rules: map[string]string{ + "/old": "/new", + }, + })) + + app.Get("/old", func(c *fiber.Ctx) error { + return c.SendString("Rewrite Successful") + }) + + req, _ := http.NewRequest("GET", "/old", nil) + resp, err := app.Test(req) + body, _ := io.ReadAll(resp.Body) + bodyString := string(body) + + if err != nil { + t.Error(err) + } + + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "Rewrite Successful", bodyString) + + // Case 2: filter function always returns false + app = fiber.New() + app.Use(New(Config{ + Filter: func(*fiber.Ctx) bool { + return false + }, + Rules: map[string]string{ + "/old": "/new", + }, + })) + + app.Get("/new", func(c *fiber.Ctx) error { + return c.SendString("Rewrite Successful") + }) + + req, _ = http.NewRequest("GET", "/old", nil) + resp, err = app.Test(req) + body, _ = io.ReadAll(resp.Body) + bodyString = string(body) + + if err != nil { + t.Error(err) + } + + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "Rewrite Successful", bodyString) + + // Case 3: check for captured tokens in rewrite rule + app = fiber.New() + app.Use(New(Config{ + Rules: map[string]string{ + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + + app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error { + return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) + }) + + req, _ = http.NewRequest("GET", "/users/123/orders/456", nil) + resp, err = app.Test(req) + body, _ = io.ReadAll(resp.Body) + bodyString = string(body) + + if err != nil { + t.Error(err) + } + + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "User ID: 123, Order ID: 456", bodyString) + + // Case 4: Send non-matching request, handled by default route + app = fiber.New() + app.Use(New(Config{ + Rules: map[string]string{ + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + + app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error { + return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) + }) + + app.Use(func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + req, _ = http.NewRequest("GET", "/not-matching-any-rule", nil) + resp, err = app.Test(req) + body, _ = io.ReadAll(resp.Body) + bodyString = string(body) + + if err != nil { + t.Error(err) + } + + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) + utils.AssertEqual(t, "OK", bodyString) + + // Case 4: Send non-matching request, with no default route + app = fiber.New() + app.Use(New(Config{ + Rules: map[string]string{ + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + + app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error { + return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) + }) + + req, _ = http.NewRequest("GET", "/not-matching-any-rule", nil) + resp, err = app.Test(req) + body, _ = io.ReadAll(resp.Body) + bodyString = string(body) + + if err != nil { + t.Error(err) + } + + utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) +}