Skip to content

Commit

Permalink
Support text Un/Marshaller for map keys (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
gordon-klotho authored May 9, 2023
1 parent 2aa0836 commit d34104d
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 17 deletions.
24 changes: 20 additions & 4 deletions marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,25 +577,41 @@ func (enc *Encoder) encodeKey(b []byte, k string) []byte {
}
}

func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind())
func (enc *Encoder) keyToString(k reflect.Value) (string, error) {
keyType := k.Type()
switch {
case keyType.Kind() == reflect.String:
return k.String(), nil

case keyType.Implements(textMarshalerType):
keyB, err := k.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err)
}
return string(keyB), nil
}
return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind())
}

func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
var (
t table
emptyValueOptions valueOptions
)

iter := v.MapRange()
for iter.Next() {
k := iter.Key().String()
v := iter.Value()

if isNil(v) {
continue
}

k, err := enc.keyToString(iter.Key())
if err != nil {
return nil, err
}

if willConvertToTableOrArrayTable(ctx, v) {
t.pushTable(k, v, emptyValueOptions)
} else {
Expand Down
69 changes: 68 additions & 1 deletion marshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ import (
"github.com/stretchr/testify/require"
)

type marshalTextKey struct {
A string
B string
}

func (k marshalTextKey) MarshalText() ([]byte, error) {
return []byte(k.A + "-" + k.B), nil
}

type marshalBadTextKey struct{}

func (k marshalBadTextKey) MarshalText() ([]byte, error) {
return nil, fmt.Errorf("error")
}

func TestMarshal(t *testing.T) {
someInt := 42

Expand Down Expand Up @@ -97,6 +112,53 @@ also = 'that'
a = 'test'
`,
},
{
desc: `map with text key`,
v: map[marshalTextKey]string{
{A: "a", B: "1"}: "value 1",
{A: "a", B: "2"}: "value 2",
{A: "b", B: "1"}: "value 3",
},
expected: `a-1 = 'value 1'
a-2 = 'value 2'
b-1 = 'value 3'
`,
},
{
desc: `table with text key`,
v: map[marshalTextKey]map[string]string{
{A: "a", B: "1"}: {"value": "foo"},
},
expected: `[a-1]
value = 'foo'
`,
},
{
desc: `map with ptr text key`,
v: map[*marshalTextKey]string{
{A: "a", B: "1"}: "value 1",
{A: "a", B: "2"}: "value 2",
{A: "b", B: "1"}: "value 3",
},
expected: `a-1 = 'value 1'
a-2 = 'value 2'
b-1 = 'value 3'
`,
},
{
desc: `map with bad text key`,
v: map[marshalBadTextKey]string{
{}: "value 1",
},
err: true,
},
{
desc: `map with bad ptr text key`,
v: map[*marshalBadTextKey]string{
{}: "value 1",
},
err: true,
},
{
desc: "simple string array",
v: map[string][]string{
Expand Down Expand Up @@ -487,9 +549,14 @@ foo = 42
},
{
desc: "invalid map key",
v: map[int]interface{}{},
v: map[int]interface{}{1: "a"},
err: true,
},
{
desc: "invalid map key but empty",
v: map[int]interface{}{},
expected: "",
},
{
desc: "unhandled type",
v: struct {
Expand Down
43 changes: 32 additions & 11 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h
vt := v.Type()

// Create the key for the map element. Convert to key type.
mk := reflect.ValueOf(string(key.Node().Data)).Convert(vt.Key())
mk, err := d.keyFromData(vt.Key(), key.Node().Data)
if err != nil {
return reflect.Value{}, err
}

// If the map does not exist, create it.
if v.IsNil() {
Expand Down Expand Up @@ -1009,6 +1012,31 @@ func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Nod
return reflect.Value{}, d.handleValue(value, v)
}

func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, error) {
switch {
case stringType.AssignableTo(keyType):
return reflect.ValueOf(string(data)), nil

case stringType.ConvertibleTo(keyType):
return reflect.ValueOf(string(data)).Convert(keyType), nil

case keyType.Implements(textUnmarshalerType):
mk := reflect.New(keyType.Elem())
if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
}
return mk, nil

case reflect.PointerTo(keyType).Implements(textUnmarshalerType):
mk := reflect.New(keyType)
if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
}
return mk.Elem(), nil
}
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType)
}

func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
// contains the replacement for v
var rv reflect.Value
Expand All @@ -1019,16 +1047,9 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
case reflect.Map:
vt := v.Type()

mk := reflect.ValueOf(string(key.Node().Data))
mkt := stringType

keyType := vt.Key()
if !mkt.AssignableTo(keyType) {
if !mkt.ConvertibleTo(keyType) {
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mkt, keyType)
}

mk = mk.Convert(keyType)
mk, err := d.keyFromData(vt.Key(), key.Node().Data)
if err != nil {
return reflect.Value{}, err
}

// If the map does not exist, create it.
Expand Down
128 changes: 127 additions & 1 deletion unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ import (
"github.com/stretchr/testify/require"
)

type unmarshalTextKey struct {
A string
B string
}

func (k *unmarshalTextKey) UnmarshalText(text []byte) error {
parts := strings.Split(string(text), "-")
if len(parts) != 2 {
return fmt.Errorf("invalid text key: %s", text)
}
k.A = parts[0]
k.B = parts[1]
return nil
}

type unmarshalBadTextKey struct{}

func (k *unmarshalBadTextKey) UnmarshalText(text []byte) error {
return fmt.Errorf("error")
}

func ExampleDecoder_DisallowUnknownFields() {
type S struct {
Key1 string
Expand Down Expand Up @@ -315,6 +336,7 @@ func TestUnmarshal(t *testing.T) {
target interface{}
expected interface{}
err bool
assert func(t *testing.T, test test)
}
examples := []struct {
skip bool
Expand Down Expand Up @@ -350,6 +372,96 @@ func TestUnmarshal(t *testing.T) {
}
},
},
{
desc: "kv text key",
input: `a-1 = "foo"`,
gen: func() test {
type doc = map[unmarshalTextKey]string

return test{
target: &doc{},
expected: &doc{{A: "a", B: "1"}: "foo"},
}
},
},
{
desc: "table text key",
input: `["a-1"]
foo = "bar"`,
gen: func() test {
type doc = map[unmarshalTextKey]map[string]string

return test{
target: &doc{},
expected: &doc{{A: "a", B: "1"}: map[string]string{"foo": "bar"}},
}
},
},
{
desc: "kv ptr text key",
input: `a-1 = "foo"`,
gen: func() test {
type doc = map[*unmarshalTextKey]string

return test{
target: &doc{},
expected: &doc{{A: "a", B: "1"}: "foo"},
assert: func(t *testing.T, test test) {
// Despite the documentation:
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
// assert.Equal does not work properly with maps with pointer keys
// https://github.com/stretchr/testify/issues/1143
expected := make(map[unmarshalTextKey]string)
for k, v := range *(test.expected.(*doc)) {
expected[*k] = v
}
got := make(map[unmarshalTextKey]string)
for k, v := range *(test.target.(*doc)) {
got[*k] = v
}
assert.Equal(t, expected, got)
},
}
},
},
{
desc: "kv bad text key",
input: `a-1 = "foo"`,
gen: func() test {
type doc = map[unmarshalBadTextKey]string

return test{
target: &doc{},
err: true,
}
},
},
{
desc: "kv bad ptr text key",
input: `a-1 = "foo"`,
gen: func() test {
type doc = map[*unmarshalBadTextKey]string

return test{
target: &doc{},
err: true,
}
},
},
{
desc: "table bad text key",
input: `["a-1"]
foo = "bar"`,
gen: func() test {
type doc = map[unmarshalBadTextKey]map[string]string

return test{
target: &doc{},
err: true,
}
},
},
{
desc: "time.time with negative zone",
input: `a = 1979-05-27T00:32:00-07:00 `, // space intentional
Expand Down Expand Up @@ -1521,6 +1633,16 @@ B = "data"`,
}
},
},
{
desc: "empty map into map with invalid key type",
input: ``,
gen: func() test {
return test{
target: &map[int]string{},
expected: &map[int]string{},
}
},
},
{
desc: "into map with convertible key type",
input: `A = "hello"`,
Expand Down Expand Up @@ -1777,7 +1899,11 @@ B = "data"`,
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expected, test.target)
if test.assert != nil {
test.assert(t, test)
} else {
assert.Equal(t, test.expected, test.target)
}
}
})
}
Expand Down

0 comments on commit d34104d

Please sign in to comment.