diff --git a/marshal.go b/marshal.go index 032e0ffc..f8eba8a6 100644 --- a/marshal.go +++ b/marshal.go @@ -1004,8 +1004,18 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: d.visitor.visit() + mvalPtr := reflect.New(mtype) + + // Check if pointer to value implements the Unmarshaler interface. + if isCustomUnmarshaler(mvalPtr.Type()) { + if err := callCustomUnmarshaler(mvalPtr, tval); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err) + } + return mvalPtr.Elem(), nil + } + // Check if pointer to value implements the encoding.TextUnmarshaler. - if mvalPtr := reflect.New(mtype); isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { + if isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { if err := d.unmarshalText(tval, mvalPtr); err != nil { return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err) } diff --git a/marshal_test.go b/marshal_test.go index d72926bf..c61955d5 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3,6 +3,7 @@ package toml import ( "bytes" "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -3976,3 +3977,59 @@ func TestGithubIssue431(t *testing.T) { t.Errorf("UnmarshalTOML should not have been called") } } + +type durationString struct { + time.Duration +} + +func (d *durationString) UnmarshalTOML(v interface{}) error { + d.Duration = 10 * time.Second + return nil +} + +type config437Error struct { +} + +func (e *config437Error) UnmarshalTOML(v interface{}) error { + return errors.New("expected") +} + +type config437 struct { + HTTP struct { + PingTimeout durationString `toml:"PingTimeout"` + ErrorField config437Error + } `toml:"HTTP"` +} + +func TestGithubIssue437(t *testing.T) { + src := ` +[HTTP] +PingTimeout = "32m" +` + cfg := &config437{} + cfg.HTTP.PingTimeout = durationString{time.Second} + + r := strings.NewReader(src) + err := NewDecoder(r).Decode(cfg) + if err != nil { + t.Fatalf("unexpected errors %s", err) + } + expected := durationString{10 * time.Second} + if cfg.HTTP.PingTimeout != expected { + t.Fatalf("expected '%s', got '%s'", expected, cfg.HTTP.PingTimeout) + } +} + +func TestLeafUnmarshalerError(t *testing.T) { + src := ` +[HTTP] +ErrorField = "foo" +` + cfg := &config437{} + + r := strings.NewReader(src) + err := NewDecoder(r).Decode(cfg) + if err == nil { + t.Fatalf("error was expected") + } +}