Skip to content

Commit

Permalink
Use MarshalTOML() if struct has defined a custom marshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
carolynvs committed Apr 4, 2017
1 parent e32a2e0 commit 3157125
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
20 changes: 20 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type tomlOpts struct {
}

var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()

// Check if the given marshall type maps to a TomlTree primitive
func isPrimitive(mtype reflect.Type) bool {
Expand Down Expand Up @@ -90,6 +91,20 @@ func isTree(mtype reflect.Type) bool {
}
}

func isCustomMarshaler(mtype reflect.Type) bool {
return mtype.Implements(marshalerType)
}

func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(Marshaler).MarshalTOML()
}

// Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML.
type Marshaler interface {
MarshalTOML() ([]byte, error)
}

/*
Marshal returns the TOML encoding of v. Behavior is similar to the Go json
encoder, except that there is no concept of a Marshaler interface or MarshalTOML
Expand All @@ -106,6 +121,9 @@ func Marshal(v interface{}) ([]byte, error) {
return []byte{}, errors.New("Only a struct can be marshaled to TOML")
}
sval := reflect.ValueOf(v)
if isCustomMarshaler(mtype) {
return callCustomMarshaler(sval)
}
t, err := valueToTree(mtype, sval)
if err != nil {
return []byte{}, err
Expand Down Expand Up @@ -178,6 +196,8 @@ func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
return valueToToml(mtype.Elem(), mval.Elem())
}
switch {
case isCustomMarshaler(mtype):
return callCustomMarshaler(mval)
case isTree(mtype):
return valueToTree(mtype, mval)
case isTreeSlice(mtype):
Expand Down
43 changes: 43 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package toml
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"reflect"
"testing"
Expand Down Expand Up @@ -533,3 +534,45 @@ func TestNestedUnmarshal(t *testing.T) {
t.Errorf("Bad nested unmarshal: expected %v, got %v", expected, result)
}
}

type customMarshalerParent struct {
Person customMarshaler `toml:"name"`
}

type customMarshaler struct {
FirsName string
LastName string
}

func (c customMarshaler) MarshalTOML() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName)
return []byte(fullName), nil
}

var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"}
var customMarshalerToml = []byte(`Sally Fields`)
var nestedCustomMarshalerData = customMarshalerParent{Person: customMarshalerData}
var nestedCustomMarshalerToml = []byte(`name = "Sally Fields"
`)

func TestCustomMarshaler(t *testing.T) {
result, err := Marshal(customMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := customMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}

func TestNestedCustomMarshaler(t *testing.T) {
result, err := Marshal(nestedCustomMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := nestedCustomMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad nested custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
3 changes: 3 additions & 0 deletions tomltree_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
return strconv.FormatFloat(value, 'f', -1, 32), nil
case string:
return "\"" + encodeTomlString(value) + "\"", nil
case []byte:
b, _ := v.([]byte)
return tomlValueStringRepresentation(string(b))
case bool:
if value {
return "true", nil
Expand Down

0 comments on commit 3157125

Please sign in to comment.