Skip to content

Commit

Permalink
add io/fs support to static mw
Browse files Browse the repository at this point in the history
  • Loading branch information
efectn committed May 18, 2024
1 parent 4284ebc commit a9d7f22
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/middleware/static.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ If you want to define static routes using `Get`, you need to use wildcard (`*`)
| Property | Type | Description | Default |
|:-----------|:------------------------|:---------------------------------------------------------------------------------------------------------------------------|:-----------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| FS | `fs.FS` | FS is the file system to serve the static files from.<br /><br />You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. | `nil` |
| Compress | `bool` | When set to true, the server tries minimizing CPU usage by caching compressed files.<br /><br />This works differently than the github.com/gofiber/compression middleware. | `false` |
| ByteRange | `bool` | When set to true, enables byte range requests. | `false` |
| Browse | `bool` | When set to true, enables directory browsing. | `false` |
Expand Down
1 change: 1 addition & 0 deletions middleware/static/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Config struct {
Next func(c fiber.Ctx) bool

// FS is the file system to serve the static files from.
// You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc.
//
// Optional. Default: nil
FS fs.FS
Expand Down
32 changes: 21 additions & 11 deletions middleware/static/static.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package static

import (
"io/fs"
"os"
"strconv"
"strings"
"sync"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)

Expand All @@ -22,14 +24,10 @@ func New(root string, cfg ...Config) fiber.Handler {
var cacheControlValue string

// adjustments for io/fs compatibility
if config.FS != nil && root != "" {
if config.FS != nil && root == "" {
root = "."
}

if root != "." && !strings.HasPrefix(root, "/") {
root = "./" + root
}

return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if config.Next != nil && config.Next(c) {
Expand Down Expand Up @@ -77,14 +75,16 @@ func New(root string, cfg ...Config) fiber.Handler {
path := fctx.Path()

if len(path) >= prefixLen {
checkFile, err := isFile(root)
checkFile, err := isFile(root, fs.FS)
if err != nil {
return path
}

// If the root is a file, we need to reset the path to "/" always.
if checkFile {
if checkFile && fs.FS == nil {
path = append(path[0:0], '/')
} else if checkFile && fs.FS != nil {
path = utils.UnsafeBytes(root)
} else {
path = path[prefixLen:]
if len(path) == 0 || path[len(path)-1] != '/' {
Expand Down Expand Up @@ -145,10 +145,20 @@ func New(root string, cfg ...Config) fiber.Handler {
}

// isFile checks if the root is a file.
func isFile(root string) (bool, error) {
file, err := os.Open(root)
if err != nil {
return false, err
func isFile(root string, filesystem fs.FS) (bool, error) {
var file fs.File
var err error

if filesystem != nil {
file, err = filesystem.Open(root)
if err != nil {
return false, err
}
} else {
file, err = os.Open(root)
if err != nil {
return false, err
}
}

stat, err := file.Stat()
Expand Down
76 changes: 75 additions & 1 deletion middleware/static/static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package static
import (
"embed"
"io"
"io/fs"
"net/http/httptest"
"os"
"strings"
Expand Down Expand Up @@ -527,7 +528,8 @@ func Test_Static_FS_Prefix_Wildcard(t *testing.T) {
app := fiber.New()

app.Get("/test*", New("index.html", Config{
FS: os.DirFS("../../.github"),
FS: os.DirFS("../../.github"),
Index: "not_index.html",
}))

req := httptest.NewRequest(fiber.MethodGet, "/test/john/doe", nil)
Expand All @@ -541,3 +543,75 @@ func Test_Static_FS_Prefix_Wildcard(t *testing.T) {
require.NoError(t, err)
require.Contains(t, string(body), "Test file")
}

func Test_isFile(t *testing.T) {
t.Parallel()

cases := []struct {
name string
path string
filesystem fs.FS
expected bool
gotError error
}{
{
name: "file",
path: "index.html",
filesystem: os.DirFS("../../.github"),
expected: true,
},
{
name: "file",
path: "index2.html",
filesystem: os.DirFS("../../.github"),
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: ".",
filesystem: os.DirFS("../../.github"),
expected: false,
},
{
name: "directory",
path: "not_exists",
filesystem: os.DirFS("../../.github"),
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: ".",
filesystem: os.DirFS("../../.github/testdata/fs/css"),
expected: false,
},
{
name: "file",
path: "../../.github/testdata/fs/css/style.css",
filesystem: nil,
expected: true,
},
{
name: "file",
path: "../../.github/testdata/fs/css/style2.css",
filesystem: nil,
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: "../../.github/testdata/fs/css",
filesystem: nil,
expected: false,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
actual, err := isFile(c.path, c.filesystem)
require.ErrorIs(t, err, c.gotError)
require.Equal(t, c.expected, actual)
})
}
}

0 comments on commit a9d7f22

Please sign in to comment.