From b226db6a2950c2b2d89884d09cbca2d8d3119348 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Wed, 24 Nov 2021 20:43:56 -0500 Subject: [PATCH] Decoder: show struct field in type mismatch errors (#684) The goal is to provide some context as to why the type were mismatched. This change only works for that case, on structs. This is the same a encoding/json. A more general solution would be great, but this would require a broader change in the decoder, which I don't think is necessary at the moment. Fixes #628 --- unmarshaler.go | 54 ++++++++++++++++++++++++++++++++++++--------- unmarshaler_test.go | 14 ++++++++++++ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/unmarshaler.go b/unmarshaler.go index 46b4460..a0237bc 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -131,6 +131,23 @@ type decoder struct { // Strict mode strict strict + + // Current context for the error. + errorContext *errorContext +} + +type errorContext struct { + Struct reflect.Type + Field []int +} + +func (d *decoder) typeMismatchError(toml string, target reflect.Type) error { + if d.errorContext != nil && d.errorContext.Struct != nil { + ctx := d.errorContext + f := ctx.Struct.FieldByIndex(ctx.Field) + return fmt.Errorf("toml: cannot decode TOML %s into struct field %s.%s of type %s", toml, ctx.Struct, f.Name, f.Type) + } + return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target) } func (d *decoder) expr() *ast.Node { @@ -444,12 +461,20 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle v.SetMapIndex(mk, mv) } case reflect.Struct: - f, found := structField(v, string(key.Node().Data)) + path, found := structFieldPath(v, string(key.Node().Data)) if !found { d.skipUntilTable = true return reflect.Value{}, nil } + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + t := v.Type() + d.errorContext.Struct = t + d.errorContext.Field = path + + f := v.FieldByIndex(path) x, err := nextFn(key, f) if err != nil || d.skipUntilTable { return reflect.Value{}, err @@ -457,6 +482,8 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle if x.IsValid() { f.Set(x) } + d.errorContext.Field = nil + d.errorContext.Struct = nil case reflect.Interface: if v.Elem().IsValid() { v = v.Elem() @@ -657,7 +684,7 @@ func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error { default: // TODO: use newDecodeError, but first the parser needs to fill // array.Data. - return fmt.Errorf("toml: cannot store array in Go type %s", v.Kind()) + return d.typeMismatchError("array", v.Type()) } elemType := v.Type().Elem() @@ -904,7 +931,7 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error { case reflect.Interface: r = reflect.ValueOf(i) default: - return fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) + return d.typeMismatchError("integer", v.Type()) } if !r.Type().AssignableTo(v.Type()) { @@ -1007,12 +1034,20 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec v.SetMapIndex(mk, mv) } case reflect.Struct: - f, found := structField(v, string(key.Node().Data)) + path, found := structFieldPath(v, string(key.Node().Data)) if !found { d.skipUntilTable = true break } + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + t := v.Type() + d.errorContext.Struct = t + d.errorContext.Field = path + + f := v.FieldByIndex(path) x, err := d.handleKeyValueInner(key, value, f) if err != nil { return reflect.Value{}, err @@ -1021,6 +1056,8 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec if x.IsValid() { f.Set(x) } + d.errorContext.Struct = nil + d.errorContext.Field = nil case reflect.Interface: v = v.Elem() @@ -1078,7 +1115,7 @@ type fieldPathsMap = map[string][]int var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap -func structField(v reflect.Value, name string) (reflect.Value, bool) { +func structFieldPath(v reflect.Value, name string) ([]int, bool) { t := v.Type() cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap) @@ -1105,12 +1142,7 @@ func structField(v reflect.Value, name string) (reflect.Value, bool) { if !ok { path, ok = fieldPaths[strings.ToLower(name)] } - - if !ok { - return reflect.Value{}, false - } - - return v.FieldByIndex(path), true + return path, ok } func forEachField(t reflect.Type, path []int, do func(name string, path []int)) { diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 8823f8c..b10d703 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -1790,6 +1790,20 @@ func TestUnmarshalOverflows(t *testing.T) { } } +func TestUnmarshalErrors(t *testing.T) { + type mystruct struct { + Bar string + } + + data := `bar = 42` + + s := mystruct{} + err := toml.Unmarshal([]byte(data), &s) + require.Error(t, err) + + require.Equal(t, "toml: cannot decode TOML integer into struct field toml_test.mystruct.Bar of type string", err.Error()) +} + func TestUnmarshalInvalidTarget(t *testing.T) { x := "foo" err := toml.Unmarshal([]byte{}, x)