diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index 40e23f8..ce7dd4a 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) { // CheckExpression takes a top-level node and checks that it does not contain // keys that have been seen in previous calls, and validates that types are -// consistent. -func (s *SeenTracker) CheckExpression(node *unstable.Node) error { +// consistent. It returns true if it is the first time this node's key is seen. +// Useful to clear array tables on first use. +func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) { if s.entries == nil { s.reset() } @@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error { } } -func (s *SeenTracker) checkTable(node *unstable.Node) error { +func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) { if s.currentIdx >= 0 { s.setExplicitFlag(s.currentIdx) } @@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error { } else { entry := s.entries[idx] if entry.kind == valueKind { - return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) + return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) } } parentIdx = idx @@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error { k := it.Node().Data idx := s.find(parentIdx, k) + first := false if idx >= 0 { kind := s.entries[idx].kind if kind != tableKind { - return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind) + return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind) } if s.entries[idx].explicit { - return fmt.Errorf("toml: table %s already exists", string(k)) + return false, fmt.Errorf("toml: table %s already exists", string(k)) } s.entries[idx].explicit = true } else { idx = s.create(parentIdx, k, tableKind, true, false) + first = true } s.currentIdx = idx - return nil + return first, nil } -func (s *SeenTracker) checkArrayTable(node *unstable.Node) error { +func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) { if s.currentIdx >= 0 { s.setExplicitFlag(s.currentIdx) } @@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error { } else { entry := s.entries[idx] if entry.kind == valueKind { - return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) + return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) } } @@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error { k := it.Node().Data idx := s.find(parentIdx, k) - if idx >= 0 { + firstTime := idx < 0 + if firstTime { + idx = s.create(parentIdx, k, arrayTableKind, true, false) + } else { kind := s.entries[idx].kind if kind != arrayTableKind { - return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k)) + return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k)) } s.clear(idx) - } else { - idx = s.create(parentIdx, k, arrayTableKind, true, false) } s.currentIdx = idx - return nil + return firstTime, nil } -func (s *SeenTracker) checkKeyValue(node *unstable.Node) error { +func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) { parentIdx := s.currentIdx it := node.Key() @@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error { } else { entry := s.entries[idx] if it.IsLast() { - return fmt.Errorf("toml: key %s is already defined", string(k)) + return false, fmt.Errorf("toml: key %s is already defined", string(k)) } else if entry.kind != tableKind { - return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) + return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) } else if entry.explicit { - return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k)) + return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k)) } } @@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error { return s.checkArray(value) } - return nil + return false, nil } -func (s *SeenTracker) checkArray(node *unstable.Node) error { +func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) { it := node.Children() for it.Next() { n := it.Node() switch n.Kind { case unstable.InlineTable: - err := s.checkInlineTable(n) + first, err = s.checkInlineTable(n) if err != nil { - return err + return false, err } case unstable.Array: - err := s.checkArray(n) + first, err = s.checkArray(n) if err != nil { - return err + return false, err } } } - return nil + return first, nil } -func (s *SeenTracker) checkInlineTable(node *unstable.Node) error { +func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) { if pool.New == nil { pool.New = func() interface{} { return &SeenTracker{} @@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error { it := node.Children() for it.Next() { n := it.Node() - err := s.checkKeyValue(n) + first, err = s.checkKeyValue(n) if err != nil { - return err + return false, err } } @@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error { // redefinition of its keys: check* functions cannot walk into // a value. pool.Put(s) - return nil + return first, nil } diff --git a/unmarshaler.go b/unmarshaler.go index c5e5f33..857f3cf 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -127,6 +127,10 @@ type decoder struct { // need to be skipped. skipUntilTable bool + // Flag indicating that the current array/slice table should be cleared because + // it is the first encounter of an array table. + clearArrayTable bool + // Tracks position in Go arrays. // This is used when decoding [[array tables]] into Go arrays. Given array // tables are separate TOML expression, we need to keep track of where we @@ -246,9 +250,10 @@ Rules for the unmarshal code: func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error { var x reflect.Value var err error + var first bool // used for to clear array tables on first use if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) { - err = d.seen.CheckExpression(expr) + first, err = d.seen.CheckExpression(expr) if err != nil { return err } @@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err case unstable.ArrayTable: d.skipUntilTable = false d.strict.EnterArrayTable(expr) + d.clearArrayTable = first x, err = d.handleArrayTable(expr.Key(), v) default: panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind)) @@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec reflect.Copy(nelem, elem) elem = nelem } + if d.clearArrayTable && elem.Len() > 0 { + elem.SetLen(0) + d.clearArrayTable = false + } } return d.handleArrayTableCollectionLast(key, elem) case reflect.Ptr: @@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec return v, nil case reflect.Slice: + if d.clearArrayTable && v.Len() > 0 { + v.SetLen(0) + d.clearArrayTable = false + } elemType := v.Type().Elem() var elem reflect.Value if elemType.Kind() == reflect.Interface { @@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { break } - err := d.seen.CheckExpression(expr) + _, err := d.seen.CheckExpression(expr) if err != nil { return reflect.Value{}, err } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index fa015c2..78e0689 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -2823,6 +2823,76 @@ blah.a = "def"`) require.Equal(t, "def", cfg.A) } +func TestIssue931(t *testing.T) { + type item struct { + Name string + } + + type items struct { + Slice []item + } + + its := items{[]item{{"a"}, {"b"}}} + + b := []byte(` + [[Slice]] + Name = 'c' + +[[Slice]] + Name = 'd' + `) + + toml.Unmarshal(b, &its) + require.Equal(t, items{[]item{{"c"}, {"d"}}}, its) +} + +func TestIssue931Interface(t *testing.T) { + type items struct { + Slice interface{} + } + + type item = map[string]interface{} + + its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}} + + b := []byte(` + [[Slice]] + Name = 'c' + +[[Slice]] + Name = 'd' + `) + + toml.Unmarshal(b, &its) + require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its) +} + +func TestIssue931SliceInterface(t *testing.T) { + type items struct { + Slice []interface{} + } + + type item = map[string]interface{} + + its := items{ + []interface{}{ + item{"Name": "a"}, + item{"Name": "b"}, + }, + } + + b := []byte(` + [[Slice]] + Name = 'c' + +[[Slice]] + Name = 'd' + `) + + toml.Unmarshal(b, &its) + require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its) +} + func TestUnmarshalDecodeErrors(t *testing.T) { examples := []struct { desc string