diff --git a/marshal.go b/marshal.go index 73056cd..032e0ff 100644 --- a/marshal.go +++ b/marshal.go @@ -742,6 +742,10 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) { d.visitor.visitAll() + if tval == nil { + return mvalPtr.Elem(), nil + } + if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil { return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err) } diff --git a/marshal_test.go b/marshal_test.go index b91b286..d72926b 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3941,3 +3941,38 @@ bar = 42 reflect.DeepEqual(x, expected) } + +type Config struct { + Key string `toml:"key"` + Obj Custom `toml:"obj"` +} + +type Custom struct { + v string +} + +func (c *Custom) UnmarshalTOML(v interface{}) error { + c.v = "called" + return nil +} + +func TestGithubIssue431(t *testing.T) { + doc := `key = "value"` + tree, err := LoadBytes([]byte(doc)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var c Config + if err := tree.Unmarshal(&c); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if c.Key != "value" { + t.Errorf("expected c.Key='value', not '%s'", c.Key) + } + + if c.Obj.v == "called" { + t.Errorf("UnmarshalTOML should not have been called") + } +}