diff --git a/marshal.go b/marshal.go index 641440ce..a1d70104 100644 --- a/marshal.go +++ b/marshal.go @@ -51,7 +51,7 @@ func isPrimitive(mtype reflect.Type) bool { case reflect.String: return true case reflect.Struct: - return mtype == timeType + return mtype == timeType || isCustomMarshaler(mtype) default: return false } diff --git a/marshal_test.go b/marshal_test.go index b8d0259e..6cd736cf 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -536,7 +536,8 @@ func TestNestedUnmarshal(t *testing.T) { } type customMarshalerParent struct { - Person customMarshaler `toml:"name"` + Self customMarshaler `toml:"me""` + Friends []customMarshaler `toml:"friends""` } type customMarshaler struct { @@ -551,8 +552,12 @@ func (c customMarshaler) MarshalTOML() ([]byte, error) { var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} var customMarshalerToml = []byte(`Sally Fields`) -var nestedCustomMarshalerData = customMarshalerParent{Person: customMarshalerData} -var nestedCustomMarshalerToml = []byte(`name = "Sally Fields" +var nestedCustomMarshalerData = customMarshalerParent{ + Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"}, + Friends: []customMarshaler{customMarshalerData}, +} +var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] +me = "Maiku Suteda" `) func TestCustomMarshaler(t *testing.T) { @@ -562,7 +567,7 @@ func TestCustomMarshaler(t *testing.T) { } expected := customMarshalerToml if !bytes.Equal(result, expected) { - t.Errorf("Bad custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + t.Errorf("Bad custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) } } @@ -573,6 +578,6 @@ func TestNestedCustomMarshaler(t *testing.T) { } expected := nestedCustomMarshalerToml if !bytes.Equal(result, expected) { - t.Errorf("Bad nested custom marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) } }