From 3f2bb0b36386a476cc920330c97abc7451b0bac4 Mon Sep 17 00:00:00 2001 From: Vincent Serpoul Date: Fri, 7 May 2021 10:29:21 +0800 Subject: [PATCH] golangci-lint (#530) --- .golangci.toml | 2 +- localtime_test.go | 11 + parser.go | 4 +- targets.go | 408 +++++++++++++++++++++++++++-------- targets_test.go | 29 ++- toml_testgen_support_test.go | 52 +++-- toml_testgen_test.go | 152 ++++++++++++- unmarshaler.go | 56 ++++- unmarshaler_test.go | 15 ++ 9 files changed, 611 insertions(+), 118 deletions(-) diff --git a/.golangci.toml b/.golangci.toml index a38c60b..0e71b20 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -79,6 +79,6 @@ enable = [ "varcheck", "wastedassign", "whitespace", - "wrapcheck", + # "wrapcheck", "wsl" ] diff --git a/localtime_test.go b/localtime_test.go index a103bc0..6741504 100644 --- a/localtime_test.go +++ b/localtime_test.go @@ -56,6 +56,7 @@ func TestDates(t *testing.T) { if got := test.date.String(); got != test.wantStr { t.Errorf("%#v.String() = %q, want %q", test.date, got, test.wantStr) } + if got := test.date.In(test.loc); !got.Equal(test.wantTime) { t.Errorf("%#v.In(%v) = %v, want %v", test.date, test.loc, got, test.wantTime) } @@ -109,6 +110,7 @@ func TestParseDate(t *testing.T) { if got != test.want { t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want) } + if err != nil && test.want != (emptyDate) { t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str) } @@ -170,6 +172,7 @@ func TestDateArithmetic(t *testing.T) { if got := test.start.AddDays(test.days); got != test.end { t.Errorf("[%s] %#v.AddDays(%v) = %#v, want %#v", test.desc, test.start, test.days, got, test.end) } + if got := test.end.DaysSince(test.start); got != test.days { t.Errorf("[%s] %#v.Sub(%#v) = %v, want %v", test.desc, test.end, test.start, got, test.days) } @@ -231,9 +234,11 @@ func TestTimeToString(t *testing.T) { continue } + if gotTime != test.time { t.Errorf("ParseLocalTime(%q) = %+v, want %+v", test.str, gotTime, test.time) } + if test.roundTrip { gotStr := test.time.String() if gotStr != test.str { @@ -303,9 +308,11 @@ func TestDateTimeToString(t *testing.T) { continue } + if gotDateTime != test.dateTime { t.Errorf("ParseLocalDateTime(%q) = %+v, want %+v", test.str, gotDateTime, test.dateTime) } + if test.roundTrip { gotStr := test.dateTime.String() if gotStr != test.str { @@ -444,6 +451,7 @@ func TestMarshalJSON(t *testing.T) { if err != nil { t.Fatal(err) } + if got := string(bgot); got != test.want { t.Errorf("%#v: got %s, want %s", test.value, got, test.want) } @@ -472,6 +480,7 @@ func TestUnmarshalJSON(t *testing.T) { if err := json.Unmarshal([]byte(test.data), test.ptr); err != nil { t.Fatalf("%s: %v", test.data, err) } + if !cmpEqual(test.ptr, test.want) { t.Errorf("%s: got %#v, want %#v", test.data, test.ptr, test.want) } @@ -486,9 +495,11 @@ func TestUnmarshalJSON(t *testing.T) { if json.Unmarshal([]byte(bad), &d) == nil { t.Errorf("%q, LocalDate: got nil, want error", bad) } + if json.Unmarshal([]byte(bad), &tm) == nil { t.Errorf("%q, LocalTime: got nil, want error", bad) } + if json.Unmarshal([]byte(bad), &dt) == nil { t.Errorf("%q, LocalDateTime: got nil, want error", bad) } diff --git a/parser.go b/parser.go index 1ff2272..9724190 100644 --- a/parser.go +++ b/parser.go @@ -370,7 +370,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { return parent, rest, err } -var errArrayCanNotStartWithComma = errors.New("array cannot start with comma") +var errArrayCannotStartWithComma = errors.New("array cannot start with comma") //nolint:funlen,cyclop func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { @@ -409,7 +409,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { if b[0] == ',' { if first { - return parent, nil, errArrayCanNotStartWithComma + return parent, nil, errArrayCannotStartWithComma } b = b[1:] diff --git a/targets.go b/targets.go index b25b203..68f9428 100644 --- a/targets.go +++ b/targets.go @@ -1,6 +1,7 @@ package toml import ( + "errors" "fmt" "math" "reflect" @@ -38,26 +39,31 @@ func (t valueTarget) get() reflect.Value { func (t valueTarget) set(v reflect.Value) error { reflect.Value(t).Set(v) + return nil } func (t valueTarget) setString(v string) error { t.get().SetString(v) + return nil } func (t valueTarget) setBool(v bool) error { t.get().SetBool(v) + return nil } func (t valueTarget) setInt64(v int64) error { t.get().SetInt(v) + return nil } func (t valueTarget) setFloat64(v float64) error { t.get().SetFloat(v) + return nil } @@ -71,23 +77,48 @@ func (t interfaceTarget) get() reflect.Value { } func (t interfaceTarget) set(v reflect.Value) error { - return t.x.set(v) + err := t.x.set(v) + if err != nil { + return fmt.Errorf("interfaceTarget set: %w", err) + } + + return nil } func (t interfaceTarget) setString(v string) error { - return t.x.setString(v) + err := t.x.setString(v) + if err != nil { + return fmt.Errorf("interfaceTarget setString: %w", err) + } + + return nil } func (t interfaceTarget) setBool(v bool) error { - return t.x.setBool(v) + err := t.x.setBool(v) + if err != nil { + return fmt.Errorf("interfaceTarget setBool: %w", err) + } + + return nil } func (t interfaceTarget) setInt64(v int64) error { - return t.x.setInt64(v) + err := t.x.setInt64(v) + if err != nil { + return fmt.Errorf("interfaceTarget setInt64: %w", err) + } + + return nil } func (t interfaceTarget) setFloat64(v float64) error { - return t.x.setFloat64(v) + err := t.x.setFloat64(v) + if err != nil { + return fmt.Errorf("interfaceTarget setFloat64: %w", err) + } + + return nil } // mapTarget targets a specific key of a map. @@ -102,6 +133,7 @@ func (t mapTarget) get() reflect.Value { func (t mapTarget) set(v reflect.Value) error { t.v.SetMapIndex(t.k, v) + return nil } @@ -121,6 +153,12 @@ func (t mapTarget) setFloat64(v float64) error { return t.set(reflect.ValueOf(v)) } +var ( + errValIndexExpectingSlice = errors.New("expecting a slice") + errValIndexCannotInitSlice = errors.New("cannot initialize a slice") +) + +//nolint:cyclop // makes sure that the value pointed at by t is indexable (Slice, Array), or // dereferences to an indexable (Ptr, Interface). func ensureValueIndexable(t target) error { @@ -129,159 +167,319 @@ func ensureValueIndexable(t target) error { switch f.Type().Kind() { case reflect.Slice: if f.IsNil() { - return t.set(reflect.MakeSlice(f.Type(), 0, 0)) + err := t.set(reflect.MakeSlice(f.Type(), 0, 0)) + if err != nil { + return fmt.Errorf("ensureValueIndexable: %w", err) + } + + return nil } case reflect.Interface: if f.IsNil() || f.Elem().Type() != sliceInterfaceType { - return t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) + err := t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) + if err != nil { + return fmt.Errorf("ensureValueIndexable: %w", err) + } + + return nil } + if f.Elem().Type().Kind() != reflect.Slice { - return fmt.Errorf("interface is pointing to a %s, not a slice", f.Kind()) + return fmt.Errorf("ensureValueIndexable: %w, not a %s", errValIndexExpectingSlice, f.Kind()) } case reflect.Ptr: if f.IsNil() { ptr := reflect.New(f.Type().Elem()) + err := t.set(ptr) if err != nil { - return err + return fmt.Errorf("ensureValueIndexable: %w", err) } + f = t.get() } + return ensureValueIndexable(valueTarget(f.Elem())) case reflect.Array: // arrays are always initialized. default: - return fmt.Errorf("cannot initialize a slice in %s", f.Kind()) + return fmt.Errorf("ensureValueIndexable: %w with %s", errValIndexCannotInitSlice, f.Kind()) } + return nil } -var sliceInterfaceType = reflect.TypeOf([]interface{}{}) -var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) +var ( + sliceInterfaceType = reflect.TypeOf([]interface{}{}) + mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) +) -func ensureMapIfInterface(x target) { +func ensureMapIfInterface(x target) error { v := x.get() + if v.Kind() == reflect.Interface && v.IsNil() { newElement := reflect.MakeMap(mapStringInterfaceType) - x.set(newElement) + + err := x.set(newElement) + if err != nil { + return fmt.Errorf("ensureMapIfInterface: %w", err) + } } + + return nil } +var errSetStringCannotAssignString = errors.New("cannot assign string") + func setString(t target, v string) error { f := t.get() switch f.Kind() { case reflect.String: - return t.setString(v) + err := t.setString(v) + if err != nil { + return fmt.Errorf("setString: %w", err) + } + + return nil case reflect.Interface: - return t.set(reflect.ValueOf(v)) + err := t.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setString: %w", err) + } + + return nil default: - return fmt.Errorf("cannot assign string to a %s", f.Kind()) + return fmt.Errorf("setString: %w to a %s", errSetStringCannotAssignString, f.Kind()) } } +var errSetBoolCannotAssignBool = errors.New("cannot assign bool") + func setBool(t target, v bool) error { f := t.get() switch f.Kind() { case reflect.Bool: - return t.setBool(v) + err := t.setBool(v) + if err != nil { + return fmt.Errorf("setBool: %w", err) + } + + return nil case reflect.Interface: - return t.set(reflect.ValueOf(v)) + err := t.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setBool: %w", err) + } + + return nil default: - return fmt.Errorf("cannot assign bool to a %s", f.String()) + return fmt.Errorf("setBool: %w to a %s", errSetBoolCannotAssignBool, f.String()) } } -const maxInt = int64(^uint(0) >> 1) -const minInt = -maxInt - 1 +const ( + maxInt = int64(^uint(0) >> 1) + minInt = -maxInt - 1 +) +var ( + errSetInt64InInt32 = errors.New("does not fit in an int32") + errSetInt64InInt16 = errors.New("does not fit in an int16") + errSetInt64InInt8 = errors.New("does not fit in an int8") + errSetInt64InInt = errors.New("does not fit in an int") + errSetInt64InUint64 = errors.New("negative integer does not fit in an uint64") + errSetInt64InUint32 = errors.New("negative integer does not fit in an uint32") + errSetInt64InUint32Max = errors.New("integer does not fit in an uint32") + errSetInt64InUint16 = errors.New("negative integer does not fit in an uint16") + errSetInt64InUint16Max = errors.New("integer does not fit in an uint16") + errSetInt64InUint8 = errors.New("negative integer does not fit in an uint8") + errSetInt64InUint8Max = errors.New("integer does not fit in an uint8") + errSetInt64InUint = errors.New("negative integer does not fit in an uint") + errSetInt64Unknown = errors.New("does not fit in an uint") +) + +//nolint:funlen,gocognit,cyclop,gocyclo func setInt64(t target, v int64) error { f := t.get() switch f.Kind() { case reflect.Int64: - return t.setInt64(v) + err := t.setInt64(v) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Int32: if v < math.MinInt32 || v > math.MaxInt32 { - return fmt.Errorf("integer %d does not fit in an int32", v) + return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt32) } - return t.set(reflect.ValueOf(int32(v))) + + err := t.set(reflect.ValueOf(int32(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Int16: if v < math.MinInt16 || v > math.MaxInt16 { - return fmt.Errorf("integer %d does not fit in an int16", v) + return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt16) } - return t.set(reflect.ValueOf(int16(v))) + + err := t.set(reflect.ValueOf(int16(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Int8: if v < math.MinInt8 || v > math.MaxInt8 { - return fmt.Errorf("integer %d does not fit in an int8", v) + return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt8) } - return t.set(reflect.ValueOf(int8(v))) + + err := t.set(reflect.ValueOf(int8(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Int: if v < minInt || v > maxInt { - return fmt.Errorf("integer %d does not fit in an int", v) + return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt) } - return t.set(reflect.ValueOf(int(v))) + err := t.set(reflect.ValueOf(int(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Uint64: if v < 0 { - return fmt.Errorf("negative integer %d cannot be stored in an uint64", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint64) } - return t.set(reflect.ValueOf(uint64(v))) + + err := t.set(reflect.ValueOf(uint64(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Uint32: if v < 0 { - return fmt.Errorf("negative integer %d cannot be stored in an uint32", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32) } + if v > math.MaxUint32 { - return fmt.Errorf("integer %d cannot be stored in an uint32", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32Max) } - return t.set(reflect.ValueOf(uint32(v))) + + err := t.set(reflect.ValueOf(uint32(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Uint16: if v < 0 { - return fmt.Errorf("negative integer %d cannot be stored in an uint16", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16) } + if v > math.MaxUint16 { - return fmt.Errorf("integer %d cannot be stored in an uint16", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16Max) } - return t.set(reflect.ValueOf(uint16(v))) + + err := t.set(reflect.ValueOf(uint16(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Uint8: if v < 0 { - return fmt.Errorf("negative integer %d cannot be stored in an uint8", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8) } + if v > math.MaxUint8 { - return fmt.Errorf("integer %d cannot be stored in an uint8", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8Max) } - return t.set(reflect.ValueOf(uint8(v))) + + err := t.set(reflect.ValueOf(uint8(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Uint: if v < 0 { - return fmt.Errorf("negative integer %d cannot be stored in an uint", v) + return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint) } - return t.set(reflect.ValueOf(uint(v))) + + err := t.set(reflect.ValueOf(uint(v))) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil case reflect.Interface: - return t.set(reflect.ValueOf(v)) + err := t.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setInt64: %w", err) + } + + return nil default: - return fmt.Errorf("cannot assign int64 to a %s", f.String()) + return fmt.Errorf("setInt64: %s, %w", f.String(), errSetInt64Unknown) } } +var ( + errSetFloat64InFloat32Max = errors.New("does not fit in an float32") + errSetFloat64Unknown = errors.New("does not fit in an float32") +) + func setFloat64(t target, v float64) error { f := t.get() switch f.Kind() { case reflect.Float64: - return t.setFloat64(v) + err := t.setFloat64(v) + if err != nil { + return fmt.Errorf("setFloat64: %w", err) + } + + return nil case reflect.Float32: if v > math.MaxFloat32 { - return fmt.Errorf("float %f cannot be stored in a float32", v) + return fmt.Errorf("setFloat64: %f %w", v, errSetFloat64InFloat32Max) } - return t.set(reflect.ValueOf(float32(v))) + + err := t.set(reflect.ValueOf(float32(v))) + if err != nil { + return fmt.Errorf("setFloat64: %w", err) + } + + return nil case reflect.Interface: - return t.set(reflect.ValueOf(v)) + err := t.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setFloat64: %w", err) + } + + return nil default: - return fmt.Errorf("cannot assign float64 to a %s", f.String()) + return fmt.Errorf("setFloat64: %s %w", f.String(), errSetFloat64Unknown) } } +var ( + errElementAtCannotOn = errors.New("cannot elementAt") + errElementAtCannotOnUnknown = errors.New("cannot elementAt") +) + +//nolint:cyclop // Returns the element at idx of the value pointed at by target, or an error if // t does not point to an indexable. // If the target points to an Array and idx is out of bounds, it returns @@ -291,95 +489,111 @@ func elementAt(t target, idx int) (target, error) { switch f.Kind() { case reflect.Slice: + //nolint:godox // TODO: use the idx function argument and avoid alloc if possible. idx := f.Len() + err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) if err != nil { - return nil, err + return nil, fmt.Errorf("elementAt: %w", err) } + return valueTarget(t.get().Index(idx)), nil case reflect.Array: if idx >= f.Len() { return nil, nil } + return valueTarget(f.Index(idx)), nil case reflect.Interface: if f.IsNil() { panic("interface should have been initialized") } + ifaceElem := f.Elem() if ifaceElem.Kind() != reflect.Slice { - return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind()) + return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOn, f.Kind()) } + idx := ifaceElem.Len() newElem := reflect.New(ifaceElem.Type().Elem()).Elem() newSlice := reflect.Append(ifaceElem, newElem) + err := t.set(newSlice) if err != nil { - return nil, err + return nil, fmt.Errorf("elementAt: %w", err) } + return valueTarget(t.get().Elem().Index(idx)), nil case reflect.Ptr: return elementAt(valueTarget(f.Elem()), idx) default: - return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind()) + return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOnUnknown, f.Kind()) } } -func (d *decoder) scopeTableTarget(append bool, t target, name string) (target, bool, error) { +//nolint:cyclop +func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (target, bool, error) { x := t.get() switch x.Kind() { // Kinds that need to recurse - case reflect.Interface: - t, err := scopeInterface(append, t) + t, err := scopeInterface(shouldAppend, t) if err != nil { - return t, false, err + return t, false, fmt.Errorf("scopeTableTarget: %w", err) } - return d.scopeTableTarget(append, t, name) + + return d.scopeTableTarget(shouldAppend, t, name) case reflect.Ptr: t, err := scopePtr(t) if err != nil { - return t, false, err + return t, false, fmt.Errorf("scopeTableTarget: %w", err) } - return d.scopeTableTarget(append, t, name) + + return d.scopeTableTarget(shouldAppend, t, name) case reflect.Slice: - t, err := scopeSlice(append, t) + t, err := scopeSlice(shouldAppend, t) if err != nil { - return t, false, err + return t, false, fmt.Errorf("scopeTableTarget: %w", err) } - append = false - return d.scopeTableTarget(append, t, name) + shouldAppend = false + + return d.scopeTableTarget(shouldAppend, t, name) case reflect.Array: - t, err := d.scopeArray(append, t) + t, err := d.scopeArray(shouldAppend, t) if err != nil { - return t, false, err + return t, false, fmt.Errorf("scopeTableTarget: %w", err) } - append = false - return d.scopeTableTarget(append, t, name) + shouldAppend = false + + return d.scopeTableTarget(shouldAppend, t, name) // Terminal kinds - case reflect.Struct: return scopeStruct(x, name) case reflect.Map: if x.IsNil() { - t.set(reflect.MakeMap(x.Type())) + err := t.set(reflect.MakeMap(x.Type())) + if err != nil { + return t, false, fmt.Errorf("scopeTableTarget: %w", err) + } + x = t.get() } return scopeMap(x, name) default: - panic(fmt.Errorf("can't scope on a %s", x.Kind())) + panic(fmt.Sprintf("can't scope on a %s", x.Kind())) } } -func scopeInterface(append bool, t target) (target, error) { - err := initInterface(append, t) +func scopeInterface(shouldAppend bool, t target) (target, error) { + err := initInterface(shouldAppend, t) if err != nil { return t, err } + return interfaceTarget{t}, nil } @@ -388,6 +602,7 @@ func scopePtr(t target) (target, error) { if err != nil { return t, err } + return valueTarget(t.get().Elem()), nil } @@ -396,13 +611,19 @@ func initPtr(t target) error { if !x.IsNil() { return nil } - return t.set(reflect.New(x.Type().Elem())) + + err := t.set(reflect.New(x.Type().Elem())) + if err != nil { + return fmt.Errorf("initPtr: %w", err) + } + + return nil } // initInterface makes sure that the interface pointed at by the target is not // nil. // Returns the target to the initialized value of the target. -func initInterface(append bool, t target) error { +func initInterface(shouldAppend bool, t target) error { x := t.get() if x.Kind() != reflect.Interface { @@ -414,54 +635,63 @@ func initInterface(append bool, t target) error { } var newElement reflect.Value - if append { + if shouldAppend { newElement = reflect.MakeSlice(sliceInterfaceType, 0, 0) } else { newElement = reflect.MakeMap(mapStringInterfaceType) } + err := t.set(newElement) if err != nil { - return err + return fmt.Errorf("initInterface: %w", err) } return nil } -func scopeSlice(append bool, t target) (target, error) { +func scopeSlice(shouldAppend bool, t target) (target, error) { v := t.get() - if append { + if shouldAppend { newElem := reflect.New(v.Type().Elem()) newSlice := reflect.Append(v, newElem.Elem()) + err := t.set(newSlice) if err != nil { - return t, err + return t, fmt.Errorf("scopeSlice: %w", err) } + v = t.get() } + return valueTarget(v.Index(v.Len() - 1)), nil } -func (d *decoder) scopeArray(append bool, t target) (target, error) { +var errScopeArrayNotEnoughSpace = errors.New("not enough space in the array") + +func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) { v := t.get() - idx := d.arrayIndex(append, v) + idx := d.arrayIndex(shouldAppend, v) if idx >= v.Len() { - return nil, fmt.Errorf("not enough space in the array") + return nil, errScopeArrayNotEnoughSpace } return valueTarget(v.Index(idx)), nil } +var errScopeMapCannotConvertStringToKey = errors.New("cannot convert string into map key type") + func scopeMap(v reflect.Value, name string) (target, bool, error) { k := reflect.ValueOf(name) keyType := v.Type().Key() if !k.Type().AssignableTo(keyType) { if !k.Type().ConvertibleTo(keyType) { - return nil, false, fmt.Errorf("cannot convert string into map key type %s", keyType) + return nil, false, fmt.Errorf("scopeMap: %w %s", errScopeMapCannotConvertStringToKey, keyType) } + k = k.Convert(keyType) } @@ -487,6 +717,7 @@ func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) { c.l.RLock() paths, ok := c.m[t] c.l.RUnlock() + return paths, ok } @@ -502,13 +733,14 @@ var globalFieldPathsCache = fieldPathsCache{ } func scopeStruct(v reflect.Value, name string) (target, bool, error) { + //nolint:godox // TODO: cache this, and reduce allocations - fieldPaths, ok := globalFieldPathsCache.get(v.Type()) if !ok { fieldPaths = map[string][]int{} path := make([]int, 0, 16) + var walk func(reflect.Value) walk = func(v reflect.Value) { t := v.Type() @@ -516,6 +748,7 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) { l := len(path) path = append(path, i) f := t.Field(i) + if f.Anonymous { walk(v.Field(i)) } else if f.PkgPath == "" { @@ -545,6 +778,7 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) { if !ok { path, ok = fieldPaths[strings.ToLower(name)] } + if !ok { return nil, false, nil } diff --git a/targets_test.go b/targets_test.go index 86aab96..7b57fe0 100644 --- a/targets_test.go +++ b/targets_test.go @@ -9,6 +9,8 @@ import ( ) func TestStructTarget_Ensure(t *testing.T) { + t.Parallel() + examples := []struct { desc string input reflect.Value @@ -31,14 +33,23 @@ func TestStructTarget_Ensure(t *testing.T) { test: func(v reflect.Value, err error) { assert.NoError(t, err) require.False(t, v.IsNil()) - s := v.Interface().([]string) + + s, ok := v.Interface().([]string) + if !ok { + t.Errorf("interface %v should be castable into []string", s) + return + } + assert.Equal(t, []string{"foo"}, s) }, }, } for _, e := range examples { + e := e t.Run(e.desc, func(t *testing.T) { + t.Parallel() + d := decoder{} target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) @@ -50,6 +61,8 @@ func TestStructTarget_Ensure(t *testing.T) { } func TestStructTarget_SetString(t *testing.T) { + t.Parallel() + str := "value" examples := []struct { @@ -86,7 +99,10 @@ func TestStructTarget_SetString(t *testing.T) { } for _, e := range examples { + e := e t.Run(e.desc, func(t *testing.T) { + t.Parallel() + d := decoder{} target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) @@ -98,7 +114,11 @@ func TestStructTarget_SetString(t *testing.T) { } func TestPushNew(t *testing.T) { + t.Parallel() + t.Run("slice of strings", func(t *testing.T) { + t.Parallel() + type Doc struct { A []string } @@ -120,6 +140,8 @@ func TestPushNew(t *testing.T) { }) t.Run("slice of interfaces", func(t *testing.T) { + t.Parallel() + type Doc struct { A []interface{} } @@ -142,6 +164,8 @@ func TestPushNew(t *testing.T) { } func TestScope_Struct(t *testing.T) { + t.Parallel() + examples := []struct { desc string input reflect.Value @@ -167,7 +191,10 @@ func TestScope_Struct(t *testing.T) { } for _, e := range examples { + e := e t.Run(e.desc, func(t *testing.T) { + t.Parallel() + dec := decoder{} x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name) assert.Equal(t, e.found, found) diff --git a/toml_testgen_support_test.go b/toml_testgen_support_test.go index 5edd25b..e2617e6 100644 --- a/toml_testgen_support_test.go +++ b/toml_testgen_support_test.go @@ -30,6 +30,7 @@ func testgenValid(t *testing.T, input string, jsonRef string) { t.Logf("Input TOML:\n%s", input) doc := map[string]interface{}{} + err := toml.Unmarshal([]byte(input), &doc) if err != nil { t.Fatalf("failed parsing toml: %s", err) @@ -49,25 +50,23 @@ func testgenValid(t *testing.T, input string, jsonRef string) { require.Equal(t, refDoc, doc2) } -type testGenDescNode struct { - Type string - Value interface{} -} - func testgenBuildRefDoc(jsonRef string) map[string]interface{} { descTree := map[string]interface{}{} + err := json.Unmarshal([]byte(jsonRef), &descTree) if err != nil { - panic(fmt.Errorf("reference doc should be valid JSON: %s", err)) + panic(fmt.Sprintf("reference doc should be valid JSON: %s", err)) } doc := testGenTranslateDesc(descTree) if doc == nil { return map[string]interface{}{} } + return doc.(map[string]interface{}) } +//nolint:funlen,gocognit,cyclop func testGenTranslateDesc(input interface{}) interface{} { a, ok := input.([]interface{}) if ok { @@ -75,48 +74,69 @@ func testGenTranslateDesc(input interface{}) interface{} { for i, v := range a { xs[i] = testGenTranslateDesc(v) } + return xs } - d := input.(map[string]interface{}) + d, ok := input.(map[string]interface{}) + if !ok { + panic(fmt.Sprintf("input should be valid map[string]: %v", input)) + } - var dtype string - var dvalue interface{} + var ( + dtype string + dvalue interface{} + ) + //nolint:nestif if len(d) == 2 { dtypeiface, ok := d["type"] if ok { dvalue, ok = d["value"] if ok { - dtype = dtypeiface.(string) + var okdt bool + + dtype, okdt = dtypeiface.(string) + if !okdt { + panic(fmt.Sprintf("dtypeiface should be valid string: %v", dtypeiface)) + } + switch dtype { case "string": return dvalue.(string) case "float": v, err := strconv.ParseFloat(dvalue.(string), 64) if err != nil { - panic(fmt.Errorf("invalid float '%s': %s", dvalue, err)) + panic(fmt.Sprintf("invalid float '%s': %s", dvalue, err)) } + return v case "integer": v, err := strconv.ParseInt(dvalue.(string), 10, 64) if err != nil { - panic(fmt.Errorf("invalid int '%s': %s", dvalue, err)) + panic(fmt.Sprintf("invalid int '%s': %s", dvalue, err)) } + return v case "bool": return dvalue.(string) == "true" case "datetime": dt, err := time.Parse("2006-01-02T15:04:05Z", dvalue.(string)) if err != nil { - panic(fmt.Errorf("invalid datetime '%s': %s", dvalue, err)) + panic(fmt.Sprintf("invalid datetime '%s': %s", dvalue, err)) } + return dt case "array": if dvalue == nil { return nil } - a := dvalue.([]interface{}) + + a, oka := dvalue.([]interface{}) + if !oka { + panic(fmt.Sprintf("a should be valid []interface{}: %v", a)) + } + xs := make([]interface{}, len(a)) for i, v := range a { @@ -125,7 +145,8 @@ func testGenTranslateDesc(input interface{}) interface{} { return xs } - panic(fmt.Errorf("unknown type: %s", dtype)) + + panic(fmt.Sprintf("unknown type: %s", dtype)) } } } @@ -134,5 +155,6 @@ func testGenTranslateDesc(input interface{}) interface{} { for k, v := range d { dest[k] = testGenTranslateDesc(v) } + return dest } diff --git a/toml_testgen_test.go b/toml_testgen_test.go index c850b0a..b0d82cd 100644 --- a/toml_testgen_test.go +++ b/toml_testgen_test.go @@ -6,26 +6,36 @@ import ( ) func TestInvalidDatetimeMalformedNoLeads(t *testing.T) { + t.Parallel() + input := `no-leads = 1987-7-05T17:45:00Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedNoSecs(t *testing.T) { + t.Parallel() + input := `no-secs = 1987-07-05T17:45Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedNoT(t *testing.T) { + t.Parallel() + input := `no-t = 1987-07-0517:45:00Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedWithMilli(t *testing.T) { + t.Parallel() + input := `with-milli = 1987-07-5T17:45:00.12Z` testgenInvalid(t, input) } func TestInvalidDuplicateKeyTable(t *testing.T) { + t.Parallel() + input := `[fruit] type = "apple" @@ -35,71 +45,97 @@ apple = "yes"` } func TestInvalidDuplicateKeys(t *testing.T) { + t.Parallel() + input := `dupe = false dupe = true` testgenInvalid(t, input) } func TestInvalidDuplicateTables(t *testing.T) { + t.Parallel() + input := `[a] [a]` testgenInvalid(t, input) } func TestInvalidEmptyImplicitTable(t *testing.T) { + t.Parallel() + input := `[naughty..naughty]` testgenInvalid(t, input) } func TestInvalidEmptyTable(t *testing.T) { + t.Parallel() + input := `[]` testgenInvalid(t, input) } func TestInvalidFloatNoLeadingZero(t *testing.T) { + t.Parallel() + input := `answer = .12345 neganswer = -.12345` testgenInvalid(t, input) } func TestInvalidFloatNoTrailingDigits(t *testing.T) { + t.Parallel() + input := `answer = 1. neganswer = -1.` testgenInvalid(t, input) } func TestInvalidKeyEmpty(t *testing.T) { + t.Parallel() + input := ` = 1` testgenInvalid(t, input) } func TestInvalidKeyHash(t *testing.T) { + t.Parallel() + input := `a# = 1` testgenInvalid(t, input) } func TestInvalidKeyNewline(t *testing.T) { + t.Parallel() + input := `a = 1` testgenInvalid(t, input) } func TestInvalidKeyOpenBracket(t *testing.T) { + t.Parallel() + input := `[abc = 1` testgenInvalid(t, input) } func TestInvalidKeySingleOpenBracket(t *testing.T) { + t.Parallel() + input := `[` testgenInvalid(t, input) } func TestInvalidKeySpace(t *testing.T) { + t.Parallel() + input := `a b = 1` testgenInvalid(t, input) } func TestInvalidKeyStartBracket(t *testing.T) { + t.Parallel() + input := `[a] [xyz = 5 [b]` @@ -107,31 +143,43 @@ func TestInvalidKeyStartBracket(t *testing.T) { } func TestInvalidKeyTwoEquals(t *testing.T) { + t.Parallel() + input := `key= = 1` testgenInvalid(t, input) } func TestInvalidStringBadByteEscape(t *testing.T) { + t.Parallel() + input := `naughty = "\xAg"` testgenInvalid(t, input) } func TestInvalidStringBadEscape(t *testing.T) { + t.Parallel() + input := `invalid-escape = "This string has a bad \a escape character."` testgenInvalid(t, input) } func TestInvalidStringByteEscapes(t *testing.T) { + t.Parallel() + input := `answer = "\x33"` testgenInvalid(t, input) } func TestInvalidStringNoClose(t *testing.T) { + t.Parallel() + input := `no-ending-quote = "One time, at band camp` testgenInvalid(t, input) } func TestInvalidTableArrayImplicit(t *testing.T) { + t.Parallel() + input := "# This test is a bit tricky. It should fail because the first use of\n" + "# `[[albums.songs]]` without first declaring `albums` implies that `albums`\n" + "# must be a table. The alternative would be quite weird. Namely, it wouldn't\n" + @@ -150,46 +198,62 @@ func TestInvalidTableArrayImplicit(t *testing.T) { } func TestInvalidTableArrayMalformedBracket(t *testing.T) { + t.Parallel() + input := `[[albums] name = "Born to Run"` testgenInvalid(t, input) } func TestInvalidTableArrayMalformedEmpty(t *testing.T) { + t.Parallel() + input := `[[]] name = "Born to Run"` testgenInvalid(t, input) } func TestInvalidTableEmpty(t *testing.T) { + t.Parallel() + input := `[]` testgenInvalid(t, input) } func TestInvalidTableNestedBracketsClose(t *testing.T) { + t.Parallel() + input := `[a]b] zyx = 42` testgenInvalid(t, input) } func TestInvalidTableNestedBracketsOpen(t *testing.T) { + t.Parallel() + input := `[a[b] zyx = 42` testgenInvalid(t, input) } func TestInvalidTableWhitespace(t *testing.T) { + t.Parallel() + input := `[invalid key]` testgenInvalid(t, input) } func TestInvalidTableWithPound(t *testing.T) { + t.Parallel() + input := `[key#group] answer = 42` testgenInvalid(t, input) } func TestInvalidTextAfterArrayEntries(t *testing.T) { + t.Parallel() + input := `array = [ "Is there life after an array separator?", No "Entry" @@ -198,21 +262,29 @@ func TestInvalidTextAfterArrayEntries(t *testing.T) { } func TestInvalidTextAfterInteger(t *testing.T) { + t.Parallel() + input := `answer = 42 the ultimate answer?` testgenInvalid(t, input) } func TestInvalidTextAfterString(t *testing.T) { + t.Parallel() + input := `string = "Is there life after strings?" No.` testgenInvalid(t, input) } func TestInvalidTextAfterTable(t *testing.T) { + t.Parallel() + input := `[error] this shouldn't be here` testgenInvalid(t, input) } func TestInvalidTextBeforeArraySeparator(t *testing.T) { + t.Parallel() + input := `array = [ "Is there life before an array separator?" No, "Entry" @@ -221,6 +293,8 @@ func TestInvalidTextBeforeArraySeparator(t *testing.T) { } func TestInvalidTextInArray(t *testing.T) { + t.Parallel() + input := `array = [ "Entry 1", I don't belong, @@ -230,6 +304,8 @@ func TestInvalidTextInArray(t *testing.T) { } func TestValidArrayEmpty(t *testing.T) { + t.Parallel() + input := `thevoid = [[[[[]]]]]` jsonRef := `{ "thevoid": { "type": "array", "value": [ @@ -246,6 +322,8 @@ func TestValidArrayEmpty(t *testing.T) { } func TestValidArrayNospaces(t *testing.T) { + t.Parallel() + input := `ints = [1,2,3]` jsonRef := `{ "ints": { @@ -261,6 +339,8 @@ func TestValidArrayNospaces(t *testing.T) { } func TestValidArraysHetergeneous(t *testing.T) { + t.Parallel() + input := `mixed = [[1, 2], ["a", "b"], [1.1, 2.1]]` jsonRef := `{ "mixed": { @@ -285,6 +365,8 @@ func TestValidArraysHetergeneous(t *testing.T) { } func TestValidArraysNested(t *testing.T) { + t.Parallel() + input := `nest = [["a"], ["b"]]` jsonRef := `{ "nest": { @@ -303,6 +385,8 @@ func TestValidArraysNested(t *testing.T) { } func TestValidArrays(t *testing.T) { + t.Parallel() + input := `ints = [1, 2, 3] floats = [1.1, 2.1, 3.1] strings = ["a", "b", "c"] @@ -349,6 +433,8 @@ dates = [ } func TestValidBool(t *testing.T) { + t.Parallel() + input := `t = true f = false` jsonRef := `{ @@ -359,6 +445,8 @@ f = false` } func TestValidCommentsEverywhere(t *testing.T) { + t.Parallel() + input := `# Top comment. # Top comment. # Top comment. @@ -368,7 +456,7 @@ func TestValidCommentsEverywhere(t *testing.T) { [group] # Comment answer = 42 # Comment # no-extraneous-keys-please = 999 -# Inbetween comment. +# In between comment. more = [ # Comment # What about multiple # comments? # Can you handle it? @@ -399,6 +487,8 @@ more = [ # Comment } func TestValidDatetime(t *testing.T) { + t.Parallel() + input := `bestdayever = 1987-07-05T17:45:00Z` jsonRef := `{ "bestdayever": {"type": "datetime", "value": "1987-07-05T17:45:00Z"} @@ -407,12 +497,16 @@ func TestValidDatetime(t *testing.T) { } func TestValidEmpty(t *testing.T) { + t.Parallel() + input := `` jsonRef := `{}` testgenValid(t, input, jsonRef) } func TestValidExample(t *testing.T) { + t.Parallel() + input := `best-day-ever = 1987-07-05T17:45:00Z [numtheory] @@ -436,6 +530,8 @@ perfection = [6, 28, 496]` } func TestValidFloat(t *testing.T) { + t.Parallel() + input := `pi = 3.14 negpi = -3.14` jsonRef := `{ @@ -446,6 +542,8 @@ negpi = -3.14` } func TestValidImplicitAndExplicitAfter(t *testing.T) { + t.Parallel() + input := `[a.b.c] answer = 42 @@ -465,6 +563,8 @@ better = 43` } func TestValidImplicitAndExplicitBefore(t *testing.T) { + t.Parallel() + input := `[a] better = 43 @@ -484,6 +584,8 @@ answer = 42` } func TestValidImplicitGroups(t *testing.T) { + t.Parallel() + input := `[a.b.c] answer = 42` jsonRef := `{ @@ -499,6 +601,8 @@ answer = 42` } func TestValidInteger(t *testing.T) { + t.Parallel() + input := `answer = 42 neganswer = -42` jsonRef := `{ @@ -509,6 +613,8 @@ neganswer = -42` } func TestValidKeyEqualsNospace(t *testing.T) { + t.Parallel() + input := `answer=42` jsonRef := `{ "answer": {"type": "integer", "value": "42"} @@ -517,6 +623,8 @@ func TestValidKeyEqualsNospace(t *testing.T) { } func TestValidKeySpace(t *testing.T) { + t.Parallel() + input := `"a b" = 1` jsonRef := `{ "a b": {"type": "integer", "value": "1"} @@ -525,6 +633,8 @@ func TestValidKeySpace(t *testing.T) { } func TestValidKeySpecialChars(t *testing.T) { + t.Parallel() + input := "\"~!@$^&*()_+-`1234567890[]|/?><.,;:'\" = 1\n" jsonRef := "{\n" + " \"~!@$^&*()_+-`1234567890[]|/?><.,;:'\": {\n" + @@ -535,6 +645,8 @@ func TestValidKeySpecialChars(t *testing.T) { } func TestValidLongFloat(t *testing.T) { + t.Parallel() + input := `longpi = 3.141592653589793 neglongpi = -3.141592653589793` jsonRef := `{ @@ -545,6 +657,8 @@ neglongpi = -3.141592653589793` } func TestValidLongInteger(t *testing.T) { + t.Parallel() + input := `answer = 9223372036854775807 neganswer = -9223372036854775808` jsonRef := `{ @@ -555,6 +669,8 @@ neganswer = -9223372036854775808` } func TestValidMultilineString(t *testing.T) { + t.Parallel() + input := `multiline_empty_one = """""" multiline_empty_two = """ """ @@ -612,6 +728,8 @@ equivalent_three = """\ } func TestValidRawMultilineString(t *testing.T) { + t.Parallel() + input := `oneline = '''This string has a ' quote character.''' firstnl = ''' This string has a ' quote character.''' @@ -639,6 +757,8 @@ in it.'''` } func TestValidRawString(t *testing.T) { + t.Parallel() + input := `backspace = 'This string has a \b backspace character.' tab = 'This string has a \t tab character.' newline = 'This string has a \n new line character.' @@ -680,6 +800,8 @@ backslash = 'This string has a \\ backslash character.'` } func TestValidStringEmpty(t *testing.T) { + t.Parallel() + input := `answer = ""` jsonRef := `{ "answer": { @@ -691,6 +813,8 @@ func TestValidStringEmpty(t *testing.T) { } func TestValidStringEscapes(t *testing.T) { + t.Parallel() + input := `backspace = "This string has a \b backspace character." tab = "This string has a \t tab character." newline = "This string has a \n new line character." @@ -752,6 +876,8 @@ notunicode4 = "This string does not have a unicode \\\u0075 escape."` } func TestValidStringSimple(t *testing.T) { + t.Parallel() + input := `answer = "You are not drinking enough whisky."` jsonRef := `{ "answer": { @@ -763,6 +889,8 @@ func TestValidStringSimple(t *testing.T) { } func TestValidStringWithPound(t *testing.T) { + t.Parallel() + input := `pound = "We see no # comments here." poundcomment = "But there are # some comments here." # Did I # mess you up?` jsonRef := `{ @@ -776,6 +904,8 @@ poundcomment = "But there are # some comments here." # Did I # mess you up?` } func TestValidTableArrayImplicit(t *testing.T) { + t.Parallel() + input := `[[albums.songs]] name = "Glory Days"` jsonRef := `{ @@ -789,6 +919,8 @@ name = "Glory Days"` } func TestValidTableArrayMany(t *testing.T) { + t.Parallel() + input := `[[people]] first_name = "Bruce" last_name = "Springsteen" @@ -820,6 +952,8 @@ last_name = "Seger"` } func TestValidTableArrayNest(t *testing.T) { + t.Parallel() + input := `[[albums]] name = "Born to Run" @@ -831,7 +965,7 @@ name = "Born to Run" [[albums]] name = "Born in the USA" - + [[albums.songs]] name = "Glory Days" @@ -859,6 +993,8 @@ name = "Born in the USA" } func TestValidTableArrayOne(t *testing.T) { + t.Parallel() + input := `[[people]] first_name = "Bruce" last_name = "Springsteen"` @@ -874,6 +1010,8 @@ last_name = "Springsteen"` } func TestValidTableEmpty(t *testing.T) { + t.Parallel() + input := `[a]` jsonRef := `{ "a": {} @@ -882,6 +1020,8 @@ func TestValidTableEmpty(t *testing.T) { } func TestValidTableSubEmpty(t *testing.T) { + t.Parallel() + input := `[a] [a.b]` jsonRef := `{ @@ -891,6 +1031,8 @@ func TestValidTableSubEmpty(t *testing.T) { } func TestValidTableWhitespace(t *testing.T) { + t.Parallel() + input := `["valid key"]` jsonRef := `{ "valid key": {} @@ -899,6 +1041,8 @@ func TestValidTableWhitespace(t *testing.T) { } func TestValidTableWithPound(t *testing.T) { + t.Parallel() + input := `["key#group"] answer = 42` jsonRef := `{ @@ -910,6 +1054,8 @@ answer = 42` } func TestValidUnicodeEscape(t *testing.T) { + t.Parallel() + input := `answer4 = "\u03B4" answer8 = "\U000003B4"` jsonRef := `{ @@ -920,6 +1066,8 @@ answer8 = "\U000003B4"` } func TestValidUnicodeLiteral(t *testing.T) { + t.Parallel() + input := `answer = "δ"` jsonRef := `{ "answer": {"type": "string", "value": "δ"} diff --git a/unmarshaler.go b/unmarshaler.go index 5460457..47a9aec 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -81,7 +81,7 @@ type decoder struct { strict strict } -func (d *decoder) arrayIndex(append bool, v reflect.Value) int { +func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int { if d.arrayIndexes == nil { d.arrayIndexes = make(map[reflect.Value]int, 1) } @@ -90,7 +90,7 @@ func (d *decoder) arrayIndex(append bool, v reflect.Value) int { if !ok { d.arrayIndexes[v] = 0 - } else if append { + } else if shouldAppend { idx++ d.arrayIndexes[v] = idx } @@ -173,6 +173,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { found = true case ast.Table: d.strict.EnterTable(node) + current, found, err = d.scopeWithKey(root, node.Key()) if err == nil && found { // In case this table points to an interface, @@ -180,7 +181,10 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { // looks like a table. Otherwise the information // of a table is lost, and marshal cannot do the // round trip. - ensureMapIfInterface(current) + err := ensureMapIfInterface(current) + if err != nil { + panic(fmt.Sprintf("ensureMapIfInterface: %s", err)) + } } case ast.ArrayTable: d.strict.EnterArrayTable(node) @@ -305,6 +309,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error { // A struct in the path was not found. Skip this value. if !found { d.strict.MissingField(node) + return nil } @@ -327,11 +332,21 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { } if v.Type().Implements(textUnmarshalerType) { - return true, v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) + err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) + if err != nil { + return false, fmt.Errorf("tryTextUnmarshaler: %w", err) + } + + return true, nil } if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { - return true, v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) + err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) + if err != nil { + return false, fmt.Errorf("tryTextUnmarshaler: %w", err) + } + + return true, nil } return false, nil @@ -345,7 +360,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error { if !v.Elem().IsValid() { err := x.set(reflect.New(v.Type().Elem())) if err != nil { - return err + return fmt.Errorf("unmarshalValue: %w", err) } v = x.get() @@ -423,14 +438,25 @@ func unmarshalDateTime(x target, node ast.Node) error { func setLocalDateTime(x target, v LocalDateTime) error { if x.get().Type() == timeType { cast := v.In(time.Local) + return setDateTime(x, cast) } - return x.set(reflect.ValueOf(v)) + err := x.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setLocalDateTime: %w", err) + } + + return nil } func setDateTime(x target, v time.Time) error { - return x.set(reflect.ValueOf(v)) + err := x.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setDateTime: %w", err) + } + + return nil } var timeType = reflect.TypeOf(time.Time{}) @@ -438,10 +464,16 @@ var timeType = reflect.TypeOf(time.Time{}) func setDate(x target, v LocalDate) error { if x.get().Type() == timeType { cast := v.In(time.Local) + return setDateTime(x, cast) } - return x.set(reflect.ValueOf(v)) + err := x.set(reflect.ValueOf(v)) + if err != nil { + return fmt.Errorf("setDate: %w", err) + } + + return nil } func unmarshalString(x target, node ast.Node) error { @@ -470,6 +502,7 @@ func unmarshalInteger(x target, node ast.Node) error { func unmarshalFloat(x target, node ast.Node) error { assertNode(ast.Float, node) + v, err := parseFloat(node.Data) if err != nil { return err @@ -481,7 +514,10 @@ func unmarshalFloat(x target, node ast.Node) error { func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error { assertNode(ast.InlineTable, node) - ensureMapIfInterface(x) + err := ensureMapIfInterface(x) + if err != nil { + return fmt.Errorf("unmarshalInlineTable: %w", err) + } it := node.Children() for it.Next() { diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 45f6e58..6caaadb 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -838,6 +838,7 @@ func (i *Integer484) UnmarshalText(data []byte) error { return fmt.Errorf("UnmarshalText: %w", err) } i.Value = conv + return nil } @@ -846,6 +847,8 @@ type Config484 struct { } func TestIssue484(t *testing.T) { + t.Parallel() + raw := []byte(`integers = ["1","2","3","100"]`) var cfg Config484 @@ -866,6 +869,8 @@ func (m Map458) A(s string) Slice458 { } func TestIssue458(t *testing.T) { + t.Parallel() + s := []byte(`[[package]] dependencies = ["regex"] name = "decode" @@ -885,6 +890,8 @@ version = "0.1.0"`) } func TestIssue252(t *testing.T) { + t.Parallel() + type config struct { Val1 string `toml:"val1"` Val2 string `toml:"val2"` @@ -905,6 +912,8 @@ val1 = "test1" } func TestIssue494(t *testing.T) { + t.Parallel() + data := ` foo = 2021-04-08 bar = 2021-04-08 @@ -920,6 +929,8 @@ bar = 2021-04-08 } func TestIssue507(t *testing.T) { + t.Parallel() + data := []byte{'0', '=', '\n', '0', 'a', 'm', 'e'} m := map[string]interface{}{} err := toml.Unmarshal(data, &m) @@ -1094,6 +1105,8 @@ func TestLocalDateTime(t *testing.T) { } func TestIssue287(t *testing.T) { + t.Parallel() + b := `y=[[{}]]` v := map[string]interface{}{} err := toml.Unmarshal([]byte(b), &v) @@ -1110,6 +1123,8 @@ func TestIssue287(t *testing.T) { } func TestIssue508(t *testing.T) { + t.Parallel() + type head struct { Title string `toml:"title"` }