From ee9b902222c33ad156960573001c3aae9bcc963c Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Thu, 9 Sep 2021 21:25:14 -0400 Subject: [PATCH] unmarshal: convert ints if target type is compatible (#594) This is required to support custom types. Fixes #590 --- marshaler_test.go | 23 ++++++++++++++++------- unmarshaler.go | 34 +++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/marshaler_test.go b/marshaler_test.go index 8667a81..dcc976d 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -782,6 +782,22 @@ func TestIssue424(t *testing.T) { require.Equal(t, msg2, msg2parsed) } +func TestIssue567(t *testing.T) { + var m map[string]interface{} + err := toml.Unmarshal([]byte("A = 12:08:05"), &m) + require.NoError(t, err) + require.IsType(t, m["A"], toml.LocalTime{}) +} + +func TestIssue590(t *testing.T) { + type CustomType int + var cfg struct { + Option CustomType `toml:"option"` + } + err := toml.Unmarshal([]byte("option = 42"), &cfg) + require.NoError(t, err) +} + func ExampleMarshal() { type MyConfig struct { Version int @@ -806,10 +822,3 @@ func ExampleMarshal() { // Name = 'go-toml' // Tags = ['go', 'toml'] } - -func TestIssue567(t *testing.T) { - var m map[string]interface{} - err := toml.Unmarshal([]byte("A = 12:08:05"), &m) - require.NoError(t, err) - require.IsType(t, m["A"], toml.LocalTime{}) -} diff --git a/unmarshaler.go b/unmarshaler.go index 5e30a20..1fdd686 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -823,71 +823,79 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error { return err } + var r reflect.Value + switch v.Kind() { case reflect.Int64: v.SetInt(i) + return nil case reflect.Int32: if i < math.MinInt32 || i > math.MaxInt32 { return fmt.Errorf("toml: number %d does not fit in an int32", i) } - v.Set(reflect.ValueOf(int32(i))) - return nil + r = reflect.ValueOf(int32(i)) case reflect.Int16: if i < math.MinInt16 || i > math.MaxInt16 { return fmt.Errorf("toml: number %d does not fit in an int16", i) } - v.Set(reflect.ValueOf(int16(i))) + r = reflect.ValueOf(int16(i)) case reflect.Int8: if i < math.MinInt8 || i > math.MaxInt8 { return fmt.Errorf("toml: number %d does not fit in an int8", i) } - v.Set(reflect.ValueOf(int8(i))) + r = reflect.ValueOf(int8(i)) case reflect.Int: if i < minInt || i > maxInt { return fmt.Errorf("toml: number %d does not fit in an int", i) } - v.Set(reflect.ValueOf(int(i))) + r = reflect.ValueOf(int(i)) case reflect.Uint64: if i < 0 { return fmt.Errorf("toml: negative number %d does not fit in an uint64", i) } - v.Set(reflect.ValueOf(uint64(i))) + r = reflect.ValueOf(uint64(i)) case reflect.Uint32: if i < 0 || i > math.MaxUint32 { return fmt.Errorf("toml: negative number %d does not fit in an uint32", i) } - v.Set(reflect.ValueOf(uint32(i))) + r = reflect.ValueOf(uint32(i)) case reflect.Uint16: if i < 0 || i > math.MaxUint16 { return fmt.Errorf("toml: negative number %d does not fit in an uint16", i) } - v.Set(reflect.ValueOf(uint16(i))) + r = reflect.ValueOf(uint16(i)) case reflect.Uint8: if i < 0 || i > math.MaxUint8 { return fmt.Errorf("toml: negative number %d does not fit in an uint8", i) } - v.Set(reflect.ValueOf(uint8(i))) + r = reflect.ValueOf(uint8(i)) case reflect.Uint: if i < 0 { return fmt.Errorf("toml: negative number %d does not fit in an uint", i) } - v.Set(reflect.ValueOf(uint(i))) + r = reflect.ValueOf(uint(i)) case reflect.Interface: - v.Set(reflect.ValueOf(i)) + r = reflect.ValueOf(i) default: - err = fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) + return fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) } - return err + if !r.Type().AssignableTo(v.Type()) { + r = r.Convert(v.Type()) + } + + v.Set(r) + + return nil } func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error {