From 48abf1061bdf6023cd1df56dcf66dc4b78978892 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Sat, 12 Sep 2020 18:39:07 -0400 Subject: [PATCH 1/2] toml.Unmarshaler supports leaf nodes --- marshal.go | 12 +++++++++++- marshal_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/marshal.go b/marshal.go index 032e0ffc..f8eba8a6 100644 --- a/marshal.go +++ b/marshal.go @@ -1004,8 +1004,18 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: d.visitor.visit() + mvalPtr := reflect.New(mtype) + + // Check if pointer to value implements the Unmarshaler interface. + if isCustomUnmarshaler(mvalPtr.Type()) { + if err := callCustomUnmarshaler(mvalPtr, tval); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err) + } + return mvalPtr.Elem(), nil + } + // Check if pointer to value implements the encoding.TextUnmarshaler. - if mvalPtr := reflect.New(mtype); isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { + if isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { if err := d.unmarshalText(tval, mvalPtr); err != nil { return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err) } diff --git a/marshal_test.go b/marshal_test.go index d72926bf..b77cc79b 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3976,3 +3976,37 @@ func TestGithubIssue431(t *testing.T) { t.Errorf("UnmarshalTOML should not have been called") } } + +type DurationString struct { + time.Duration +} + +func (d *DurationString) UnmarshalTOML(v interface{}) error { + d.Duration = 10 * time.Second + return nil +} + +type Config437 struct { + HTTP struct { + PingTimeout DurationString `toml:"PingTimeout"` + } `toml:"HTTP"` +} + +func TestGithubIssue437(t *testing.T) { + src := ` +[HTTP] +PingTimeout = "32m" +` + cfg := &Config437{} + cfg.HTTP.PingTimeout = DurationString{time.Second} + + r := strings.NewReader(src) + err := NewDecoder(r).Decode(cfg) + if err != nil { + t.Fatalf("unexpected errors %s", err) + } + expected := DurationString{10 * time.Second} + if cfg.HTTP.PingTimeout != expected { + t.Fatalf("expected '%s', got '%s'", expected, cfg.HTTP.PingTimeout) + } +} From dcc96feae2447cfff215076d04f3f23a6f16b783 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Sat, 12 Sep 2020 18:54:15 -0400 Subject: [PATCH 2/2] Add test for error --- marshal_test.go | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/marshal_test.go b/marshal_test.go index b77cc79b..c61955d5 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3,6 +3,7 @@ package toml import ( "bytes" "encoding/json" + "errors" "fmt" "io/ioutil" "os" @@ -3977,18 +3978,26 @@ func TestGithubIssue431(t *testing.T) { } } -type DurationString struct { +type durationString struct { time.Duration } -func (d *DurationString) UnmarshalTOML(v interface{}) error { +func (d *durationString) UnmarshalTOML(v interface{}) error { d.Duration = 10 * time.Second return nil } -type Config437 struct { +type config437Error struct { +} + +func (e *config437Error) UnmarshalTOML(v interface{}) error { + return errors.New("expected") +} + +type config437 struct { HTTP struct { - PingTimeout DurationString `toml:"PingTimeout"` + PingTimeout durationString `toml:"PingTimeout"` + ErrorField config437Error } `toml:"HTTP"` } @@ -3997,16 +4006,30 @@ func TestGithubIssue437(t *testing.T) { [HTTP] PingTimeout = "32m" ` - cfg := &Config437{} - cfg.HTTP.PingTimeout = DurationString{time.Second} + cfg := &config437{} + cfg.HTTP.PingTimeout = durationString{time.Second} r := strings.NewReader(src) err := NewDecoder(r).Decode(cfg) if err != nil { t.Fatalf("unexpected errors %s", err) } - expected := DurationString{10 * time.Second} + expected := durationString{10 * time.Second} if cfg.HTTP.PingTimeout != expected { t.Fatalf("expected '%s', got '%s'", expected, cfg.HTTP.PingTimeout) } } + +func TestLeafUnmarshalerError(t *testing.T) { + src := ` +[HTTP] +ErrorField = "foo" +` + cfg := &config437{} + + r := strings.NewReader(src) + err := NewDecoder(r).Decode(cfg) + if err == nil { + t.Fatalf("error was expected") + } +}