Skip to content
This repository has been archived by the owner on May 19, 2023. It is now read-only.

Add unit-tests for rewrite middleware #78

Merged
merged 1 commit into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
go-version:
- 1.16.x
- 1.18.x
- 1.19.x
platform:
- ubuntu-latest
- windows-latest
Expand Down
19 changes: 7 additions & 12 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Config struct {
// Filter defines a function to skip middleware.
// Optional. Default: nil
Filter func(*fiber.Ctx) bool

// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Required. Example:
Expand All @@ -25,29 +26,23 @@ type Config struct {
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rules map[string]string
// // Redirect determns if the client should be redirected
// // By default this is disabled and urls are rewritten on the server
// // Optional. Default: false
// Redirect bool
// // The status code when redirecting
// // This is ignored if Redirect is disabled
// // Optional. Default: 302 Temporary Redirect
// StatusCode int

rulesRegex map[*regexp.Regexp]string
}

// New ...
func New(config ...Config) fiber.Handler {
// Init config
var cfg Config

if len(config) > 0 {
cfg = config[0]
} else {
cfg = Config{}
}
// if cfg.StatusCode == 0 {
// cfg.StatusCode = 302 // Temporary Redirect
// }
cfg = config[0]

cfg.rulesRegex = map[*regexp.Regexp]string{}

// Initialize
for k, v := range cfg.Rules {
k = strings.Replace(k, "*", "(.*)", -1)
Expand Down
175 changes: 175 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,178 @@
// 📝 Github Repository: https://github.com/gofiber/fiber

package rewrite

import (
"fmt"
"io"
"net/http"
"testing"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)

func Test_New(t *testing.T) {
// Test with no config
m := New()

if m == nil {
t.Error("Expected middleware to be returned, got nil")
}

// Test with config
m = New(Config{
Rules: map[string]string{
"/old": "/new",
},
})

if m == nil {
t.Error("Expected middleware to be returned, got nil")
}

// Test with full config
m = New(Config{
Filter: func(*fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
})

if m == nil {
t.Error("Expected middleware to be returned, got nil")
}
}

func Test_Rewrite(t *testing.T) {
// Case 1: filter function always returns true
app := fiber.New()
app.Use(New(Config{
Filter: func(*fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
}))

app.Get("/old", func(c *fiber.Ctx) error {
return c.SendString("Rewrite Successful")
})

req, _ := http.NewRequest("GET", "/old", nil)
resp, err := app.Test(req)
body, _ := io.ReadAll(resp.Body)
bodyString := string(body)

if err != nil {
t.Error(err)
}

utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Rewrite Successful", bodyString)

// Case 2: filter function always returns false
app = fiber.New()
app.Use(New(Config{
Filter: func(*fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/old": "/new",
},
}))

app.Get("/new", func(c *fiber.Ctx) error {
return c.SendString("Rewrite Successful")
})

req, _ = http.NewRequest("GET", "/old", nil)
resp, err = app.Test(req)
body, _ = io.ReadAll(resp.Body)
bodyString = string(body)

if err != nil {
t.Error(err)
}

utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "Rewrite Successful", bodyString)

// Case 3: check for captured tokens in rewrite rule
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))

app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})

req, _ = http.NewRequest("GET", "/users/123/orders/456", nil)
resp, err = app.Test(req)
body, _ = io.ReadAll(resp.Body)
bodyString = string(body)

if err != nil {
t.Error(err)
}

utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "User ID: 123, Order ID: 456", bodyString)

// Case 4: Send non-matching request, handled by default route
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))

app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})

app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

req, _ = http.NewRequest("GET", "/not-matching-any-rule", nil)
resp, err = app.Test(req)
body, _ = io.ReadAll(resp.Body)
bodyString = string(body)

if err != nil {
t.Error(err)
}

utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "OK", bodyString)

// Case 4: Send non-matching request, with no default route
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))

app.Get("/user/:userID/order/:orderID", func(c *fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})

req, _ = http.NewRequest("GET", "/not-matching-any-rule", nil)
resp, err = app.Test(req)
body, _ = io.ReadAll(resp.Body)
bodyString = string(body)

if err != nil {
t.Error(err)
}

utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}