Skip to content

Commit

Permalink
fix: correct parsing of floats/ints from json (#180)
Browse files Browse the repository at this point in the history
Corrects an issue caused by Go defaulting to parsing JSON Numbers as
float64s. This caused some numbers to be incorrectly parsed as floats
when they were integers. This defaults to parsing using json.Number,
which allows us to parse between Int/Float more accurately.
  • Loading branch information
kurtisvg authored Jan 3, 2025
1 parent 66ab70f commit 387a5b5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 14 deletions.
14 changes: 13 additions & 1 deletion internal/server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package server

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -118,7 +120,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}

var data map[string]any
if err := render.DecodeJSON(r.Body, &data); err != nil {
if err := decodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)
err := fmt.Errorf("request body was invalid JSON: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
Expand Down Expand Up @@ -181,3 +183,13 @@ func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}

// decodeJSON decodes a given reader into an interface using the json decoder.
func decodeJSON(r io.Reader, v interface{}) error {
defer io.Copy(io.Discard, r) //nolint:errcheck
d := json.NewDecoder(r)
// specify JSON numbers should get parsed to json.Number instead of float64 by default.
// This prevents loss between floats/ints.
d.UseNumber()
return d.Decode(v)
}
39 changes: 32 additions & 7 deletions internal/tools/parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package tools

import (
"encoding/json"
"fmt"

"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -105,7 +106,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
var err error
v, err = parseFromAuthSource(paramAuthSources, claimsMap)
if err != nil {
return nil, fmt.Errorf("error parsing anthenticated parameter %q: %w", name, err)
return nil, fmt.Errorf("error parsing authenticated parameter %q: %w", name, err)
}
}
newV, err := p.Parse(v)
Expand Down Expand Up @@ -331,11 +332,24 @@ type IntParameter struct {
}

func (p *IntParameter) Parse(v any) (any, error) {
newV, ok := v.(int)
if !ok {
var out int
switch newV := v.(type) {
default:
return nil, &ParseTypeError{p.Name, p.Type, v}
case int:
out = int(newV)
case int32:
out = int(newV)
case int64:
out = int(newV)
case json.Number:
newI, err := newV.Int64()
if err != nil {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
out = int(newI)
}
return newV, nil
return out, nil
}

func (p *IntParameter) GetAuthSources() []ParamAuthSource {
Expand Down Expand Up @@ -374,11 +388,22 @@ type FloatParameter struct {
}

func (p *FloatParameter) Parse(v any) (any, error) {
newV, ok := v.(float64)
if !ok {
var out float64
switch newV := v.(type) {
default:
return nil, &ParseTypeError{p.Name, p.Type, v}
case float32:
out = float64(newV)
case float64:
out = newV
case json.Number:
newI, err := newV.Float64()
if err != nil {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
out = float64(newI)
}
return newV, nil
return out, nil
}

func (p *FloatParameter) GetAuthSources() []ParamAuthSource {
Expand Down
33 changes: 27 additions & 6 deletions internal/tools/parameters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package tools_test

import (
"bytes"
"encoding/json"
"math"
"reflect"
"testing"

Expand Down Expand Up @@ -351,6 +354,16 @@ func TestParametersParse(t *testing.T) {
"my_int": 14.5,
},
},
{
name: "not int (big)",
params: tools.Parameters{
tools.NewIntParameter("my_int", "this param is an int"),
},
in: map[string]any{
"my_int": math.MaxInt64,
},
want: tools.ParamValues{tools.ParamValue{Name: "my_int", Value: math.MaxInt64}},
},
{
name: "float",
params: tools.Parameters{
Expand Down Expand Up @@ -393,25 +406,31 @@ func TestParametersParse(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// parse map to bytes
data, err := yaml.Marshal(tc.in)
data, err := json.Marshal(tc.in)
if err != nil {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
var m map[string]any
err = yaml.Unmarshal(data, &m)

d := json.NewDecoder(bytes.NewReader(data))
d.UseNumber()
err = d.Decode(&m)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

wantErr := len(tc.want) == 0 // error is expected if no items in want
gotAll, err := tools.ParseParams(tc.params, m, make(map[string]map[string]any))
if err != nil {
if len(tc.want) == 0 {
// error is expected if no items in want
if wantErr {
return
}
t.Fatalf("unexpected error from ParseParams: %s", err)
}
if wantErr {
t.Fatalf("expected error but Param parsed successfully: %s", gotAll)
}
for i, got := range gotAll {
want := tc.want[i]
if got != want {
Expand Down Expand Up @@ -552,13 +571,15 @@ func TestAuthParametersParse(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// parse map to bytes
data, err := yaml.Marshal(tc.in)
data, err := json.Marshal(tc.in)
if err != nil {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
var m map[string]any
err = yaml.Unmarshal(data, &m)
d := json.NewDecoder(bytes.NewReader(data))
d.UseNumber()
err = d.Decode(&m)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
Expand Down

0 comments on commit 387a5b5

Please sign in to comment.