From dbba6cfa6979d2784cbe6cbaefbee217d87266ef Mon Sep 17 00:00:00 2001 From: RW Date: Fri, 28 Jun 2024 15:51:26 +0200 Subject: [PATCH] fixes #3038 "v3 Flash Message with redirect is not working" (#3046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixes #3038 🐞 [Bug]: v3 Flash Message with redirect is not working #3038 * fixes #3038 🐞 [Bug]: v3 Flash Message with redirect is not working #3038 --- redirect.go | 74 +++++++++++++++++++++------------------- redirect_test.go | 88 +++++++++++++++++++++++++++++++++++++++++------- router.go | 2 +- 3 files changed, 116 insertions(+), 48 deletions(-) diff --git a/redirect.go b/redirect.go index f03981b818..741a2e12a5 100644 --- a/redirect.go +++ b/redirect.go @@ -182,6 +182,8 @@ func (r *Redirect) To(location string) error { r.c.setCanonical(HeaderLocation, location) r.c.Status(r.status) + r.processFlashMessages() + return nil } @@ -200,38 +202,6 @@ func (r *Redirect) Route(name string, config ...RedirectConfig) error { return err } - // Flash messages - if len(r.messages) > 0 || len(r.oldInput) > 0 { - messageText := bytebufferpool.Get() - defer bytebufferpool.Put(messageText) - - // flash messages - for i, message := range r.messages { - messageText.WriteString(message) - // when there are more messages or oldInput -> add a comma - if len(r.messages)-1 != i || (len(r.messages)-1 == i && len(r.oldInput) > 0) { - messageText.WriteString(CookieDataSeparator) - } - } - r.messages = r.messages[:0] - - // old input data - i := 1 - for k, v := range r.oldInput { - messageText.WriteString(OldInputDataPrefix + k + CookieDataAssigner + v) - if len(r.oldInput) != i { - messageText.WriteString(CookieDataSeparator) - } - i++ - } - - r.c.Cookie(&Cookie{ - Name: FlashCookieName, - Value: r.c.app.getString(messageText.Bytes()), - SessionOnly: true, - }) - } - // Check queries if len(cfg.Queries) > 0 { queryText := bytebufferpool.Get() @@ -270,8 +240,8 @@ func (r *Redirect) Back(fallback ...string) error { return r.To(location) } -// setFlash is a method to get flash messages before removing them -func (r *Redirect) setFlash() { +// parseAndClearFlashMessages is a method to get flash messages before removing them +func (r *Redirect) parseAndClearFlashMessages() { // parse flash messages cookieValue := r.c.Cookies(FlashCookieName) @@ -289,6 +259,42 @@ func (r *Redirect) setFlash() { r.c.ClearCookie(FlashCookieName) } +// processFlashMessages is a helper function to process flash messages and old input data +// and set them as cookies +func (r *Redirect) processFlashMessages() { + // Flash messages + if len(r.messages) > 0 || len(r.oldInput) > 0 { + messageText := bytebufferpool.Get() + defer bytebufferpool.Put(messageText) + + // flash messages + for i, message := range r.messages { + messageText.WriteString(message) + // when there are more messages or oldInput -> add a comma + if len(r.messages)-1 != i || (len(r.messages)-1 == i && len(r.oldInput) > 0) { + messageText.WriteString(CookieDataSeparator) + } + } + r.messages = r.messages[:0] + + // old input data + i := 1 + for k, v := range r.oldInput { + messageText.WriteString(OldInputDataPrefix + k + CookieDataAssigner + v) + if len(r.oldInput) != i { + messageText.WriteString(CookieDataSeparator) + } + i++ + } + + r.c.Cookie(&Cookie{ + Name: FlashCookieName, + Value: r.c.app.getString(messageText.Bytes()), + SessionOnly: true, + }) + } +} + // parseMessage is a helper function to parse flash messages and old input data func parseMessage(raw string) (string, string) { //nolint: revive // not necessary if i := findNextNonEscapedCharsetPosition(raw, []byte(CookieDataAssigner)); i != -1 { diff --git a/redirect_test.go b/redirect_test.go index dd5e4b2715..70f583ea9f 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -34,6 +34,24 @@ func Test_Redirect_To(t *testing.T) { require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) } +// go test -run Test_Redirect_To_WithFlashMessages +func Test_Redirect_To_WithFlashMessages(t *testing.T) { + t.Parallel() + app := New() + c := app.AcquireCtx(&fasthttp.RequestCtx{}) + + err := c.Redirect().With("success", "1").With("message", "test").To("http://example.com") + require.NoError(t, err) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) + + equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" + require.True(t, equal) + + c.Redirect().parseAndClearFlashMessages() + require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) +} + // go test -run Test_Redirect_Route_WithParams func Test_Redirect_Route_WithParams(t *testing.T) { t.Parallel() @@ -149,6 +167,29 @@ func Test_Redirect_Back(t *testing.T) { require.ErrorAs(t, err, &ErrRedirectBackNoFallback) } +// go test -run Test_Redirect_Back_WithFlashMessages +func Test_Redirect_Back_WithFlashMessages(t *testing.T) { + t.Parallel() + + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed + + err := c.Redirect().With("success", "1").With("message", "test").Back("/") + require.NoError(t, err) + require.Equal(t, 302, c.Response().StatusCode()) + require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) + + equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" + require.True(t, equal) + + c.Redirect().parseAndClearFlashMessages() + require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) +} + // go test -run Test_Redirect_Back_WithReferer func Test_Redirect_Back_WithReferer(t *testing.T) { t.Parallel() @@ -188,7 +229,7 @@ func Test_Redirect_Route_WithFlashMessages(t *testing.T) { equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" require.True(t, equal) - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) } @@ -216,12 +257,12 @@ func Test_Redirect_Route_WithOldInput(t *testing.T) { require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_id:1") require.Contains(t, c.GetRespHeader(HeaderSetCookie), ",old_input_data_name:tom") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) } -// go test -run Test_Redirect_setFlash -func Test_Redirect_setFlash(t *testing.T) { +// go test -run Test_Redirect_parseAndClearFlashMessages +func Test_Redirect_parseAndClearFlashMessages(t *testing.T) { t.Parallel() app := New() @@ -233,7 +274,7 @@ func Test_Redirect_setFlash(t *testing.T) { c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() require.Equal(t, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) @@ -416,12 +457,12 @@ func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) { equal := c.GetRespHeader(HeaderSetCookie) == "fiber_flash=success:1,message:test; path=/; SameSite=Lax" || c.GetRespHeader(HeaderSetCookie) == "fiber_flash=message:test,success:1; path=/; SameSite=Lax" require.True(b, equal) - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) } -// go test -v -run=^$ -bench=Benchmark_Redirect_setFlash -benchmem -count=4 -func Benchmark_Redirect_setFlash(b *testing.B) { +// go test -v -run=^$ -bench=Benchmark_Redirect_parseAndClearFlashMessages -benchmem -count=4 +func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") @@ -435,7 +476,7 @@ func Benchmark_Redirect_setFlash(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() } require.Equal(b, "fiber_flash=; expires=Tue, 10 Nov 2009 23:00:00 GMT", c.GetRespHeader(HeaderSetCookie)) @@ -449,6 +490,27 @@ func Benchmark_Redirect_setFlash(b *testing.B) { require.Equal(b, map[string]string{"id": "1", "name": "tom"}, c.Redirect().OldInputs()) } +// go test -v -run=^$ -bench=Benchmark_Redirect_processFlashMessages -benchmem -count=4 +func Benchmark_Redirect_processFlashMessages(b *testing.B) { + app := New() + app.Get("/user", func(c Ctx) error { + return c.SendString("user") + }).Name("user") + + c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed + + c.Redirect().With("success", "1").With("message", "test") + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + c.Redirect().processFlashMessages() + } + + require.Equal(b, "fiber_flash=success:1,message:test; path=/; SameSite=Lax", c.GetRespHeader(HeaderSetCookie)) +} + // go test -v -run=^$ -bench=Benchmark_Redirect_Messages -benchmem -count=4 func Benchmark_Redirect_Messages(b *testing.B) { app := New() @@ -459,7 +521,7 @@ func Benchmark_Redirect_Messages(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() var msgs map[string]string @@ -484,7 +546,7 @@ func Benchmark_Redirect_OldInputs(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() var oldInputs map[string]string @@ -509,7 +571,7 @@ func Benchmark_Redirect_Message(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() var msg string @@ -534,7 +596,7 @@ func Benchmark_Redirect_OldInput(b *testing.B) { c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck, forcetypeassert // not needed c.Request().Header.Set(HeaderCookie, "fiber_flash=success:1,message:test,old_input_data_name:tom,old_input_data_id:1") - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() var input string diff --git a/router.go b/router.go index 26a2483f09..eae9adef70 100644 --- a/router.go +++ b/router.go @@ -225,7 +225,7 @@ func (app *App) requestHandler(rctx *fasthttp.RequestCtx) { // check flash messages if strings.Contains(utils.UnsafeString(c.Request().Header.RawHeaders()), FlashCookieName) { - c.Redirect().setFlash() + c.Redirect().parseAndClearFlashMessages() } // Find match in stack