Decoder: disallow modification of existing table (#704)

Fixes #703
This commit is contained in:
Thomas Pelletier
2021-12-15 11:05:27 -05:00
committed by GitHub
parent facb2b13e8
commit 696dd25c17
3 changed files with 130 additions and 97 deletions
+99 -71
View File
@@ -3,6 +3,7 @@ package tracker
import (
"bytes"
"fmt"
"sync"
"github.com/pelletier/go-toml/v2/internal/ast"
)
@@ -54,69 +55,103 @@ func (k keyKind) String() string {
type SeenTracker struct {
entries []entry
currentIdx int
nextID int
}
var pool sync.Pool
func (s *SeenTracker) reset() {
// Always contains a root element at index 0.
s.currentIdx = 0
if len(s.entries) == 0 {
s.entries = make([]entry, 1, 2)
} else {
s.entries = s.entries[:1]
}
s.entries[0].child = -1
s.entries[0].next = -1
}
type entry struct {
id int
parent int
// Use -1 to indicate no child or no sibling.
child int
next int
name []byte
kind keyKind
explicit bool
}
// Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) {
p := s.entries[idx].id
rest := clear(p, s.entries[idx+1:])
s.entries = s.entries[:idx+1+len(rest)]
}
func clear(parentID int, entries []entry) []entry {
for i := 0; i < len(entries); {
if entries[i].parent == parentID {
id := entries[i].id
copy(entries[i:], entries[i+1:])
entries = entries[:len(entries)-1]
rest := clear(id, entries[i:])
entries = entries[:i+len(rest)]
} else {
i++
// Find the index of the child of parentIdx with key k. Returns -1 if
// it does not exist.
func (s *SeenTracker) find(parentIdx int, k []byte) int {
for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
if bytes.Equal(s.entries[i].name, k) {
return i
}
}
return entries
return -1
}
// Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) {
if idx >= len(s.entries) {
return
}
for i := s.entries[idx].child; i >= 0; {
next := s.entries[i].next
n := s.entries[0].next
s.entries[0].next = i
s.entries[i].next = n
s.entries[i].name = nil
s.clear(i)
i = next
}
s.entries[idx].child = -1
}
func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int {
parentID := s.id(parentIdx)
e := entry{
child: -1,
next: s.entries[parentIdx].child,
idx := len(s.entries)
s.entries = append(s.entries, entry{
id: s.nextID,
parent: parentID,
name: name,
kind: kind,
explicit: explicit,
})
s.nextID++
}
var idx int
if s.entries[0].next >= 0 {
idx = s.entries[0].next
s.entries[0].next = s.entries[idx].next
s.entries[idx] = e
} else {
idx = len(s.entries)
s.entries = append(s.entries, e)
}
s.entries[parentIdx].child = idx
return idx
}
func (s *SeenTracker) setExplicitFlag(parentIdx int) {
for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries[i].explicit = true
s.setExplicitFlag(i)
}
}
// CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error {
if s.entries == nil {
// Skip ID = 0 to remove the confusion between nodes whose
// parent has id 0 and root nodes (parent id is 0 because it's
// the zero value).
s.nextID = 1
// Start unscoped, so idx is negative.
s.currentIdx = -1
s.reset()
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.currentIdx, node)
return s.checkKeyValue(node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
@@ -127,9 +162,13 @@ func (s *SeenTracker) CheckExpression(node *ast.Node) error {
}
func (s *SeenTracker) checkTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key()
parentIdx := -1
parentIdx := 0
// This code is duplicated in checkArrayTable. This is because factoring
// it in a function requires to copy the iterator, or allocate it to the
@@ -176,9 +215,13 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
}
func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key()
parentIdx := -1
parentIdx := 0
for it.Next() {
if it.IsLast() {
@@ -219,7 +262,8 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil
}
func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {
func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
parentIdx := s.currentIdx
it := node.Key()
for it.Next() {
@@ -249,45 +293,48 @@ func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {
switch value.Kind {
case ast.InlineTable:
return s.checkInlineTable(parentIdx, value)
return s.checkInlineTable(value)
case ast.Array:
return s.checkArray(parentIdx, value)
return s.checkArray(value)
}
return nil
}
func (s *SeenTracker) checkArray(parentIdx int, node *ast.Node) error {
set := false
func (s *SeenTracker) checkArray(node *ast.Node) error {
it := node.Children()
for it.Next() {
if set {
s.clear(parentIdx)
}
n := it.Node()
switch n.Kind {
case ast.InlineTable:
err := s.checkInlineTable(parentIdx, n)
err := s.checkInlineTable(n)
if err != nil {
return err
}
set = true
case ast.Array:
err := s.checkArray(parentIdx, n)
err := s.checkArray(n)
if err != nil {
return err
}
set = true
}
}
return nil
}
func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
func (s *SeenTracker) checkInlineTable(node *ast.Node) error {
if pool.New == nil {
pool.New = func() interface{} {
return &SeenTracker{}
}
}
s = pool.Get().(*SeenTracker)
s.reset()
it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(parentIdx, n)
err := s.checkKeyValue(n)
if err != nil {
return err
}
@@ -299,25 +346,6 @@ func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
// mark the presence of the inline table and prevent
// redefinition of its keys: check* functions cannot walk into
// a value.
s.clear(parentIdx)
pool.Put(s)
return nil
}
func (s *SeenTracker) id(idx int) int {
if idx >= 0 {
return s.entries[idx].id
}
return 0
}
func (s *SeenTracker) find(parentIdx int, k []byte) int {
parentID := s.id(parentIdx)
for i := parentIdx + 1; i < len(s.entries); i++ {
if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) {
return i
}
}
return -1
}