Decoder: disallow modification of existing table (#704)

Fixes #703
This commit is contained in:
Thomas Pelletier
2021-12-15 11:05:27 -05:00
committed by GitHub
parent facb2b13e8
commit 696dd25c17
3 changed files with 130 additions and 97 deletions
+24 -25
View File
@@ -20,8 +20,8 @@ type Iterator struct {
node *Node node *Node
} }
// Next moves the iterator forward and returns true if points to a node, false // Next moves the iterator forward and returns true if points to a
// otherwise. // node, false otherwise.
func (c *Iterator) Next() bool { func (c *Iterator) Next() bool {
if !c.started { if !c.started {
c.started = true c.started = true
@@ -31,8 +31,8 @@ func (c *Iterator) Next() bool {
return c.node.Valid() return c.node.Valid()
} }
// IsLast returns true if the current node of the iterator is the last one. // IsLast returns true if the current node of the iterator is the last
// Subsequent call to Next() will return false. // one. Subsequent call to Next() will return false.
func (c *Iterator) IsLast() bool { func (c *Iterator) IsLast() bool {
return c.node.next == 0 return c.node.next == 0
} }
@@ -62,20 +62,20 @@ func (r *Root) at(idx Reference) *Node {
return &r.nodes[idx] return &r.nodes[idx]
} }
// Arrays have one child per element in the array. // Arrays have one child per element in the array. InlineTables have
// InlineTables have one child per key-value pair in the table. // one child per key-value pair in the table. KeyValues have at least
// KeyValues have at least two children. The first one is the value. The // two children. The first one is the value. The rest make a
// rest make a potentially dotted key. // potentially dotted key. Table and Array table have one child per
// Table and Array table have one child per element of the key they // element of the key they represent (same as KeyValue, but without
// represent (same as KeyValue, but without the last node being the value). // the last node being the value).
// children []Node
type Node struct { type Node struct {
Kind Kind Kind Kind
Raw Range // Raw bytes from the input. Raw Range // Raw bytes from the input.
Data []byte // Node value (could be either allocated or referencing the input). Data []byte // Node value (either allocated or referencing the input).
// References to other nodes, as offsets in the backing array from this // References to other nodes, as offsets in the backing array
// node. References can go backward, so those can be negative. // from this node. References can go backward, so those can be
// negative.
next int // 0 if last element next int // 0 if last element
child int // 0 if no child child int // 0 if no child
} }
@@ -85,8 +85,8 @@ type Range struct {
Length uint32 Length uint32
} }
// Next returns a copy of the next node, or an invalid Node if there is no // Next returns a copy of the next node, or an invalid Node if there
// next node. // is no next node.
func (n *Node) Next() *Node { func (n *Node) Next() *Node {
if n.next == 0 { if n.next == 0 {
return nil return nil
@@ -96,9 +96,9 @@ func (n *Node) Next() *Node {
return (*Node)(danger.Stride(ptr, size, n.next)) return (*Node)(danger.Stride(ptr, size, n.next))
} }
// Child returns a copy of the first child node of this node. Other children // Child returns a copy of the first child node of this node. Other
// can be accessed calling Next on the first child. // children can be accessed calling Next on the first child. Returns
// Returns an invalid Node if there is none. // an invalid Node if there is none.
func (n *Node) Child() *Node { func (n *Node) Child() *Node {
if n.child == 0 { if n.child == 0 {
return nil return nil
@@ -113,10 +113,9 @@ func (n *Node) Valid() bool {
return n != nil return n != nil
} }
// Key returns the child nodes making the Key on a supported node. Panics // Key returns the child nodes making the Key on a supported
// otherwise. // node. Panics otherwise. They are guaranteed to be all be of the
// They are guaranteed to be all be of the Kind Key. A simple key would return // Kind Key. A simple key would return just one element.
// just one element.
func (n *Node) Key() Iterator { func (n *Node) Key() Iterator {
switch n.Kind { switch n.Kind {
case KeyValue: case KeyValue:
@@ -133,8 +132,8 @@ func (n *Node) Key() Iterator {
} }
// Value returns a pointer to the value node of a KeyValue. // Value returns a pointer to the value node of a KeyValue.
// Guaranteed to be non-nil. // Guaranteed to be non-nil. Panics if not called on a KeyValue node,
// Panics if not called on a KeyValue node, or if the Children are malformed. // or if the Children are malformed.
func (n *Node) Value() *Node { func (n *Node) Value() *Node {
return n.Child() return n.Child()
} }
+99 -71
View File
@@ -3,6 +3,7 @@ package tracker
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"sync"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
) )
@@ -54,69 +55,103 @@ func (k keyKind) String() string {
type SeenTracker struct { type SeenTracker struct {
entries []entry entries []entry
currentIdx int currentIdx int
nextID int }
var pool sync.Pool
func (s *SeenTracker) reset() {
// Always contains a root element at index 0.
s.currentIdx = 0
if len(s.entries) == 0 {
s.entries = make([]entry, 1, 2)
} else {
s.entries = s.entries[:1]
}
s.entries[0].child = -1
s.entries[0].next = -1
} }
type entry struct { type entry struct {
id int // Use -1 to indicate no child or no sibling.
parent int child int
next int
name []byte name []byte
kind keyKind kind keyKind
explicit bool explicit bool
} }
// Remove all descendants of node at position idx. // Find the index of the child of parentIdx with key k. Returns -1 if
func (s *SeenTracker) clear(idx int) { // it does not exist.
p := s.entries[idx].id func (s *SeenTracker) find(parentIdx int, k []byte) int {
rest := clear(p, s.entries[idx+1:]) for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries = s.entries[:idx+1+len(rest)] if bytes.Equal(s.entries[i].name, k) {
} return i
func clear(parentID int, entries []entry) []entry {
for i := 0; i < len(entries); {
if entries[i].parent == parentID {
id := entries[i].id
copy(entries[i:], entries[i+1:])
entries = entries[:len(entries)-1]
rest := clear(id, entries[i:])
entries = entries[:i+len(rest)]
} else {
i++
} }
} }
return entries return -1
}
// Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) {
if idx >= len(s.entries) {
return
}
for i := s.entries[idx].child; i >= 0; {
next := s.entries[i].next
n := s.entries[0].next
s.entries[0].next = i
s.entries[i].next = n
s.entries[i].name = nil
s.clear(i)
i = next
}
s.entries[idx].child = -1
} }
func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int { func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int {
parentID := s.id(parentIdx) e := entry{
child: -1,
next: s.entries[parentIdx].child,
idx := len(s.entries)
s.entries = append(s.entries, entry{
id: s.nextID,
parent: parentID,
name: name, name: name,
kind: kind, kind: kind,
explicit: explicit, explicit: explicit,
}) }
s.nextID++ var idx int
if s.entries[0].next >= 0 {
idx = s.entries[0].next
s.entries[0].next = s.entries[idx].next
s.entries[idx] = e
} else {
idx = len(s.entries)
s.entries = append(s.entries, e)
}
s.entries[parentIdx].child = idx
return idx return idx
} }
func (s *SeenTracker) setExplicitFlag(parentIdx int) {
for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries[i].explicit = true
s.setExplicitFlag(i)
}
}
// CheckExpression takes a top-level node and checks that it does not contain // CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are // keys that have been seen in previous calls, and validates that types are
// consistent. // consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error { func (s *SeenTracker) CheckExpression(node *ast.Node) error {
if s.entries == nil { if s.entries == nil {
// Skip ID = 0 to remove the confusion between nodes whose s.reset()
// parent has id 0 and root nodes (parent id is 0 because it's
// the zero value).
s.nextID = 1
// Start unscoped, so idx is negative.
s.currentIdx = -1
} }
switch node.Kind { switch node.Kind {
case ast.KeyValue: case ast.KeyValue:
return s.checkKeyValue(s.currentIdx, node) return s.checkKeyValue(node)
case ast.Table: case ast.Table:
return s.checkTable(node) return s.checkTable(node)
case ast.ArrayTable: case ast.ArrayTable:
@@ -127,9 +162,13 @@ func (s *SeenTracker) CheckExpression(node *ast.Node) error {
} }
func (s *SeenTracker) checkTable(node *ast.Node) error { func (s *SeenTracker) checkTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key() it := node.Key()
parentIdx := -1 parentIdx := 0
// This code is duplicated in checkArrayTable. This is because factoring // This code is duplicated in checkArrayTable. This is because factoring
// it in a function requires to copy the iterator, or allocate it to the // it in a function requires to copy the iterator, or allocate it to the
@@ -176,9 +215,13 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
} }
func (s *SeenTracker) checkArrayTable(node *ast.Node) error { func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key() it := node.Key()
parentIdx := -1 parentIdx := 0
for it.Next() { for it.Next() {
if it.IsLast() { if it.IsLast() {
@@ -219,7 +262,8 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error { func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
parentIdx := s.currentIdx
it := node.Key() it := node.Key()
for it.Next() { for it.Next() {
@@ -249,45 +293,48 @@ func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {
switch value.Kind { switch value.Kind {
case ast.InlineTable: case ast.InlineTable:
return s.checkInlineTable(parentIdx, value) return s.checkInlineTable(value)
case ast.Array: case ast.Array:
return s.checkArray(parentIdx, value) return s.checkArray(value)
} }
return nil return nil
} }
func (s *SeenTracker) checkArray(parentIdx int, node *ast.Node) error { func (s *SeenTracker) checkArray(node *ast.Node) error {
set := false
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
if set {
s.clear(parentIdx)
}
n := it.Node() n := it.Node()
switch n.Kind { switch n.Kind {
case ast.InlineTable: case ast.InlineTable:
err := s.checkInlineTable(parentIdx, n) err := s.checkInlineTable(n)
if err != nil { if err != nil {
return err return err
} }
set = true
case ast.Array: case ast.Array:
err := s.checkArray(parentIdx, n) err := s.checkArray(n)
if err != nil { if err != nil {
return err return err
} }
set = true
} }
} }
return nil return nil
} }
func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error { func (s *SeenTracker) checkInlineTable(node *ast.Node) error {
if pool.New == nil {
pool.New = func() interface{} {
return &SeenTracker{}
}
}
s = pool.Get().(*SeenTracker)
s.reset()
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
err := s.checkKeyValue(parentIdx, n) err := s.checkKeyValue(n)
if err != nil { if err != nil {
return err return err
} }
@@ -299,25 +346,6 @@ func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
// mark the presence of the inline table and prevent // mark the presence of the inline table and prevent
// redefinition of its keys: check* functions cannot walk into // redefinition of its keys: check* functions cannot walk into
// a value. // a value.
s.clear(parentIdx) pool.Put(s)
return nil return nil
} }
func (s *SeenTracker) id(idx int) int {
if idx >= 0 {
return s.entries[idx].id
}
return 0
}
func (s *SeenTracker) find(parentIdx int, k []byte) int {
parentID := s.id(parentIdx)
for i := parentIdx + 1; i < len(s.entries); i++ {
if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) {
return i
}
}
return -1
}
+7 -1
View File
@@ -2129,7 +2129,7 @@ xz_hash = "1a48f723fea1f17d786ce6eadd9d00914d38062d28fd9c455ed3c3801905b388"
expected := doc{ expected := doc{
Pkg: map[string]pkg{ Pkg: map[string]pkg{
"cargo": pkg{ "cargo": {
Target: map[string]target{ Target: map[string]target{
"aarch64-apple-darwin": { "aarch64-apple-darwin": {
XZ_URL: "https://static.rust-lang.org/dist/2021-07-29/cargo-1.54.0-aarch64-apple-darwin.tar.xz", XZ_URL: "https://static.rust-lang.org/dist/2021-07-29/cargo-1.54.0-aarch64-apple-darwin.tar.xz",
@@ -2298,6 +2298,12 @@ z=0
} }
} }
func TestIssue703(t *testing.T) {
var v interface{}
err := toml.Unmarshal([]byte("[a]\nx.y=0\n[a.x]"), &v)
require.Error(t, err)
}
func TestUnmarshalDecodeErrors(t *testing.T) { func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct { examples := []struct {
desc string desc string