Decode: fix reuse of slice for array tables (#934)

When decoding into a non-empty slice, it needs to be emptied so that only the
tables contained in the document are present in the resulting value.

Arrays are not impacted because their unmarshal offset is tracked separately.

Fixes #931
This commit is contained in:
Thomas Pelletier
2024-02-27 15:28:49 -05:00
committed by GitHub
parent 2e087bdf5f
commit 06fb30bf2e
3 changed files with 119 additions and 31 deletions
+33 -29
View File
@@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {
// CheckExpression takes a top-level node and checks that it does not contain // CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are // keys that have been seen in previous calls, and validates that types are
// consistent. // consistent. It returns true if it is the first time this node's key is seen.
func (s *SeenTracker) CheckExpression(node *unstable.Node) error { // Useful to clear array tables on first use.
func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) {
if s.entries == nil { if s.entries == nil {
s.reset() s.reset()
} }
@@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
} }
} }
func (s *SeenTracker) checkTable(node *unstable.Node) error { func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
} else { } else {
entry := s.entries[idx] entry := s.entries[idx]
if entry.kind == valueKind { if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} }
} }
parentIdx = idx parentIdx = idx
@@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
k := it.Node().Data k := it.Node().Data
idx := s.find(parentIdx, k) idx := s.find(parentIdx, k)
first := false
if idx >= 0 { if idx >= 0 {
kind := s.entries[idx].kind kind := s.entries[idx].kind
if kind != tableKind { if kind != tableKind {
return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind) return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
} }
if s.entries[idx].explicit { if s.entries[idx].explicit {
return fmt.Errorf("toml: table %s already exists", string(k)) return false, fmt.Errorf("toml: table %s already exists", string(k))
} }
s.entries[idx].explicit = true s.entries[idx].explicit = true
} else { } else {
idx = s.create(parentIdx, k, tableKind, true, false) idx = s.create(parentIdx, k, tableKind, true, false)
first = true
} }
s.currentIdx = idx s.currentIdx = idx
return nil return first, nil
} }
func (s *SeenTracker) checkArrayTable(node *unstable.Node) error { func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
} else { } else {
entry := s.entries[idx] entry := s.entries[idx]
if entry.kind == valueKind { if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} }
} }
@@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
k := it.Node().Data k := it.Node().Data
idx := s.find(parentIdx, k) idx := s.find(parentIdx, k)
if idx >= 0 { firstTime := idx < 0
if firstTime {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
} else {
kind := s.entries[idx].kind kind := s.entries[idx].kind
if kind != arrayTableKind { if kind != arrayTableKind {
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k)) return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
} }
s.clear(idx) s.clear(idx)
} else {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
} }
s.currentIdx = idx s.currentIdx = idx
return nil return firstTime, nil
} }
func (s *SeenTracker) checkKeyValue(node *unstable.Node) error { func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
parentIdx := s.currentIdx parentIdx := s.currentIdx
it := node.Key() it := node.Key()
@@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
} else { } else {
entry := s.entries[idx] entry := s.entries[idx]
if it.IsLast() { if it.IsLast() {
return fmt.Errorf("toml: key %s is already defined", string(k)) return false, fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind { } else if entry.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit { } else if entry.explicit {
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k)) return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
} }
} }
@@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
return s.checkArray(value) return s.checkArray(value)
} }
return nil return false, nil
} }
func (s *SeenTracker) checkArray(node *unstable.Node) error { func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
switch n.Kind { switch n.Kind {
case unstable.InlineTable: case unstable.InlineTable:
err := s.checkInlineTable(n) first, err = s.checkInlineTable(n)
if err != nil { if err != nil {
return err return false, err
} }
case unstable.Array: case unstable.Array:
err := s.checkArray(n) first, err = s.checkArray(n)
if err != nil { if err != nil {
return err return false, err
} }
} }
} }
return nil return first, nil
} }
func (s *SeenTracker) checkInlineTable(node *unstable.Node) error { func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) {
if pool.New == nil { if pool.New == nil {
pool.New = func() interface{} { pool.New = func() interface{} {
return &SeenTracker{} return &SeenTracker{}
@@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
err := s.checkKeyValue(n) first, err = s.checkKeyValue(n)
if err != nil { if err != nil {
return err return false, err
} }
} }
@@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
// redefinition of its keys: check* functions cannot walk into // redefinition of its keys: check* functions cannot walk into
// a value. // a value.
pool.Put(s) pool.Put(s)
return nil return first, nil
} }
+16 -2
View File
@@ -127,6 +127,10 @@ type decoder struct {
// need to be skipped. // need to be skipped.
skipUntilTable bool skipUntilTable bool
// Flag indicating that the current array/slice table should be cleared because
// it is the first encounter of an array table.
clearArrayTable bool
// Tracks position in Go arrays. // Tracks position in Go arrays.
// This is used when decoding [[array tables]] into Go arrays. Given array // This is used when decoding [[array tables]] into Go arrays. Given array
// tables are separate TOML expression, we need to keep track of where we // tables are separate TOML expression, we need to keep track of where we
@@ -246,9 +250,10 @@ Rules for the unmarshal code:
func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error { func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
var x reflect.Value var x reflect.Value
var err error var err error
var first bool // used for to clear array tables on first use
if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) { if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
err = d.seen.CheckExpression(expr) first, err = d.seen.CheckExpression(expr)
if err != nil { if err != nil {
return err return err
} }
@@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
case unstable.ArrayTable: case unstable.ArrayTable:
d.skipUntilTable = false d.skipUntilTable = false
d.strict.EnterArrayTable(expr) d.strict.EnterArrayTable(expr)
d.clearArrayTable = first
x, err = d.handleArrayTable(expr.Key(), v) x, err = d.handleArrayTable(expr.Key(), v)
default: default:
panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind)) panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
@@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
reflect.Copy(nelem, elem) reflect.Copy(nelem, elem)
elem = nelem elem = nelem
} }
if d.clearArrayTable && elem.Len() > 0 {
elem.SetLen(0)
d.clearArrayTable = false
}
} }
return d.handleArrayTableCollectionLast(key, elem) return d.handleArrayTableCollectionLast(key, elem)
case reflect.Ptr: case reflect.Ptr:
@@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
return v, nil return v, nil
case reflect.Slice: case reflect.Slice:
if d.clearArrayTable && v.Len() > 0 {
v.SetLen(0)
d.clearArrayTable = false
}
elemType := v.Type().Elem() elemType := v.Type().Elem()
var elem reflect.Value var elem reflect.Value
if elemType.Kind() == reflect.Interface { if elemType.Kind() == reflect.Interface {
@@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
break break
} }
err := d.seen.CheckExpression(expr) _, err := d.seen.CheckExpression(expr)
if err != nil { if err != nil {
return reflect.Value{}, err return reflect.Value{}, err
} }
+70
View File
@@ -2823,6 +2823,76 @@ blah.a = "def"`)
require.Equal(t, "def", cfg.A) require.Equal(t, "def", cfg.A)
} }
func TestIssue931(t *testing.T) {
type item struct {
Name string
}
type items struct {
Slice []item
}
its := items{[]item{{"a"}, {"b"}}}
b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)
toml.Unmarshal(b, &its)
require.Equal(t, items{[]item{{"c"}, {"d"}}}, its)
}
func TestIssue931Interface(t *testing.T) {
type items struct {
Slice interface{}
}
type item = map[string]interface{}
its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}}
b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)
toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}
func TestIssue931SliceInterface(t *testing.T) {
type items struct {
Slice []interface{}
}
type item = map[string]interface{}
its := items{
[]interface{}{
item{"Name": "a"},
item{"Name": "b"},
},
}
b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)
toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}
func TestUnmarshalDecodeErrors(t *testing.T) { func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct { examples := []struct {
desc string desc string