Decode: fix reuse of slice for array tables (#934)
When decoding into a non-empty slice, it needs to be emptied so that only the tables contained in the document are present in the resulting value. Arrays are not impacted because their unmarshal offset is tracked separately. Fixes #931
This commit is contained in:
+33
-29
@@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {
|
||||
|
||||
// CheckExpression takes a top-level node and checks that it does not contain
|
||||
// keys that have been seen in previous calls, and validates that types are
|
||||
// consistent.
|
||||
func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
|
||||
// consistent. It returns true if it is the first time this node's key is seen.
|
||||
// Useful to clear array tables on first use.
|
||||
func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) {
|
||||
if s.entries == nil {
|
||||
s.reset()
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkTable(node *unstable.Node) error {
|
||||
func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) {
|
||||
if s.currentIdx >= 0 {
|
||||
s.setExplicitFlag(s.currentIdx)
|
||||
}
|
||||
@@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
|
||||
} else {
|
||||
entry := s.entries[idx]
|
||||
if entry.kind == valueKind {
|
||||
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
}
|
||||
}
|
||||
parentIdx = idx
|
||||
@@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
|
||||
k := it.Node().Data
|
||||
idx := s.find(parentIdx, k)
|
||||
|
||||
first := false
|
||||
if idx >= 0 {
|
||||
kind := s.entries[idx].kind
|
||||
if kind != tableKind {
|
||||
return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
|
||||
return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
|
||||
}
|
||||
if s.entries[idx].explicit {
|
||||
return fmt.Errorf("toml: table %s already exists", string(k))
|
||||
return false, fmt.Errorf("toml: table %s already exists", string(k))
|
||||
}
|
||||
s.entries[idx].explicit = true
|
||||
} else {
|
||||
idx = s.create(parentIdx, k, tableKind, true, false)
|
||||
first = true
|
||||
}
|
||||
|
||||
s.currentIdx = idx
|
||||
|
||||
return nil
|
||||
return first, nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
|
||||
func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) {
|
||||
if s.currentIdx >= 0 {
|
||||
s.setExplicitFlag(s.currentIdx)
|
||||
}
|
||||
@@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
|
||||
} else {
|
||||
entry := s.entries[idx]
|
||||
if entry.kind == valueKind {
|
||||
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
|
||||
k := it.Node().Data
|
||||
idx := s.find(parentIdx, k)
|
||||
|
||||
if idx >= 0 {
|
||||
firstTime := idx < 0
|
||||
if firstTime {
|
||||
idx = s.create(parentIdx, k, arrayTableKind, true, false)
|
||||
} else {
|
||||
kind := s.entries[idx].kind
|
||||
if kind != arrayTableKind {
|
||||
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
|
||||
return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
|
||||
}
|
||||
s.clear(idx)
|
||||
} else {
|
||||
idx = s.create(parentIdx, k, arrayTableKind, true, false)
|
||||
}
|
||||
|
||||
s.currentIdx = idx
|
||||
|
||||
return nil
|
||||
return firstTime, nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
|
||||
func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
|
||||
parentIdx := s.currentIdx
|
||||
it := node.Key()
|
||||
|
||||
@@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
|
||||
} else {
|
||||
entry := s.entries[idx]
|
||||
if it.IsLast() {
|
||||
return fmt.Errorf("toml: key %s is already defined", string(k))
|
||||
return false, fmt.Errorf("toml: key %s is already defined", string(k))
|
||||
} else if entry.kind != tableKind {
|
||||
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
|
||||
} else if entry.explicit {
|
||||
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
|
||||
return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
|
||||
return s.checkArray(value)
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkArray(node *unstable.Node) error {
|
||||
func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
|
||||
it := node.Children()
|
||||
for it.Next() {
|
||||
n := it.Node()
|
||||
switch n.Kind {
|
||||
case unstable.InlineTable:
|
||||
err := s.checkInlineTable(n)
|
||||
first, err = s.checkInlineTable(n)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
case unstable.Array:
|
||||
err := s.checkArray(n)
|
||||
first, err = s.checkArray(n)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return first, nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
|
||||
func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) {
|
||||
if pool.New == nil {
|
||||
pool.New = func() interface{} {
|
||||
return &SeenTracker{}
|
||||
@@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
|
||||
it := node.Children()
|
||||
for it.Next() {
|
||||
n := it.Node()
|
||||
err := s.checkKeyValue(n)
|
||||
first, err = s.checkKeyValue(n)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
|
||||
// redefinition of its keys: check* functions cannot walk into
|
||||
// a value.
|
||||
pool.Put(s)
|
||||
return nil
|
||||
return first, nil
|
||||
}
|
||||
|
||||
+16
-2
@@ -127,6 +127,10 @@ type decoder struct {
|
||||
// need to be skipped.
|
||||
skipUntilTable bool
|
||||
|
||||
// Flag indicating that the current array/slice table should be cleared because
|
||||
// it is the first encounter of an array table.
|
||||
clearArrayTable bool
|
||||
|
||||
// Tracks position in Go arrays.
|
||||
// This is used when decoding [[array tables]] into Go arrays. Given array
|
||||
// tables are separate TOML expression, we need to keep track of where we
|
||||
@@ -246,9 +250,10 @@ Rules for the unmarshal code:
|
||||
func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
|
||||
var x reflect.Value
|
||||
var err error
|
||||
var first bool // used for to clear array tables on first use
|
||||
|
||||
if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
|
||||
err = d.seen.CheckExpression(expr)
|
||||
first, err = d.seen.CheckExpression(expr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
|
||||
case unstable.ArrayTable:
|
||||
d.skipUntilTable = false
|
||||
d.strict.EnterArrayTable(expr)
|
||||
d.clearArrayTable = first
|
||||
x, err = d.handleArrayTable(expr.Key(), v)
|
||||
default:
|
||||
panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
|
||||
@@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
|
||||
reflect.Copy(nelem, elem)
|
||||
elem = nelem
|
||||
}
|
||||
if d.clearArrayTable && elem.Len() > 0 {
|
||||
elem.SetLen(0)
|
||||
d.clearArrayTable = false
|
||||
}
|
||||
}
|
||||
return d.handleArrayTableCollectionLast(key, elem)
|
||||
case reflect.Ptr:
|
||||
@@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
|
||||
|
||||
return v, nil
|
||||
case reflect.Slice:
|
||||
if d.clearArrayTable && v.Len() > 0 {
|
||||
v.SetLen(0)
|
||||
d.clearArrayTable = false
|
||||
}
|
||||
elemType := v.Type().Elem()
|
||||
var elem reflect.Value
|
||||
if elemType.Kind() == reflect.Interface {
|
||||
@@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
|
||||
break
|
||||
}
|
||||
|
||||
err := d.seen.CheckExpression(expr)
|
||||
_, err := d.seen.CheckExpression(expr)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
@@ -2823,6 +2823,76 @@ blah.a = "def"`)
|
||||
require.Equal(t, "def", cfg.A)
|
||||
}
|
||||
|
||||
func TestIssue931(t *testing.T) {
|
||||
type item struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type items struct {
|
||||
Slice []item
|
||||
}
|
||||
|
||||
its := items{[]item{{"a"}, {"b"}}}
|
||||
|
||||
b := []byte(`
|
||||
[[Slice]]
|
||||
Name = 'c'
|
||||
|
||||
[[Slice]]
|
||||
Name = 'd'
|
||||
`)
|
||||
|
||||
toml.Unmarshal(b, &its)
|
||||
require.Equal(t, items{[]item{{"c"}, {"d"}}}, its)
|
||||
}
|
||||
|
||||
func TestIssue931Interface(t *testing.T) {
|
||||
type items struct {
|
||||
Slice interface{}
|
||||
}
|
||||
|
||||
type item = map[string]interface{}
|
||||
|
||||
its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}}
|
||||
|
||||
b := []byte(`
|
||||
[[Slice]]
|
||||
Name = 'c'
|
||||
|
||||
[[Slice]]
|
||||
Name = 'd'
|
||||
`)
|
||||
|
||||
toml.Unmarshal(b, &its)
|
||||
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
|
||||
}
|
||||
|
||||
func TestIssue931SliceInterface(t *testing.T) {
|
||||
type items struct {
|
||||
Slice []interface{}
|
||||
}
|
||||
|
||||
type item = map[string]interface{}
|
||||
|
||||
its := items{
|
||||
[]interface{}{
|
||||
item{"Name": "a"},
|
||||
item{"Name": "b"},
|
||||
},
|
||||
}
|
||||
|
||||
b := []byte(`
|
||||
[[Slice]]
|
||||
Name = 'c'
|
||||
|
||||
[[Slice]]
|
||||
Name = 'd'
|
||||
`)
|
||||
|
||||
toml.Unmarshal(b, &its)
|
||||
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
|
||||
}
|
||||
|
||||
func TestUnmarshalDecodeErrors(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
|
||||
Reference in New Issue
Block a user