diff --git a/ctx.go b/ctx.go index 64deaad9..45743590 100644 --- a/ctx.go +++ b/ctx.go @@ -8,6 +8,7 @@ import ( "io/fs" "net/http" "net/url" + "strconv" "strings" "time" @@ -44,6 +45,9 @@ type ContextWithBody[B any] interface { // ... // }) PathParam(name string) string + // If the path parameter is not provided or is not an int, it returns 0. Use [Ctx.PathParamIntErr] if you want to know if the path parameter is erroneous. + PathParamInt(name string) int + PathParamIntErr(name string) (int, error) QueryParam(name string) string QueryParamArr(name string) []string @@ -220,6 +224,71 @@ func (c netHttpContext[B]) PathParam(name string) string { return c.Req.PathValue(name) } +type PathParamNotFoundError struct { + ParamName string +} + +func (e PathParamNotFoundError) Error() string { + return fmt.Errorf("param %s not found", e.ParamName).Error() +} + +func (e PathParamNotFoundError) StatusCode() int { return 404 } + +type PathParamInvalidTypeError struct { + Err error + ParamName string + ParamValue string + ExpectedType string +} + +func (e PathParamInvalidTypeError) Error() string { + return fmt.Errorf("param %s=%s is not of type %s: %w", e.ParamName, e.ParamValue, e.ExpectedType, e.Err).Error() +} + +func (e PathParamInvalidTypeError) StatusCode() int { return 422 } + +type ContextWithPathParam interface { + PathParam(name string) string +} + +func PathParamIntErr(c ContextWithPathParam, name string) (int, error) { + param := c.PathParam(name) + if param == "" { + return 0, PathParamNotFoundError{ParamName: name} + } + + i, err := strconv.Atoi(param) + if err != nil { + return 0, PathParamInvalidTypeError{ + ParamName: name, + ParamValue: param, + ExpectedType: "int", + Err: err, + } + } + + return i, nil +} + +func (c netHttpContext[B]) PathParamIntErr(name string) (int, error) { + return PathParamIntErr(c, name) +} + +func PathParamInt(c ContextWithPathParam, name string) int { + param, err := PathParamIntErr(c, name) + if err != nil { + return 0 + } + + return param +} + +// PathParamInt returns the path parameter with the given name as an int. +// If the query parameter does not exist, or if it is not an int, it returns 0. +func (c netHttpContext[B]) PathParamInt(name string) int { + return PathParamInt(c, name) +} + func (c netHttpContext[B]) MainLang() string { return strings.Split(c.MainLocale(), "-")[0] } diff --git a/ctx_test.go b/ctx_test.go index 8d971e1b..e9bf1070 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/xml" "errors" + "fmt" "net/http/httptest" "strings" "testing" @@ -27,6 +28,66 @@ func TestContext_PathParam(t *testing.T) { require.Equal(t, crlf(`{"ans":"123"}`), w.Body.String()) }) + t.Run("can read one path param to int", func(t *testing.T) { + s := NewServer() + Get(s, "/foo/{id}", func(c ContextNoBody) (ans, error) { + return ans{Ans: fmt.Sprintf("%d", c.PathParamInt("id"))}, nil + }) + + r := httptest.NewRequest("GET", "/foo/123", nil) + w := httptest.NewRecorder() + + s.Mux.ServeHTTP(w, r) + + require.Equal(t, crlf(`{"ans":"123"}`), w.Body.String()) + }) + + t.Run("reading non-int path param to int defaults to 0", func(t *testing.T) { + s := NewServer() + Get(s, "/foo/{id}", func(c ContextNoBody) (ans, error) { + return ans{Ans: fmt.Sprintf("%d", c.PathParamInt("id"))}, nil + }) + + r := httptest.NewRequest("GET", "/foo/abc", nil) + w := httptest.NewRecorder() + + s.Mux.ServeHTTP(w, r) + + require.Equal(t, crlf(`{"ans":"0"}`), w.Body.String()) + }) + + t.Run("reading missing path param to int defaults to 0", func(t *testing.T) { + s := NewServer() + Get(s, "/foo/", func(c ContextNoBody) (ans, error) { + return ans{Ans: fmt.Sprintf("%d", c.PathParamInt("id"))}, nil + }) + + r := httptest.NewRequest("GET", "/foo/", nil) + w := httptest.NewRecorder() + + s.Mux.ServeHTTP(w, r) + + require.Equal(t, crlf(`{"ans":"0"}`), w.Body.String()) + }) + + t.Run("reading non-int path param to int sends an error", func(t *testing.T) { + s := NewServer() + Get(s, "/foo/{id}", func(c ContextNoBody) (ans, error) { + id, err := c.PathParamIntErr("id") + if err != nil { + return ans{}, err + } + return ans{Ans: fmt.Sprintf("%d", id)}, nil + }) + + r := httptest.NewRequest("GET", "/foo/abc", nil) + w := httptest.NewRecorder() + + s.Mux.ServeHTTP(w, r) + + require.Equal(t, 422, w.Code) + }) + t.Run("path param invalid", func(t *testing.T) { s := NewServer() Get(s, "/foo/", func(c ContextNoBody) (ans, error) { diff --git a/extra/fuegoecho/context.go b/extra/fuegoecho/context.go index a21862fa..87c8500e 100644 --- a/extra/fuegoecho/context.go +++ b/extra/fuegoecho/context.go @@ -55,6 +55,14 @@ func (c echoContext[B]) PathParam(name string) string { return c.echoCtx.Param(name) } +func (c echoContext[B]) PathParamIntErr(name string) (int, error) { + return fuego.PathParamIntErr(c, name) +} + +func (c echoContext[B]) PathParamInt(name string) int { + return fuego.PathParamInt(c, name) +} + func (c echoContext[B]) MainLang() string { return strings.Split(c.MainLocale(), "-")[0] } diff --git a/extra/fuegogin/context.go b/extra/fuegogin/context.go index c4e8afb1..2f9d1b7f 100644 --- a/extra/fuegogin/context.go +++ b/extra/fuegogin/context.go @@ -55,6 +55,14 @@ func (c ginContext[B]) PathParam(name string) string { return c.ginCtx.Param(name) } +func (c ginContext[B]) PathParamIntErr(name string) (int, error) { + return fuego.PathParamIntErr(c, name) +} + +func (c ginContext[B]) PathParamInt(name string) int { + return fuego.PathParamInt(c, name) +} + func (c ginContext[B]) MainLang() string { return strings.Split(c.MainLocale(), "-")[0] } diff --git a/go.work b/go.work index 37b2f679..7fa00701 100644 --- a/go.work +++ b/go.work @@ -17,6 +17,7 @@ use ( ./examples/petstore ./examples/with-listener ./extra/fuegogin + ./extra/fuegoecho ./extra/markdown ./middleware/basicauth ./middleware/cache diff --git a/mock_context.go b/mock_context.go index 2cf5cabf..6b362059 100644 --- a/mock_context.go +++ b/mock_context.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/url" + "strconv" "strings" "github.com/go-fuego/fuego/internal" @@ -84,6 +85,17 @@ func (m *MockContext[B]) PathParam(name string) string { return m.PathParams[name] } +func (m *MockContext[B]) PathParamIntErr(name string) (int, error) { + return strconv.Atoi(m.PathParams[name]) +} + +func (m *MockContext[B]) PathParamInt(name string) int { + if i, err := m.PathParamIntErr(name); err == nil { + return i + } + return 0 +} + // Request returns the mock request func (m *MockContext[B]) Request() *http.Request { return m.request