From 3157125ac0ad5b52993e19edca3616c30ac54bb3 Mon Sep 17 00:00:00 2001 From: Carolyn Van Slyck Date: Sun, 2 Apr 2017 16:14:56 -0500 Subject: [PATCH 1/2] Use MarshalTOML() if struct has defined a custom marshaler --- marshal.go | 20 ++++++++++++++++++++ marshal_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ tomltree_write.go | 3 +++ 3 files changed, 66 insertions(+) diff --git a/marshal.go b/marshal.go index 4301a451..641440ce 100644 --- a/marshal.go +++ b/marshal.go @@ -33,6 +33,7 @@ type tomlOpts struct { } var timeType = reflect.TypeOf(time.Time{}) +var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() // Check if the given marshall type maps to a TomlTree primitive func isPrimitive(mtype reflect.Type) bool { @@ -90,6 +91,20 @@ func isTree(mtype reflect.Type) bool { } } +func isCustomMarshaler(mtype reflect.Type) bool { + return mtype.Implements(marshalerType) +} + +func callCustomMarshaler(mval reflect.Value) ([]byte, error) { + return mval.Interface().(Marshaler).MarshalTOML() +} + +// Marshaler is the interface implemented by types that +// can marshal themselves into valid TOML. +type Marshaler interface { + MarshalTOML() ([]byte, error) +} + /* Marshal returns the TOML encoding of v. Behavior is similar to the Go json encoder, except that there is no concept of a Marshaler interface or MarshalTOML @@ -106,6 +121,9 @@ func Marshal(v interface{}) ([]byte, error) { return []byte{}, errors.New("Only a struct can be marshaled to TOML") } sval := reflect.ValueOf(v) + if isCustomMarshaler(mtype) { + return callCustomMarshaler(sval) + } t, err := valueToTree(mtype, sval) if err != nil { return []byte{}, err @@ -178,6 +196,8 @@ func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) { return valueToToml(mtype.Elem(), mval.Elem()) } switch { + case isCustomMarshaler(mtype): + return callCustomMarshaler(mval) case isTree(mtype): return valueToTree(mtype, mval) case isTreeSlice(mtype): diff --git a/marshal_test.go b/marshal_test.go index c8dee94d..b8d0259e 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3,6 +3,7 @@ package toml import ( "bytes" "encoding/json" + "fmt" "io/ioutil" "reflect" "testing" @@ -533,3 +534,45 @@ func TestNestedUnmarshal(t *testing.T) { t.Errorf("Bad nested unmarshal: expected %v, got %v", expected, result) } } + +type customMarshalerParent struct { + Person customMarshaler `toml:"name"` +} + +type customMarshaler struct { + FirsName string + LastName string +} + +func (c customMarshaler) MarshalTOML() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName) + return []byte(fullName), nil +} + +var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} +var customMarshalerToml = []byte(`Sally Fields`) +var nestedCustomMarshalerData = customMarshalerParent{Person: customMarshalerData} +var nestedCustomMarshalerToml = []byte(`name = "Sally Fields" +`) + +func TestCustomMarshaler(t *testing.T) { + result, err := Marshal(customMarshalerData) + if err != nil { + t.Fatal(err) + } + expected := customMarshalerToml + if !bytes.Equal(result, expected) { + t.Errorf("Bad custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +func TestNestedCustomMarshaler(t *testing.T) { + result, err := Marshal(nestedCustomMarshalerData) + if err != nil { + t.Fatal(err) + } + expected := nestedCustomMarshalerToml + if !bytes.Equal(result, expected) { + t.Errorf("Bad nested custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} diff --git a/tomltree_write.go b/tomltree_write.go index 89c3c422..6a7fa174 100644 --- a/tomltree_write.go +++ b/tomltree_write.go @@ -52,6 +52,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) { return strconv.FormatFloat(value, 'f', -1, 32), nil case string: return "\"" + encodeTomlString(value) + "\"", nil + case []byte: + b, _ := v.([]byte) + return tomlValueStringRepresentation(string(b)) case bool: if value { return "true", nil From bba6acb9f63613170e012fd7c2d3d5297915a7ef Mon Sep 17 00:00:00 2001 From: Carolyn Van Slyck Date: Tue, 4 Apr 2017 19:24:48 -0500 Subject: [PATCH 2/2] Handle array items with a custom marshaler --- marshal.go | 2 +- marshal_test.go | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/marshal.go b/marshal.go index 641440ce..a1d70104 100644 --- a/marshal.go +++ b/marshal.go @@ -51,7 +51,7 @@ func isPrimitive(mtype reflect.Type) bool { case reflect.String: return true case reflect.Struct: - return mtype == timeType + return mtype == timeType || isCustomMarshaler(mtype) default: return false } diff --git a/marshal_test.go b/marshal_test.go index b8d0259e..891222e9 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -536,7 +536,8 @@ func TestNestedUnmarshal(t *testing.T) { } type customMarshalerParent struct { - Person customMarshaler `toml:"name"` + Self customMarshaler `toml:"me"` + Friends []customMarshaler `toml:"friends"` } type customMarshaler struct { @@ -551,8 +552,12 @@ func (c customMarshaler) MarshalTOML() ([]byte, error) { var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} var customMarshalerToml = []byte(`Sally Fields`) -var nestedCustomMarshalerData = customMarshalerParent{Person: customMarshalerData} -var nestedCustomMarshalerToml = []byte(`name = "Sally Fields" +var nestedCustomMarshalerData = customMarshalerParent{ + Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"}, + Friends: []customMarshaler{customMarshalerData}, +} +var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] +me = "Maiku Suteda" `) func TestCustomMarshaler(t *testing.T) { @@ -562,7 +567,7 @@ func TestCustomMarshaler(t *testing.T) { } expected := customMarshalerToml if !bytes.Equal(result, expected) { - t.Errorf("Bad custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + t.Errorf("Bad custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) } } @@ -573,6 +578,6 @@ func TestNestedCustomMarshaler(t *testing.T) { } expected := nestedCustomMarshalerToml if !bytes.Equal(result, expected) { - t.Errorf("Bad nested custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) } }