Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MarshalTOML() if struct has defined a custom marshaler #151

Merged
merged 2 commits into from
Apr 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type tomlOpts struct {
}

var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()

// Check if the given marshall type maps to a TomlTree primitive
func isPrimitive(mtype reflect.Type) bool {
Expand All @@ -50,7 +51,7 @@ func isPrimitive(mtype reflect.Type) bool {
case reflect.String:
return true
case reflect.Struct:
return mtype == timeType
return mtype == timeType || isCustomMarshaler(mtype)
default:
return false
}
Expand Down Expand Up @@ -90,6 +91,20 @@ func isTree(mtype reflect.Type) bool {
}
}

func isCustomMarshaler(mtype reflect.Type) bool {
return mtype.Implements(marshalerType)
}

func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(Marshaler).MarshalTOML()
}

// Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML.
type Marshaler interface {
MarshalTOML() ([]byte, error)
}

/*
Marshal returns the TOML encoding of v. Behavior is similar to the Go json
encoder, except that there is no concept of a Marshaler interface or MarshalTOML
Expand All @@ -106,6 +121,9 @@ func Marshal(v interface{}) ([]byte, error) {
return []byte{}, errors.New("Only a struct can be marshaled to TOML")
}
sval := reflect.ValueOf(v)
if isCustomMarshaler(mtype) {
return callCustomMarshaler(sval)
}
t, err := valueToTree(mtype, sval)
if err != nil {
return []byte{}, err
Expand Down Expand Up @@ -178,6 +196,8 @@ func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
return valueToToml(mtype.Elem(), mval.Elem())
}
switch {
case isCustomMarshaler(mtype):
return callCustomMarshaler(mval)
case isTree(mtype):
return valueToTree(mtype, mval)
case isTreeSlice(mtype):
Expand Down
48 changes: 48 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package toml
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"reflect"
"testing"
Expand Down Expand Up @@ -533,3 +534,50 @@ func TestNestedUnmarshal(t *testing.T) {
t.Errorf("Bad nested unmarshal: expected %v, got %v", expected, result)
}
}

type customMarshalerParent struct {
Self customMarshaler `toml:"me"`
Friends []customMarshaler `toml:"friends"`
}

type customMarshaler struct {
FirsName string
LastName string
}

func (c customMarshaler) MarshalTOML() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName)
return []byte(fullName), nil
}

var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"}
var customMarshalerToml = []byte(`Sally Fields`)
var nestedCustomMarshalerData = customMarshalerParent{
Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"},
Friends: []customMarshaler{customMarshalerData},
}
var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"]
me = "Maiku Suteda"
`)

func TestCustomMarshaler(t *testing.T) {
result, err := Marshal(customMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := customMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}

func TestNestedCustomMarshaler(t *testing.T) {
result, err := Marshal(nestedCustomMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := nestedCustomMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
3 changes: 3 additions & 0 deletions tomltree_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
return strconv.FormatFloat(value, 'f', -1, 32), nil
case string:
return "\"" + encodeTomlString(value) + "\"", nil
case []byte:
b, _ := v.([]byte)
return tomlValueStringRepresentation(string(b))
case bool:
if value {
return "true", nil
Expand Down