Skip to content

Commit

Permalink
Fuzzing setup and fixes (#755)
Browse files Browse the repository at this point in the history
* encode: fix localdate formatting
* encode: fix empty key marshaling
* encode: fix invalid quotation of time.Time
* encode: ensure control chars are escaped
* decode: always use UTC for zero tz
* encode: check for invalid characters in keys
* encode: always construct map for empty array tables
* fuzz: add go 1.18 fuzz test
* encode: handle NaNs
* encode: allow new lines in quoted keys
* encode: never emit table inside array
* encode: don't capitalize inf
  • Loading branch information
pelletier authored Apr 11, 2022
1 parent 2377ac4 commit 8bbb673
Show file tree
Hide file tree
Showing 19 changed files with 230 additions and 62 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* text=auto

benchmark/benchmark.toml text eol=lf
testdata/** text eol=lf
18 changes: 13 additions & 5 deletions ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ cover() {
fi

pushd "$dir"
go test -covermode=atomic -coverprofile=coverage.out ./...
go test -covermode=atomic -coverpkg=./... -coverprofile=coverage.out.tmp ./...
cat coverage.out.tmp | grep -v testsuite | grep -v tomltestgen | grep -v gotoml-test-decoder > coverage.out
go tool cover -func=coverage.out
popd

Expand All @@ -103,16 +104,23 @@ coverage() {

echo ""

target_pct="$(cat ${target_out} |sed -E 's/.*total.*\t([0-9.]+)%/\1/;t;d')"
head_pct="$(cat ${head_out} |sed -E 's/.*total.*\t([0-9.]+)%/\1/;t;d')"
target_pct="$(tail -n2 ${target_out} | head -n1 | sed -E 's/.*total.*\t([0-9.]+)%.*/\1/')"
head_pct="$(tail -n2 ${head_out} | head -n1 | sed -E 's/.*total.*\t([0-9.]+)%/\1/')"
echo "Results: ${target} ${target_pct}% HEAD ${head_pct}%"

delta_pct=$(echo "$head_pct - $target_pct" | bc -l)
echo "Delta: ${delta_pct}"

if [[ $delta_pct = \-* ]]; then
echo "Regression!";
return 1
echo "Regression!";

target_diff="${output_dir}/target.diff.txt"
head_diff="${output_dir}/head.diff.txt"
cat "${target_out}" | grep -E '^github.com/pelletier/go-toml' | tr -s "\t " | cut -f 2,3 | sort > "${target_diff}"
cat "${head_out}" | grep -E '^github.com/pelletier/go-toml' | tr -s "\t " | cut -f 2,3 | sort > "${head_diff}"

diff --side-by-side --suppress-common-lines "${target_diff}" "${head_diff}"
return 1
fi
return 0
;;
Expand Down
6 changes: 5 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ func parseDateTime(b []byte) (time.Time, error) {
}

seconds := direction * (hours*3600 + minutes*60)
zone = time.FixedZone("", seconds)
if seconds == 0 {
zone = time.UTC
} else {
zone = time.FixedZone("", seconds)
}
b = b[dateTimeByteLen:]
}

Expand Down
56 changes: 56 additions & 0 deletions fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//go:build go1.18
// +build go1.18

package toml_test

import (
"io/ioutil"
"strings"
"testing"

"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/require"
)

func FuzzUnmarshal(f *testing.F) {
file, err := ioutil.ReadFile("benchmark/benchmark.toml")
if err != nil {
panic(err)
}
f.Add(file)

f.Fuzz(func(t *testing.T, b []byte) {
if strings.Contains(string(b), "nan") {
// Current limitation of testify.
// https://github.com/stretchr/testify/issues/624
t.Skip("can't compare NaNs")
}

t.Log("INITIAL DOCUMENT ===========================")
t.Log(string(b))

var v interface{}
err := toml.Unmarshal(b, &v)
if err != nil {
return
}

t.Log("DECODED VALUE ===========================")
t.Logf("%#+v", v)

encoded, err := toml.Marshal(v)
if err != nil {
t.Fatalf("cannot marshal unmarshaled document: %s", err)
}

t.Log("ENCODED DOCUMENT ===========================")
t.Log(string(encoded))

var v2 interface{}
err = toml.Unmarshal(encoded, &v2)
if err != nil {
t.Fatalf("failed round trip: %s", err)
}
require.Equal(t, v, v2)
})
}
108 changes: 62 additions & 46 deletions marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,20 @@ func (ctx *encoderCtx) isRoot() bool {
}

func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if !v.IsZero() {
i, ok := v.Interface().(time.Time)
if ok {
return i.AppendFormat(b, time.RFC3339), nil
i := v.Interface()

switch x := i.(type) {
case time.Time:
if x.Nanosecond() > 0 {
return x.AppendFormat(b, time.RFC3339Nano), nil
}
return x.AppendFormat(b, time.RFC3339), nil
case LocalTime:
return append(b, x.String()...), nil
case LocalDate:
return append(b, x.String()...), nil
case LocalDateTime:
return append(b, x.String()...), nil
}

hasTextMarshaler := v.Type().Implements(textMarshalerType)
Expand Down Expand Up @@ -260,16 +269,31 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.String:
b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32:
if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 32)
f := v.Float()

if math.IsNaN(f) {
b = append(b, "nan"...)
} else if f > math.MaxFloat32 {
b = append(b, "inf"...)
} else if f < -math.MaxFloat32 {
b = append(b, "-inf"...)
} else if math.Trunc(f) == f {
b = strconv.AppendFloat(b, f, 'f', 1, 32)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
b = strconv.AppendFloat(b, f, 'f', -1, 32)
}
case reflect.Float64:
if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 64)
f := v.Float()
if math.IsNaN(f) {
b = append(b, "nan"...)
} else if f > math.MaxFloat64 {
b = append(b, "inf"...)
} else if f < -math.MaxFloat64 {
b = append(b, "-inf"...)
} else if math.Trunc(f) == f {
b = strconv.AppendFloat(b, f, 'f', 1, 64)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 64)
b = strconv.AppendFloat(b, f, 'f', -1, 64)
}
case reflect.Bool:
if v.Bool() {
Expand Down Expand Up @@ -300,10 +324,6 @@ func isNil(v reflect.Value) bool {
func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) {
var err error

if !ctx.hasKey {
panic("caller of encodeKv should have set the key in the context")
}

if (ctx.options.omitempty || options.omitempty) && isEmptyValue(v) {
return b, nil
}
Expand All @@ -313,12 +333,7 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
}

b = enc.indent(ctx.indent, b)

b, err = enc.encodeKey(b, ctx.key)
if err != nil {
return nil, err
}

b = enc.encodeKey(b, ctx.key)
b = append(b, " = "...)

// create a copy of the context because the value of a KV shouldn't
Expand Down Expand Up @@ -365,7 +380,13 @@ func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byt
}

func needsQuoting(v string) bool {
return strings.ContainsAny(v, "'\b\f\n\r\t")
// TODO: vectorize
for _, b := range []byte(v) {
if b == '\'' || b == '\r' || b == '\n' || invalidAscii(b) {
return true
}
}
return false
}

// caller should have checked that the string does not contain new lines or ' .
Expand Down Expand Up @@ -437,7 +458,7 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
return b
}

// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
// caller should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...)
}
Expand All @@ -453,20 +474,11 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)

b = append(b, '[')

var err error

b, err = enc.encodeKey(b, ctx.parentKey[0])
if err != nil {
return nil, err
}
b = enc.encodeKey(b, ctx.parentKey[0])

for _, k := range ctx.parentKey[1:] {
b = append(b, '.')

b, err = enc.encodeKey(b, k)
if err != nil {
return nil, err
}
b = enc.encodeKey(b, k)
}

b = append(b, "]\n"...)
Expand All @@ -475,33 +487,37 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
}

//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
func (enc *Encoder) encodeKey(b []byte, k string) []byte {
needsQuotation := false
cannotUseLiteral := false

if len(k) == 0 {
return append(b, "''"...)
}

for _, c := range k {
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
continue
}

if c == '\n' {
return nil, fmt.Errorf("toml: new line characters in keys are not supported")
}

if c == literalQuote {
cannotUseLiteral = true
}

needsQuotation = true
}

if needsQuotation && needsQuoting(k) {
cannotUseLiteral = true
}

switch {
case cannotUseLiteral:
return enc.encodeQuotedString(false, b, k), nil
return enc.encodeQuotedString(false, b, k)
case needsQuotation:
return enc.encodeLiteralString(b, k), nil
return enc.encodeLiteralString(b, k)
default:
return enc.encodeUnquotedKey(b, k), nil
return enc.encodeUnquotedKey(b, k)
}
}

Expand Down Expand Up @@ -803,6 +819,9 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) bool {
}

func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) bool {
if ctx.insideKv {
return false
}
t := v.Type()

if t.Kind() == reflect.Interface {
Expand Down Expand Up @@ -848,7 +867,6 @@ func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]by
func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
ctx.shiftKey()

var err error
scratch := make([]byte, 0, 64)
scratch = append(scratch, "[["...)

Expand All @@ -857,10 +875,7 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
scratch = append(scratch, '.')
}

scratch, err = enc.encodeKey(scratch, k)
if err != nil {
return nil, err
}
scratch = enc.encodeKey(scratch, k)
}

scratch = append(scratch, "]]\n"...)
Expand All @@ -869,6 +884,7 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
for i := 0; i < v.Len(); i++ {
b = append(b, scratch...)

var err error
b, err = enc.encode(b, ctx, v.Index(i))
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 8bbb673

Please sign in to comment.