Skip AST branches that don't exist in the target
This commit is contained in:
@@ -328,13 +328,8 @@ shouldntBeHere = 2
|
|||||||
func TestUnexportedUnmarshal(t *testing.T) {
|
func TestUnexportedUnmarshal(t *testing.T) {
|
||||||
result := unexportedMarshalTestStruct{}
|
result := unexportedMarshalTestStruct{}
|
||||||
err := toml.Unmarshal(unexportedTestToml, &result)
|
err := toml.Unmarshal(unexportedTestToml, &result)
|
||||||
expected := unexportedTestData
|
require.NoError(t, err)
|
||||||
if err != nil {
|
assert.Equal(t, unexportedTestData, result)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(result, expected) {
|
|
||||||
t.Errorf("Bad unexported unmarshal: expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type errStruct struct {
|
type errStruct struct {
|
||||||
|
|||||||
+9
-10
@@ -225,7 +225,7 @@ func pushNew(t target) (target, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func scopeTableTarget(append bool, t target, name string) (target, error) {
|
func scopeTableTarget(append bool, t target, name string) (target, bool, error) {
|
||||||
x := t.get()
|
x := t.get()
|
||||||
|
|
||||||
switch x.Kind() {
|
switch x.Kind() {
|
||||||
@@ -234,19 +234,19 @@ func scopeTableTarget(append bool, t target, name string) (target, error) {
|
|||||||
case reflect.Interface:
|
case reflect.Interface:
|
||||||
t, err := scopeInterface(append, t)
|
t, err := scopeInterface(append, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return t, err
|
return t, false, err
|
||||||
}
|
}
|
||||||
return scopeTableTarget(append, t, name)
|
return scopeTableTarget(append, t, name)
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
t, err := scopePtr(t)
|
t, err := scopePtr(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return t, err
|
return t, false, err
|
||||||
}
|
}
|
||||||
return scopeTableTarget(append, t, name)
|
return scopeTableTarget(append, t, name)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
t, err := scopeSlice(append, t)
|
t, err := scopeSlice(append, t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return t, err
|
return t, false, err
|
||||||
}
|
}
|
||||||
append = false
|
append = false
|
||||||
return scopeTableTarget(append, t, name)
|
return scopeTableTarget(append, t, name)
|
||||||
@@ -260,7 +260,6 @@ func scopeTableTarget(append bool, t target, name string) (target, error) {
|
|||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("can't scope on a %s", x.Kind()))
|
panic(fmt.Errorf("can't scope on a %s", x.Kind()))
|
||||||
}
|
}
|
||||||
return t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func scopeInterface(append bool, t target) (target, error) {
|
func scopeInterface(append bool, t target) (target, error) {
|
||||||
@@ -330,7 +329,7 @@ func scopeSlice(append bool, t target) (target, error) {
|
|||||||
return valueTarget(v.Index(v.Len() - 1)), nil
|
return valueTarget(v.Index(v.Len() - 1)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func scopeMap(v reflect.Value, name string) (target, error) {
|
func scopeMap(v reflect.Value, name string) (target, bool, error) {
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
v.Set(reflect.MakeMap(v.Type()))
|
v.Set(reflect.MakeMap(v.Type()))
|
||||||
}
|
}
|
||||||
@@ -344,10 +343,10 @@ func scopeMap(v reflect.Value, name string) (target, error) {
|
|||||||
return mapTarget{
|
return mapTarget{
|
||||||
v: v,
|
v: v,
|
||||||
k: k,
|
k: k,
|
||||||
}, nil
|
}, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func scopeStruct(v reflect.Value, name string) (target, error) {
|
func scopeStruct(v reflect.Value, name string) (target, bool, error) {
|
||||||
// TODO: cache this
|
// TODO: cache this
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
@@ -361,9 +360,9 @@ func scopeStruct(v reflect.Value, name string) (target, error) {
|
|||||||
} else {
|
} else {
|
||||||
// TODO: handle names variations
|
// TODO: handle names variations
|
||||||
if f.Name == name {
|
if f.Name == name {
|
||||||
return valueTarget(v.Field(i)), nil
|
return valueTarget(v.Field(i)), true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type())
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+13
-8
@@ -39,7 +39,7 @@ func TestStructTarget_Ensure(t *testing.T) {
|
|||||||
|
|
||||||
for _, e := range examples {
|
for _, e := range examples {
|
||||||
t.Run(e.desc, func(t *testing.T) {
|
t.Run(e.desc, func(t *testing.T) {
|
||||||
target, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = ensureSlice(target)
|
err = ensureSlice(target)
|
||||||
v := target.get()
|
v := target.get()
|
||||||
@@ -86,7 +86,7 @@ func TestStructTarget_SetString(t *testing.T) {
|
|||||||
|
|
||||||
for _, e := range examples {
|
for _, e := range examples {
|
||||||
t.Run(e.desc, func(t *testing.T) {
|
t.Run(e.desc, func(t *testing.T) {
|
||||||
target, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = setString(target, str)
|
err = setString(target, str)
|
||||||
v := target.get()
|
v := target.get()
|
||||||
@@ -102,7 +102,7 @@ func TestPushNew(t *testing.T) {
|
|||||||
}
|
}
|
||||||
d := Doc{}
|
d := Doc{}
|
||||||
|
|
||||||
x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
|
x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
n, err := pushNew(x)
|
n, err := pushNew(x)
|
||||||
@@ -122,7 +122,7 @@ func TestPushNew(t *testing.T) {
|
|||||||
}
|
}
|
||||||
d := Doc{}
|
d := Doc{}
|
||||||
|
|
||||||
x, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
|
x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
n, err := pushNew(x)
|
n, err := pushNew(x)
|
||||||
@@ -143,6 +143,7 @@ func TestScope_Struct(t *testing.T) {
|
|||||||
input reflect.Value
|
input reflect.Value
|
||||||
name string
|
name string
|
||||||
err bool
|
err bool
|
||||||
|
found bool
|
||||||
idx []int
|
idx []int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -150,21 +151,25 @@ func TestScope_Struct(t *testing.T) {
|
|||||||
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
|
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
|
||||||
name: "A",
|
name: "A",
|
||||||
idx: []int{0},
|
idx: []int{0},
|
||||||
|
found: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "fails not-exported field",
|
desc: "fails not-exported field",
|
||||||
input: reflect.ValueOf(&struct{ a string }{}).Elem(),
|
input: reflect.ValueOf(&struct{ a string }{}).Elem(),
|
||||||
name: "a",
|
name: "a",
|
||||||
err: true,
|
err: false,
|
||||||
|
found: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, e := range examples {
|
for _, e := range examples {
|
||||||
t.Run(e.desc, func(t *testing.T) {
|
t.Run(e.desc, func(t *testing.T) {
|
||||||
x, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
x, found, err := scopeTableTarget(false, valueTarget(e.input), e.name)
|
||||||
|
assert.Equal(t, e.found, found)
|
||||||
if e.err {
|
if e.err {
|
||||||
require.Error(t, err)
|
assert.Error(t, err)
|
||||||
} else {
|
}
|
||||||
|
if found {
|
||||||
x2, ok := x.(valueTarget)
|
x2, ok := x.(valueTarget)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
x2.get()
|
x2.get()
|
||||||
|
|||||||
+44
-32
@@ -26,33 +26,38 @@ func fromAst(tree ast.Root, v interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
var skipUntilTable bool
|
||||||
var root target = valueTarget(r.Elem())
|
var root target = valueTarget(r.Elem())
|
||||||
current := root
|
current := root
|
||||||
for _, node := range tree {
|
for _, node := range tree {
|
||||||
current, err = unmarshalTopLevelNode(root, current, &node)
|
var found bool
|
||||||
|
switch node.Kind {
|
||||||
|
case ast.KeyValue:
|
||||||
|
if skipUntilTable {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = unmarshalKeyValue(current, &node)
|
||||||
|
found = true
|
||||||
|
case ast.Table:
|
||||||
|
current, found, err = scopeWithKey(root, node.Key())
|
||||||
|
case ast.ArrayTable:
|
||||||
|
current, found, err = scopeWithArrayTable(root, node.Key())
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
skipUntilTable = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// The target return value is the target for the next top-level node. Mostly
|
|
||||||
// unchanged, except by table and array table.
|
|
||||||
func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error) {
|
|
||||||
switch node.Kind {
|
|
||||||
case ast.KeyValue:
|
|
||||||
return x, unmarshalKeyValue(x, node)
|
|
||||||
case ast.Table:
|
|
||||||
return scopeWithKey(root, node.Key())
|
|
||||||
case ast.ArrayTable:
|
|
||||||
return scopeWithArrayTable(root, node.Key())
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node.
|
// scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node.
|
||||||
//
|
//
|
||||||
// The goal is to hop from target to target recursively using the names in key.
|
// The goal is to hop from target to target recursively using the names in key.
|
||||||
@@ -61,15 +66,16 @@ func unmarshalTopLevelNode(root target, x target, node *ast.Node) (target, error
|
|||||||
//
|
//
|
||||||
// When encountering slices, it should always use its last element, and error
|
// When encountering slices, it should always use its last element, and error
|
||||||
// if the slice does not have any.
|
// if the slice does not have any.
|
||||||
func scopeWithKey(x target, key []ast.Node) (target, error) {
|
func scopeWithKey(x target, key []ast.Node) (target, bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
found := true
|
||||||
for _, n := range key {
|
for _, n := range key {
|
||||||
x, err = scopeTableTarget(false, x, string(n.Data))
|
x, found, err = scopeTableTarget(false, x, string(n.Data))
|
||||||
if err != nil {
|
if err != nil || !found {
|
||||||
return nil, err
|
return nil, found, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return x, nil
|
return x, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// scopeWithArrayTable performs target scoping when unmarshaling an
|
// scopeWithArrayTable performs target scoping when unmarshaling an
|
||||||
@@ -77,19 +83,20 @@ func scopeWithKey(x target, key []ast.Node) (target, error) {
|
|||||||
//
|
//
|
||||||
// It is the same as scopeWithKey, but when scoping the last part of the key
|
// It is the same as scopeWithKey, but when scoping the last part of the key
|
||||||
// it creates a new element in the array instead of using the last one.
|
// it creates a new element in the array instead of using the last one.
|
||||||
func scopeWithArrayTable(x target, key []ast.Node) (target, error) {
|
func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) {
|
||||||
var err error
|
var err error
|
||||||
|
found := true
|
||||||
if len(key) > 1 {
|
if len(key) > 1 {
|
||||||
for _, n := range key[:len(key)-1] {
|
for _, n := range key[:len(key)-1] {
|
||||||
x, err = scopeTableTarget(false, x, string(n.Data))
|
x, found, err = scopeTableTarget(false, x, string(n.Data))
|
||||||
if err != nil {
|
if err != nil || !found {
|
||||||
return nil, err
|
return nil, found, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
x, err = scopeTableTarget(false, x, string(key[len(key)-1].Data))
|
x, found, err = scopeTableTarget(false, x, string(key[len(key)-1].Data))
|
||||||
if err != nil {
|
if err != nil || !found {
|
||||||
return x, err
|
return x, found, err
|
||||||
}
|
}
|
||||||
|
|
||||||
v := x.get()
|
v := x.get()
|
||||||
@@ -97,26 +104,31 @@ func scopeWithArrayTable(x target, key []ast.Node) (target, error) {
|
|||||||
if v.Kind() == reflect.Interface {
|
if v.Kind() == reflect.Interface {
|
||||||
x, err = scopeInterface(true, x)
|
x, err = scopeInterface(true, x)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return x, err
|
return x, found, err
|
||||||
}
|
}
|
||||||
v = x.get()
|
v = x.get()
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.Kind() == reflect.Slice {
|
if v.Kind() == reflect.Slice {
|
||||||
return scopeSlice(true, x)
|
x, err = scopeSlice(true, x)
|
||||||
}
|
}
|
||||||
|
|
||||||
return x, err
|
return x, found, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalKeyValue(x target, node *ast.Node) error {
|
func unmarshalKeyValue(x target, node *ast.Node) error {
|
||||||
assertNode(ast.KeyValue, node)
|
assertNode(ast.KeyValue, node)
|
||||||
|
|
||||||
x, err := scopeWithKey(x, node.Key())
|
x, found, err := scopeWithKey(x, node.Key())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A struct in the path was not found. Skip this value.
|
||||||
|
if !found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return unmarshalValue(x, node.Value())
|
return unmarshalValue(x, node.Value())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user