diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 9c92491..d970dbb 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -328,13 +328,8 @@ shouldntBeHere = 2 func TestUnexportedUnmarshal(t *testing.T) { result := unexportedMarshalTestStruct{} err := toml.Unmarshal(unexportedTestToml, &result) - expected := unexportedTestData - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(result, expected) { - t.Errorf("Bad unexported unmarshal: expected %v, got %v", expected, result) - } + require.NoError(t, err) + assert.Equal(t, unexportedTestData, result) } type errStruct struct { diff --git a/targets.go b/targets.go index f7c0b02..216701d 100644 --- a/targets.go +++ b/targets.go @@ -225,7 +225,7 @@ func pushNew(t target) (target, error) { } } -func scopeTableTarget(append bool, t target, name string) (target, error) { +func scopeTableTarget(append bool, t target, name string) (target, bool, error) { x := t.get() switch x.Kind() { @@ -234,19 +234,19 @@ func scopeTableTarget(append bool, t target, name string) (target, error) { case reflect.Interface: t, err := scopeInterface(append, t) if err != nil { - return t, err + return t, false, err } return scopeTableTarget(append, t, name) case reflect.Ptr: t, err := scopePtr(t) if err != nil { - return t, err + return t, false, err } return scopeTableTarget(append, t, name) case reflect.Slice: t, err := scopeSlice(append, t) if err != nil { - return t, err + return t, false, err } append = false return scopeTableTarget(append, t, name) @@ -260,7 +260,6 @@ func scopeTableTarget(append bool, t target, name string) (target, error) { default: panic(fmt.Errorf("can't scope on a %s", x.Kind())) } - return t, nil } func scopeInterface(append bool, t target) (target, error) { @@ -330,7 +329,7 @@ func scopeSlice(append bool, t target) (target, error) { return valueTarget(v.Index(v.Len() - 1)), nil } -func scopeMap(v reflect.Value, name string) (target, error) { +func scopeMap(v reflect.Value, name string) (target, bool, error) { if v.IsNil() { v.Set(reflect.MakeMap(v.Type())) } @@ -344,10 +343,10 @@ func scopeMap(v reflect.Value, name string) (target, error) { return mapTarget{ v: v, k: k, - }, nil + }, true, nil } -func scopeStruct(v reflect.Value, name string) (target, error) { +func scopeStruct(v reflect.Value, name string) (target, bool, error) { // TODO: cache this t := v.Type() for i := 0; i < t.NumField(); i++ { @@ -361,9 +360,9 @@ func scopeStruct(v reflect.Value, name string) (target, error) { } else { // TODO: handle names variations if f.Name == name { - return valueTarget(v.Field(i)), nil + return valueTarget(v.Field(i)), true, nil } } } - return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type()) + return nil, false, nil } diff --git a/targets_test.go b/targets_test.go index 2fd5708..7316994 100644 --- a/targets_test.go +++ b/targets_test.go @@ -39,7 +39,7 @@ func TestStructTarget_Ensure(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - target, err := scopeTableTarget(false, valueTarget(e.input), e.name) + target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) err = ensureSlice(target) v := target.get() @@ -86,7 +86,7 @@ func TestStructTarget_SetString(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - target, err := scopeTableTarget(false, valueTarget(e.input), e.name) + target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) err = setString(target, str) v := target.get() @@ -102,7 +102,7 @@ func TestPushNew(t *testing.T) { } d := Doc{} - x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") + x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) n, err := pushNew(x) @@ -122,7 +122,7 @@ func TestPushNew(t *testing.T) { } d := Doc{} - x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") + x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) n, err := pushNew(x) @@ -143,6 +143,7 @@ func TestScope_Struct(t *testing.T) { input reflect.Value name string err bool + found bool idx []int }{ { @@ -150,21 +151,25 @@ func TestScope_Struct(t *testing.T) { input: reflect.ValueOf(&struct{ A string }{}).Elem(), name: "A", idx: []int{0}, + found: true, }, { desc: "fails not-exported field", input: reflect.ValueOf(&struct{ a string }{}).Elem(), name: "a", - err: true, + err: false, + found: false, }, } for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - x, err := scopeTableTarget(false, valueTarget(e.input), e.name) + x, found, err := scopeTableTarget(false, valueTarget(e.input), e.name) + assert.Equal(t, e.found, found) if e.err { - require.Error(t, err) - } else { + assert.Error(t, err) + } + if found { x2, ok := x.(valueTarget) require.True(t, ok) x2.get() diff --git a/unmarshaler.go b/unmarshaler.go index 7f1e45b..1fee384 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -26,33 +26,38 @@ func fromAst(tree ast.Root, v interface{}) error { } var err error + var skipUntilTable bool var root target = valueTarget(r.Elem()) current := root for _, node := range tree { - current, err = unmarshalTopLevelNode(root, current, &node) + var found bool + switch node.Kind { + case ast.KeyValue: + if skipUntilTable { + continue + } + err = unmarshalKeyValue(current, &node) + found = true + case ast.Table: + current, found, err = scopeWithKey(root, node.Key()) + case ast.ArrayTable: + current, found, err = scopeWithArrayTable(root, node.Key()) + default: + panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) + } + if err != nil { return err } + + if !found { + skipUntilTable = true + } } return nil } -// The target return value is the target for the next top-level node. Mostly -// unchanged, except by table and array table. -func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error) { - switch node.Kind { - case ast.KeyValue: - return x, unmarshalKeyValue(x, node) - case ast.Table: - return scopeWithKey(root, node.Key()) - case ast.ArrayTable: - return scopeWithArrayTable(root, node.Key()) - default: - panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) - } -} - // scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node. // // The goal is to hop from target to target recursively using the names in key. @@ -61,15 +66,16 @@ func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error // // When encountering slices, it should always use its last element, and error // if the slice does not have any. -func scopeWithKey(x target, key []ast.Node) (target, error) { +func scopeWithKey(x target, key []ast.Node) (target, bool, error) { var err error + found := true for _, n := range key { - x, err = scopeTableTarget(false, x, string(n.Data)) - if err != nil { - return nil, err + x, found, err = scopeTableTarget(false, x, string(n.Data)) + if err != nil || !found { + return nil, found, err } } - return x, nil + return x, true, nil } // scopeWithArrayTable performs target scoping when unmarshaling an @@ -77,19 +83,20 @@ func scopeWithKey(x target, key []ast.Node) (target, error) { // // It is the same as scopeWithKey, but when scoping the last part of the key // it creates a new element in the array instead of using the last one. -func scopeWithArrayTable(x target, key []ast.Node) (target, error) { +func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) { var err error + found := true if len(key) > 1 { for _, n := range key[:len(key)-1] { - x, err = scopeTableTarget(false, x, string(n.Data)) - if err != nil { - return nil, err + x, found, err = scopeTableTarget(false, x, string(n.Data)) + if err != nil || !found { + return nil, found, err } } } - x, err = scopeTableTarget(false, x, string(key[len(key)-1].Data)) - if err != nil { - return x, err + x, found, err = scopeTableTarget(false, x, string(key[len(key)-1].Data)) + if err != nil || !found { + return x, found, err } v := x.get() @@ -97,26 +104,31 @@ func scopeWithArrayTable(x target, key []ast.Node) (target, error) { if v.Kind() == reflect.Interface { x, err = scopeInterface(true, x) if err != nil { - return x, err + return x, found, err } v = x.get() } if v.Kind() == reflect.Slice { - return scopeSlice(true, x) + x, err = scopeSlice(true, x) } - return x, err + return x, found, err } func unmarshalKeyValue(x target, node *ast.Node) error { assertNode(ast.KeyValue, node) - x, err := scopeWithKey(x, node.Key()) + x, found, err := scopeWithKey(x, node.Key()) if err != nil { return err } + // A struct in the path was not found. Skip this value. + if !found { + return nil + } + return unmarshalValue(x, node.Value()) }