From a1e8a8d7023b03aada384c1aa6f7d643e8d3cebf Mon Sep 17 00:00:00 2001 From: Jelte Fennema Date: Thu, 18 Jan 2018 23:08:34 +0100 Subject: [PATCH] Unmarshal into custom types and error on overflows (#209) This fixes two unmarshalling issues: 1. Unmarshalling into a custom integer/float type (e.g. `time.Duration`). 2. Checks for overflows happen before unmarshalling, erroring if an overflow would happen. Apart from this it also reduces code duplication a bit. --- marshal.go | 131 ++++++++++++++++++----------------------------------- 1 file changed, 45 insertions(+), 86 deletions(-) diff --git a/marshal.go b/marshal.go index 6280225..b433db4 100644 --- a/marshal.go +++ b/marshal.go @@ -486,96 +486,55 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.V return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: switch mtype.Kind() { - case reflect.Bool: - val, ok := tval.(bool) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to bool", tval, tval) + case reflect.Bool, reflect.Struct: + val := reflect.ValueOf(tval) + // if this passes for when mtype is reflect.Struct, tval is a time.Time + if !val.Type().ConvertibleTo(mtype) { + return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } - return reflect.ValueOf(val), nil - case reflect.Int: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval) - } - return reflect.ValueOf(int(val)), nil - case reflect.Int8: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval) - } - return reflect.ValueOf(int8(val)), nil - case reflect.Int16: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval) - } - return reflect.ValueOf(int16(val)), nil - case reflect.Int32: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval) - } - return reflect.ValueOf(int32(val)), nil - case reflect.Int64: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval) - } - return reflect.ValueOf(val), nil - case reflect.Uint: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval) - } - return reflect.ValueOf(uint(val)), nil - case reflect.Uint8: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval) - } - return reflect.ValueOf(uint8(val)), nil - case reflect.Uint16: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval) - } - return reflect.ValueOf(uint16(val)), nil - case reflect.Uint32: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval) - } - return reflect.ValueOf(uint32(val)), nil - case reflect.Uint64: - val, ok := tval.(int64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval) - } - return reflect.ValueOf(uint64(val)), nil - case reflect.Float32: - val, ok := tval.(float64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval) - } - return reflect.ValueOf(float32(val)), nil - case reflect.Float64: - val, ok := tval.(float64) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval) - } - return reflect.ValueOf(val), nil + + return val.Convert(mtype), nil case reflect.String: - val, ok := tval.(string) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to string", tval, tval) + val := reflect.ValueOf(tval) + // stupidly, int64 is convertible to string. So special case this. + if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Int64 { + return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } - return reflect.ValueOf(val), nil - case reflect.Struct: - val, ok := tval.(time.Time) - if !ok { - return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to time", tval, tval) + + return val.Convert(mtype), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val := reflect.ValueOf(tval) + if !val.Type().ConvertibleTo(mtype) { + return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } - return reflect.ValueOf(val), nil + if reflect.Indirect(reflect.New(mtype)).OverflowInt(val.Int()) { + return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) + } + + return val.Convert(mtype), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + val := reflect.ValueOf(tval) + if !val.Type().ConvertibleTo(mtype) { + return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) + } + if val.Int() < 0 { + return reflect.ValueOf(nil), fmt.Errorf("%v(%T) is negative so does not fit in %v", tval, tval, mtype.String()) + } + if reflect.Indirect(reflect.New(mtype)).OverflowUint(uint64(val.Int())) { + return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) + } + + return val.Convert(mtype), nil + case reflect.Float32, reflect.Float64: + val := reflect.ValueOf(tval) + if !val.Type().ConvertibleTo(mtype) { + return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) + } + if reflect.Indirect(reflect.New(mtype)).OverflowFloat(val.Float()) { + return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) + } + + return val.Convert(mtype), nil default: return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v(%v)", tval, tval, mtype, mtype.Kind()) }