Skip to content

Commit

Permalink
🔥 Feature: Add Session Access via Go Context
Browse files Browse the repository at this point in the history
  • Loading branch information
JIeJaitt committed Mar 6, 2025
1 parent a5c7b77 commit 740d04f
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 1 deletion.
5 changes: 5 additions & 0 deletions middleware/session/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package session

import (
"context"
"errors"
"sync"

Expand Down Expand Up @@ -90,6 +91,10 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) {
m := acquireMiddleware()
m.initialize(c, cfg)

// Add session to Go context
ctx := context.WithValue(c.Context(), sessionContextKey, m.Session)
c.SetContext(ctx)

stackErr := c.Next()

m.mu.RLock()
Expand Down
49 changes: 49 additions & 0 deletions middleware/session/middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package session

import (
"io"
"net/http/httptest"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -451,3 +453,50 @@ func Test_Session_Middleware_Store(t *testing.T) {
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}

func Test_Session_GoContext(t *testing.T) {
t.Parallel()
app := fiber.New()

// Create session store
_ = NewStore()

// Setup session middleware
app.Use(New())

// Setup test route
app.Get("/test", func(c fiber.Ctx) error {
// Get session from Fiber context
sess := FromContext(c)

// Set a value
sess.Set("test_key", "test_value")

// Get the session from Go context
goCtxSess := FromGoContext(c.Context())

// Verify both sessions are the same
if goCtxSess == nil {
return c.Status(fiber.StatusInternalServerError).SendString("Session not found in Go context")
}

// Get value from Go context session
val := goCtxSess.Get("test_key")

// Verify value is correct
if val != "test_value" {
return c.Status(fiber.StatusInternalServerError).SendString("Wrong value in Go context session")
}

return c.SendString("success")
})

// Make request
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "success", string(body))
}
22 changes: 22 additions & 0 deletions middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package session

import (
"bytes"
"context"
"encoding/gob"
"fmt"
"sync"
Expand Down Expand Up @@ -30,6 +31,15 @@ const (
absExpirationKey absExpirationKeyType = iota
)

// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int

const (
// sessionContextKey is the key used to store the *Session in the Go context.
sessionContextKey contextKey = iota
)

// Session pool for reusing byte buffers.
var byteBufferPool = sync.Pool{
New: func() any {
Expand Down Expand Up @@ -511,3 +521,15 @@ func (s *Session) isAbsExpired() bool {
func (s *Session) setAbsExpiration(absExpiration time.Time) {
s.Set(absExpirationKey, absExpiration)
}

// FromGoContext returns the Session from the go context.
// If there is no session, nil is returned.
func FromGoContext(ctx context.Context) *Session {
if ctx == nil {
return nil
}

Check warning on line 530 in middleware/session/session.go

View check run for this annotation

Codecov / codecov/patch

middleware/session/session.go#L529-L530

Added lines #L529 - L530 were not covered by tests
if sess, ok := ctx.Value(sessionContextKey).(*Session); ok {
return sess
}
return nil

Check warning on line 534 in middleware/session/session.go

View check run for this annotation

Codecov / codecov/patch

middleware/session/session.go#L534

Added line #L534 was not covered by tests
}
12 changes: 11 additions & 1 deletion middleware/session/store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package session

import (
"context"
"encoding/gob"
"errors"
"fmt"
Expand Down Expand Up @@ -100,7 +101,16 @@ func (s *Store) Get(c fiber.Ctx) (*Session, error) {
return nil, ErrSessionAlreadyLoadedByMiddleware
}

return s.getSession(c)
sess, err := s.getSession(c)
if err != nil {
return nil, err
}

// Add session to Go context
ctx := context.WithValue(c.Context(), sessionContextKey, sess)
c.SetContext(ctx)

return sess, nil
}

// getSession retrieves a session based on the context.
Expand Down
54 changes: 54 additions & 0 deletions middleware/session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package session

import (
"fmt"
"io"
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v3"
Expand Down Expand Up @@ -223,3 +225,55 @@ func Test_Store_GetByID(t *testing.T) {
})
})
}

func Test_Store_Get_GoContext(t *testing.T) {
t.Parallel()
app := fiber.New()
store := NewStore()

app.Get("/direct-get", func(c fiber.Ctx) error {
// Get session directly from store (not using middleware)
sess, err := store.Get(c)
if err != nil {
return err
}
defer sess.Release()

// Set some data
sess.Set("name", "JIeJaitt")

// Now get session from Go context
goCtxSess := FromGoContext(c.Context())

// Verify the session exists in Go context
if goCtxSess == nil {
return c.Status(fiber.StatusInternalServerError).SendString("Session not found in Go context")
}

// Verify it's the same session
if goCtxSess.ID() != sess.ID() {
return c.Status(fiber.StatusInternalServerError).SendString("Session IDs don't match")
}

// Verify data is accessible from Go context session
if goCtxSess.Get("name") != "JIeJaitt" {
return c.Status(fiber.StatusInternalServerError).SendString("Wrong value in Go context session")
}

// Save session
if err := sess.Save(); err != nil {
return err
}

return c.SendString("success")
})

// Make request
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/direct-get", nil))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "success", string(body))
}

0 comments on commit 740d04f

Please sign in to comment.