diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 3b86002..33c7f91 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -20,8 +20,8 @@ type Iterator struct { node *Node } -// Next moves the iterator forward and returns true if points to a node, false -// otherwise. +// Next moves the iterator forward and returns true if points to a +// node, false otherwise. func (c *Iterator) Next() bool { if !c.started { c.started = true @@ -31,8 +31,8 @@ func (c *Iterator) Next() bool { return c.node.Valid() } -// IsLast returns true if the current node of the iterator is the last one. -// Subsequent call to Next() will return false. +// IsLast returns true if the current node of the iterator is the last +// one. Subsequent call to Next() will return false. func (c *Iterator) IsLast() bool { return c.node.next == 0 } @@ -62,20 +62,20 @@ func (r *Root) at(idx Reference) *Node { return &r.nodes[idx] } -// Arrays have one child per element in the array. -// InlineTables have one child per key-value pair in the table. -// KeyValues have at least two children. The first one is the value. The -// rest make a potentially dotted key. -// Table and Array table have one child per element of the key they -// represent (same as KeyValue, but without the last node being the value). -// children []Node +// Arrays have one child per element in the array. InlineTables have +// one child per key-value pair in the table. KeyValues have at least +// two children. The first one is the value. The rest make a +// potentially dotted key. Table and Array table have one child per +// element of the key they represent (same as KeyValue, but without +// the last node being the value). type Node struct { Kind Kind Raw Range // Raw bytes from the input. - Data []byte // Node value (could be either allocated or referencing the input). + Data []byte // Node value (either allocated or referencing the input). - // References to other nodes, as offsets in the backing array from this - // node. References can go backward, so those can be negative. + // References to other nodes, as offsets in the backing array + // from this node. References can go backward, so those can be + // negative. next int // 0 if last element child int // 0 if no child } @@ -85,8 +85,8 @@ type Range struct { Length uint32 } -// Next returns a copy of the next node, or an invalid Node if there is no -// next node. +// Next returns a copy of the next node, or an invalid Node if there +// is no next node. func (n *Node) Next() *Node { if n.next == 0 { return nil @@ -96,9 +96,9 @@ func (n *Node) Next() *Node { return (*Node)(danger.Stride(ptr, size, n.next)) } -// Child returns a copy of the first child node of this node. Other children -// can be accessed calling Next on the first child. -// Returns an invalid Node if there is none. +// Child returns a copy of the first child node of this node. Other +// children can be accessed calling Next on the first child. Returns +// an invalid Node if there is none. func (n *Node) Child() *Node { if n.child == 0 { return nil @@ -113,10 +113,9 @@ func (n *Node) Valid() bool { return n != nil } -// Key returns the child nodes making the Key on a supported node. Panics -// otherwise. -// They are guaranteed to be all be of the Kind Key. A simple key would return -// just one element. +// Key returns the child nodes making the Key on a supported +// node. Panics otherwise. They are guaranteed to be all be of the +// Kind Key. A simple key would return just one element. func (n *Node) Key() Iterator { switch n.Kind { case KeyValue: @@ -133,8 +132,8 @@ func (n *Node) Key() Iterator { } // Value returns a pointer to the value node of a KeyValue. -// Guaranteed to be non-nil. -// Panics if not called on a KeyValue node, or if the Children are malformed. +// Guaranteed to be non-nil. Panics if not called on a KeyValue node, +// or if the Children are malformed. func (n *Node) Value() *Node { return n.Child() } diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index 167790d..434b02c 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -3,6 +3,7 @@ package tracker import ( "bytes" "fmt" + "sync" "github.com/pelletier/go-toml/v2/internal/ast" ) @@ -54,69 +55,103 @@ func (k keyKind) String() string { type SeenTracker struct { entries []entry currentIdx int - nextID int +} + +var pool sync.Pool + +func (s *SeenTracker) reset() { + // Always contains a root element at index 0. + s.currentIdx = 0 + if len(s.entries) == 0 { + s.entries = make([]entry, 1, 2) + } else { + s.entries = s.entries[:1] + } + s.entries[0].child = -1 + s.entries[0].next = -1 } type entry struct { - id int - parent int + // Use -1 to indicate no child or no sibling. + child int + next int + name []byte kind keyKind explicit bool } -// Remove all descendants of node at position idx. -func (s *SeenTracker) clear(idx int) { - p := s.entries[idx].id - rest := clear(p, s.entries[idx+1:]) - s.entries = s.entries[:idx+1+len(rest)] -} - -func clear(parentID int, entries []entry) []entry { - for i := 0; i < len(entries); { - if entries[i].parent == parentID { - id := entries[i].id - copy(entries[i:], entries[i+1:]) - entries = entries[:len(entries)-1] - rest := clear(id, entries[i:]) - entries = entries[:i+len(rest)] - } else { - i++ +// Find the index of the child of parentIdx with key k. Returns -1 if +// it does not exist. +func (s *SeenTracker) find(parentIdx int, k []byte) int { + for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next { + if bytes.Equal(s.entries[i].name, k) { + return i } } - return entries + return -1 +} + +// Remove all descendants of node at position idx. +func (s *SeenTracker) clear(idx int) { + if idx >= len(s.entries) { + return + } + + for i := s.entries[idx].child; i >= 0; { + next := s.entries[i].next + n := s.entries[0].next + s.entries[0].next = i + s.entries[i].next = n + s.entries[i].name = nil + s.clear(i) + i = next + } + + s.entries[idx].child = -1 } func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int { - parentID := s.id(parentIdx) + e := entry{ + child: -1, + next: s.entries[parentIdx].child, - idx := len(s.entries) - s.entries = append(s.entries, entry{ - id: s.nextID, - parent: parentID, name: name, kind: kind, explicit: explicit, - }) - s.nextID++ + } + var idx int + if s.entries[0].next >= 0 { + idx = s.entries[0].next + s.entries[0].next = s.entries[idx].next + s.entries[idx] = e + } else { + idx = len(s.entries) + s.entries = append(s.entries, e) + } + + s.entries[parentIdx].child = idx + return idx } +func (s *SeenTracker) setExplicitFlag(parentIdx int) { + for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next { + s.entries[i].explicit = true + s.setExplicitFlag(i) + } +} + // CheckExpression takes a top-level node and checks that it does not contain // keys that have been seen in previous calls, and validates that types are // consistent. func (s *SeenTracker) CheckExpression(node *ast.Node) error { if s.entries == nil { - // Skip ID = 0 to remove the confusion between nodes whose - // parent has id 0 and root nodes (parent id is 0 because it's - // the zero value). - s.nextID = 1 - // Start unscoped, so idx is negative. - s.currentIdx = -1 + s.reset() } switch node.Kind { case ast.KeyValue: - return s.checkKeyValue(s.currentIdx, node) + return s.checkKeyValue(node) case ast.Table: return s.checkTable(node) case ast.ArrayTable: @@ -127,9 +162,13 @@ func (s *SeenTracker) CheckExpression(node *ast.Node) error { } func (s *SeenTracker) checkTable(node *ast.Node) error { + if s.currentIdx >= 0 { + s.setExplicitFlag(s.currentIdx) + } + it := node.Key() - parentIdx := -1 + parentIdx := 0 // This code is duplicated in checkArrayTable. This is because factoring // it in a function requires to copy the iterator, or allocate it to the @@ -176,9 +215,13 @@ func (s *SeenTracker) checkTable(node *ast.Node) error { } func (s *SeenTracker) checkArrayTable(node *ast.Node) error { + if s.currentIdx >= 0 { + s.setExplicitFlag(s.currentIdx) + } + it := node.Key() - parentIdx := -1 + parentIdx := 0 for it.Next() { if it.IsLast() { @@ -219,7 +262,8 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error { return nil } -func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error { +func (s *SeenTracker) checkKeyValue(node *ast.Node) error { + parentIdx := s.currentIdx it := node.Key() for it.Next() { @@ -249,45 +293,48 @@ func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error { switch value.Kind { case ast.InlineTable: - return s.checkInlineTable(parentIdx, value) + return s.checkInlineTable(value) case ast.Array: - return s.checkArray(parentIdx, value) + return s.checkArray(value) } return nil } -func (s *SeenTracker) checkArray(parentIdx int, node *ast.Node) error { - set := false +func (s *SeenTracker) checkArray(node *ast.Node) error { it := node.Children() for it.Next() { - if set { - s.clear(parentIdx) - } n := it.Node() switch n.Kind { case ast.InlineTable: - err := s.checkInlineTable(parentIdx, n) + err := s.checkInlineTable(n) if err != nil { return err } - set = true case ast.Array: - err := s.checkArray(parentIdx, n) + err := s.checkArray(n) if err != nil { return err } - set = true } } return nil } -func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error { +func (s *SeenTracker) checkInlineTable(node *ast.Node) error { + if pool.New == nil { + pool.New = func() interface{} { + return &SeenTracker{} + } + } + + s = pool.Get().(*SeenTracker) + s.reset() + it := node.Children() for it.Next() { n := it.Node() - err := s.checkKeyValue(parentIdx, n) + err := s.checkKeyValue(n) if err != nil { return err } @@ -299,25 +346,6 @@ func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error { // mark the presence of the inline table and prevent // redefinition of its keys: check* functions cannot walk into // a value. - s.clear(parentIdx) + pool.Put(s) return nil } - -func (s *SeenTracker) id(idx int) int { - if idx >= 0 { - return s.entries[idx].id - } - return 0 -} - -func (s *SeenTracker) find(parentIdx int, k []byte) int { - parentID := s.id(parentIdx) - - for i := parentIdx + 1; i < len(s.entries); i++ { - if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) { - return i - } - } - - return -1 -} diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 5339b6a..17bd891 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -2129,7 +2129,7 @@ xz_hash = "1a48f723fea1f17d786ce6eadd9d00914d38062d28fd9c455ed3c3801905b388" expected := doc{ Pkg: map[string]pkg{ - "cargo": pkg{ + "cargo": { Target: map[string]target{ "aarch64-apple-darwin": { XZ_URL: "https://static.rust-lang.org/dist/2021-07-29/cargo-1.54.0-aarch64-apple-darwin.tar.xz", @@ -2298,6 +2298,12 @@ z=0 } } +func TestIssue703(t *testing.T) { + var v interface{} + err := toml.Unmarshal([]byte("[a]\nx.y=0\n[a.x]"), &v) + require.Error(t, err) +} + func TestUnmarshalDecodeErrors(t *testing.T) { examples := []struct { desc string