Skip AST branches that don't exist in the target

This commit is contained in:
Thomas Pelletier
2021-03-18 20:30:51 -04:00
parent 3e8b8db786
commit 93a7b0d77d
4 changed files with 68 additions and 57 deletions
@@ -328,13 +328,8 @@ shouldntBeHere = 2
func TestUnexportedUnmarshal(t *testing.T) { func TestUnexportedUnmarshal(t *testing.T) {
result := unexportedMarshalTestStruct{} result := unexportedMarshalTestStruct{}
err := toml.Unmarshal(unexportedTestToml, &result) err := toml.Unmarshal(unexportedTestToml, &result)
expected := unexportedTestData require.NoError(t, err)
if err != nil { assert.Equal(t, unexportedTestData, result)
t.Fatal(err)
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Bad unexported unmarshal: expected %v, got %v", expected, result)
}
} }
type errStruct struct { type errStruct struct {
+9 -10
View File
@@ -225,7 +225,7 @@ func pushNew(t target) (target, error) {
} }
} }
func scopeTableTarget(append bool, t target, name string) (target, error) { func scopeTableTarget(append bool, t target, name string) (target, bool, error) {
x := t.get() x := t.get()
switch x.Kind() { switch x.Kind() {
@@ -234,19 +234,19 @@ func scopeTableTarget(append bool, t target, name string) (target, error) {
case reflect.Interface: case reflect.Interface:
t, err := scopeInterface(append, t) t, err := scopeInterface(append, t)
if err != nil { if err != nil {
return t, err return t, false, err
} }
return scopeTableTarget(append, t, name) return scopeTableTarget(append, t, name)
case reflect.Ptr: case reflect.Ptr:
t, err := scopePtr(t) t, err := scopePtr(t)
if err != nil { if err != nil {
return t, err return t, false, err
} }
return scopeTableTarget(append, t, name) return scopeTableTarget(append, t, name)
case reflect.Slice: case reflect.Slice:
t, err := scopeSlice(append, t) t, err := scopeSlice(append, t)
if err != nil { if err != nil {
return t, err return t, false, err
} }
append = false append = false
return scopeTableTarget(append, t, name) return scopeTableTarget(append, t, name)
@@ -260,7 +260,6 @@ func scopeTableTarget(append bool, t target, name string) (target, error) {
default: default:
panic(fmt.Errorf("can't scope on a %s", x.Kind())) panic(fmt.Errorf("can't scope on a %s", x.Kind()))
} }
return t, nil
} }
func scopeInterface(append bool, t target) (target, error) { func scopeInterface(append bool, t target) (target, error) {
@@ -330,7 +329,7 @@ func scopeSlice(append bool, t target) (target, error) {
return valueTarget(v.Index(v.Len() - 1)), nil return valueTarget(v.Index(v.Len() - 1)), nil
} }
func scopeMap(v reflect.Value, name string) (target, error) { func scopeMap(v reflect.Value, name string) (target, bool, error) {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.MakeMap(v.Type())) v.Set(reflect.MakeMap(v.Type()))
} }
@@ -344,10 +343,10 @@ func scopeMap(v reflect.Value, name string) (target, error) {
return mapTarget{ return mapTarget{
v: v, v: v,
k: k, k: k,
}, nil }, true, nil
} }
func scopeStruct(v reflect.Value, name string) (target, error) { func scopeStruct(v reflect.Value, name string) (target, bool, error) {
// TODO: cache this // TODO: cache this
t := v.Type() t := v.Type()
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
@@ -361,9 +360,9 @@ func scopeStruct(v reflect.Value, name string) (target, error) {
} else { } else {
// TODO: handle names variations // TODO: handle names variations
if f.Name == name { if f.Name == name {
return valueTarget(v.Field(i)), nil return valueTarget(v.Field(i)), true, nil
} }
} }
} }
return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type()) return nil, false, nil
} }
+13 -8
View File
@@ -39,7 +39,7 @@ func TestStructTarget_Ensure(t *testing.T) {
for _, e := range examples { for _, e := range examples {
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
target, err := scopeTableTarget(false, valueTarget(e.input), e.name) target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err) require.NoError(t, err)
err = ensureSlice(target) err = ensureSlice(target)
v := target.get() v := target.get()
@@ -86,7 +86,7 @@ func TestStructTarget_SetString(t *testing.T) {
for _, e := range examples { for _, e := range examples {
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
target, err := scopeTableTarget(false, valueTarget(e.input), e.name) target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err) require.NoError(t, err)
err = setString(target, str) err = setString(target, str)
v := target.get() v := target.get()
@@ -102,7 +102,7 @@ func TestPushNew(t *testing.T) {
} }
d := Doc{} d := Doc{}
x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := pushNew(x) n, err := pushNew(x)
@@ -122,7 +122,7 @@ func TestPushNew(t *testing.T) {
} }
d := Doc{} d := Doc{}
x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := pushNew(x) n, err := pushNew(x)
@@ -143,6 +143,7 @@ func TestScope_Struct(t *testing.T) {
input reflect.Value input reflect.Value
name string name string
err bool err bool
found bool
idx []int idx []int
}{ }{
{ {
@@ -150,21 +151,25 @@ func TestScope_Struct(t *testing.T) {
input: reflect.ValueOf(&struct{ A string }{}).Elem(), input: reflect.ValueOf(&struct{ A string }{}).Elem(),
name: "A", name: "A",
idx: []int{0}, idx: []int{0},
found: true,
}, },
{ {
desc: "fails not-exported field", desc: "fails not-exported field",
input: reflect.ValueOf(&struct{ a string }{}).Elem(), input: reflect.ValueOf(&struct{ a string }{}).Elem(),
name: "a", name: "a",
err: true, err: false,
found: false,
}, },
} }
for _, e := range examples { for _, e := range examples {
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
x, err := scopeTableTarget(false, valueTarget(e.input), e.name) x, found, err := scopeTableTarget(false, valueTarget(e.input), e.name)
assert.Equal(t, e.found, found)
if e.err { if e.err {
require.Error(t, err) assert.Error(t, err)
} else { }
if found {
x2, ok := x.(valueTarget) x2, ok := x.(valueTarget)
require.True(t, ok) require.True(t, ok)
x2.get() x2.get()
+44 -32
View File
@@ -26,33 +26,38 @@ func fromAst(tree ast.Root, v interface{}) error {
} }
var err error var err error
var skipUntilTable bool
var root target = valueTarget(r.Elem()) var root target = valueTarget(r.Elem())
current := root current := root
for _, node := range tree { for _, node := range tree {
current, err = unmarshalTopLevelNode(root, current, &node) var found bool
switch node.Kind {
case ast.KeyValue:
if skipUntilTable {
continue
}
err = unmarshalKeyValue(current, &node)
found = true
case ast.Table:
current, found, err = scopeWithKey(root, node.Key())
case ast.ArrayTable:
current, found, err = scopeWithArrayTable(root, node.Key())
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
if err != nil { if err != nil {
return err return err
} }
if !found {
skipUntilTable = true
}
} }
return nil return nil
} }
// The target return value is the target for the next top-level node. Mostly
// unchanged, except by table and array table.
func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error) {
switch node.Kind {
case ast.KeyValue:
return x, unmarshalKeyValue(x, node)
case ast.Table:
return scopeWithKey(root, node.Key())
case ast.ArrayTable:
return scopeWithArrayTable(root, node.Key())
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
// scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node. // scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node.
// //
// The goal is to hop from target to target recursively using the names in key. // The goal is to hop from target to target recursively using the names in key.
@@ -61,15 +66,16 @@ func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error
// //
// When encountering slices, it should always use its last element, and error // When encountering slices, it should always use its last element, and error
// if the slice does not have any. // if the slice does not have any.
func scopeWithKey(x target, key []ast.Node) (target, error) { func scopeWithKey(x target, key []ast.Node) (target, bool, error) {
var err error var err error
found := true
for _, n := range key { for _, n := range key {
x, err = scopeTableTarget(false, x, string(n.Data)) x, found, err = scopeTableTarget(false, x, string(n.Data))
if err != nil { if err != nil || !found {
return nil, err return nil, found, err
} }
} }
return x, nil return x, true, nil
} }
// scopeWithArrayTable performs target scoping when unmarshaling an // scopeWithArrayTable performs target scoping when unmarshaling an
@@ -77,19 +83,20 @@ func scopeWithKey(x target, key []ast.Node) (target, error) {
// //
// It is the same as scopeWithKey, but when scoping the last part of the key // It is the same as scopeWithKey, but when scoping the last part of the key
// it creates a new element in the array instead of using the last one. // it creates a new element in the array instead of using the last one.
func scopeWithArrayTable(x target, key []ast.Node) (target, error) { func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) {
var err error var err error
found := true
if len(key) > 1 { if len(key) > 1 {
for _, n := range key[:len(key)-1] { for _, n := range key[:len(key)-1] {
x, err = scopeTableTarget(false, x, string(n.Data)) x, found, err = scopeTableTarget(false, x, string(n.Data))
if err != nil { if err != nil || !found {
return nil, err return nil, found, err
} }
} }
} }
x, err = scopeTableTarget(false, x, string(key[len(key)-1].Data)) x, found, err = scopeTableTarget(false, x, string(key[len(key)-1].Data))
if err != nil { if err != nil || !found {
return x, err return x, found, err
} }
v := x.get() v := x.get()
@@ -97,26 +104,31 @@ func scopeWithArrayTable(x target, key []ast.Node) (target, error) {
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
x, err = scopeInterface(true, x) x, err = scopeInterface(true, x)
if err != nil { if err != nil {
return x, err return x, found, err
} }
v = x.get() v = x.get()
} }
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
return scopeSlice(true, x) x, err = scopeSlice(true, x)
} }
return x, err return x, found, err
} }
func unmarshalKeyValue(x target, node *ast.Node) error { func unmarshalKeyValue(x target, node *ast.Node) error {
assertNode(ast.KeyValue, node) assertNode(ast.KeyValue, node)
x, err := scopeWithKey(x, node.Key()) x, found, err := scopeWithKey(x, node.Key())
if err != nil { if err != nil {
return err return err
} }
// A struct in the path was not found. Skip this value.
if !found {
return nil
}
return unmarshalValue(x, node.Value()) return unmarshalValue(x, node.Value())
} }