From 387a5b56b53ccfe0637a0f44c0ddbec8e991cc39 Mon Sep 17 00:00:00 2001 From: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Date: Fri, 3 Jan 2025 10:09:45 -0700 Subject: [PATCH] fix: correct parsing of floats/ints from json (#180) Corrects an issue caused by Go defaulting to parsing JSON Numbers as float64s. This caused some numbers to be incorrectly parsed as floats when they were integers. This defaults to parsing using json.Number, which allows us to parse between Int/Float more accurately. --- internal/server/api.go | 14 ++++++++++- internal/tools/parameters.go | 39 +++++++++++++++++++++++++------ internal/tools/parameters_test.go | 33 +++++++++++++++++++++----- 3 files changed, 72 insertions(+), 14 deletions(-) diff --git a/internal/server/api.go b/internal/server/api.go index 93e80aa30..6c556b9c7 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -15,7 +15,9 @@ package server import ( + "encoding/json" "fmt" + "io" "net/http" "github.com/go-chi/chi/v5" @@ -118,7 +120,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { } var data map[string]any - if err := render.DecodeJSON(r.Body, &data); err != nil { + if err := decodeJSON(r.Body, &data); err != nil { render.Status(r, http.StatusBadRequest) err := fmt.Errorf("request body was invalid JSON: %w", err) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) @@ -181,3 +183,13 @@ func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { render.Status(r, e.HTTPStatusCode) return nil } + +// decodeJSON decodes a given reader into an interface using the json decoder. +func decodeJSON(r io.Reader, v interface{}) error { + defer io.Copy(io.Discard, r) //nolint:errcheck + d := json.NewDecoder(r) + // specify JSON numbers should get parsed to json.Number instead of float64 by default. + // This prevents loss between floats/ints. + d.UseNumber() + return d.Decode(v) +} diff --git a/internal/tools/parameters.go b/internal/tools/parameters.go index bc7f3a616..24eb32cef 100644 --- a/internal/tools/parameters.go +++ b/internal/tools/parameters.go @@ -15,6 +15,7 @@ package tools import ( + "encoding/json" "fmt" "gopkg.in/yaml.v3" @@ -105,7 +106,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st var err error v, err = parseFromAuthSource(paramAuthSources, claimsMap) if err != nil { - return nil, fmt.Errorf("error parsing anthenticated parameter %q: %w", name, err) + return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err) } } newV, err := p.Parse(v) @@ -331,11 +332,24 @@ type IntParameter struct { } func (p *IntParameter) Parse(v any) (any, error) { - newV, ok := v.(int) - if !ok { + var out int + switch newV := v.(type) { + default: return nil, &ParseTypeError{p.Name, p.Type, v} + case int: + out = int(newV) + case int32: + out = int(newV) + case int64: + out = int(newV) + case json.Number: + newI, err := newV.Int64() + if err != nil { + return nil, &ParseTypeError{p.Name, p.Type, v} + } + out = int(newI) } - return newV, nil + return out, nil } func (p *IntParameter) GetAuthSources() []ParamAuthSource { @@ -374,11 +388,22 @@ type FloatParameter struct { } func (p *FloatParameter) Parse(v any) (any, error) { - newV, ok := v.(float64) - if !ok { + var out float64 + switch newV := v.(type) { + default: return nil, &ParseTypeError{p.Name, p.Type, v} + case float32: + out = float64(newV) + case float64: + out = newV + case json.Number: + newI, err := newV.Float64() + if err != nil { + return nil, &ParseTypeError{p.Name, p.Type, v} + } + out = float64(newI) } - return newV, nil + return out, nil } func (p *FloatParameter) GetAuthSources() []ParamAuthSource { diff --git a/internal/tools/parameters_test.go b/internal/tools/parameters_test.go index 0b3e406cc..72f904c3f 100644 --- a/internal/tools/parameters_test.go +++ b/internal/tools/parameters_test.go @@ -15,6 +15,9 @@ package tools_test import ( + "bytes" + "encoding/json" + "math" "reflect" "testing" @@ -351,6 +354,16 @@ func TestParametersParse(t *testing.T) { "my_int": 14.5, }, }, + { + name: "not int (big)", + params: tools.Parameters{ + tools.NewIntParameter("my_int", "this param is an int"), + }, + in: map[string]any{ + "my_int": math.MaxInt64, + }, + want: tools.ParamValues{tools.ParamValue{Name: "my_int", Value: math.MaxInt64}}, + }, { name: "float", params: tools.Parameters{ @@ -393,25 +406,31 @@ func TestParametersParse(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // parse map to bytes - data, err := yaml.Marshal(tc.in) + data, err := json.Marshal(tc.in) if err != nil { t.Fatalf("unable to marshal input to yaml: %s", err) } // parse bytes to object var m map[string]any - err = yaml.Unmarshal(data, &m) + + d := json.NewDecoder(bytes.NewReader(data)) + d.UseNumber() + err = d.Decode(&m) if err != nil { t.Fatalf("unable to unmarshal: %s", err) } + wantErr := len(tc.want) == 0 // error is expected if no items in want gotAll, err := tools.ParseParams(tc.params, m, make(map[string]map[string]any)) if err != nil { - if len(tc.want) == 0 { - // error is expected if no items in want + if wantErr { return } t.Fatalf("unexpected error from ParseParams: %s", err) } + if wantErr { + t.Fatalf("expected error but Param parsed successfully: %s", gotAll) + } for i, got := range gotAll { want := tc.want[i] if got != want { @@ -552,13 +571,15 @@ func TestAuthParametersParse(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { // parse map to bytes - data, err := yaml.Marshal(tc.in) + data, err := json.Marshal(tc.in) if err != nil { t.Fatalf("unable to marshal input to yaml: %s", err) } // parse bytes to object var m map[string]any - err = yaml.Unmarshal(data, &m) + d := json.NewDecoder(bytes.NewReader(data)) + d.UseNumber() + err = d.Decode(&m) if err != nil { t.Fatalf("unable to unmarshal: %s", err) }