wip: correctness pass on the AST
This commit is contained in:
@@ -0,0 +1,98 @@
|
|||||||
|
package tracker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||||
|
)
|
||||||
|
|
||||||
|
type keyKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
invalid keyKind = iota // also used for the root key
|
||||||
|
value
|
||||||
|
table
|
||||||
|
arrayTable
|
||||||
|
)
|
||||||
|
|
||||||
|
type key string
|
||||||
|
|
||||||
|
type builder struct {
|
||||||
|
prefix [][]byte
|
||||||
|
local [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *builder) Reset(prefix [][]byte) {
|
||||||
|
b.prefix = prefix
|
||||||
|
b.local = b.local[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes the number of bytes required to store the full key.
|
||||||
|
func (b *builder) size() int {
|
||||||
|
size := len(b.prefix) + len(b.local) - 1
|
||||||
|
for _, p := range b.prefix {
|
||||||
|
size += len(p)
|
||||||
|
}
|
||||||
|
for _, p := range b.local {
|
||||||
|
size += len(p)
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *builder) copy(firstJoin bool, from [][]byte, to []byte) int {
|
||||||
|
offset := 0
|
||||||
|
for i, p := range from {
|
||||||
|
if i > 0 || firstJoin {
|
||||||
|
to[offset] = 0x1E
|
||||||
|
offset++
|
||||||
|
}
|
||||||
|
copy(to[offset:], p)
|
||||||
|
offset += len(p)
|
||||||
|
}
|
||||||
|
return offset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *builder) MakeKey() key {
|
||||||
|
k := make([]byte, b.size())
|
||||||
|
b.copy(false, b.prefix, k)
|
||||||
|
b.copy(len(b.prefix) > 0, b.local, k)
|
||||||
|
return key(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *builder) Append(k []byte) {
|
||||||
|
b.local = append(b.local, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tracks which keys have been seen with which TOML type to flag duplicates
|
||||||
|
// and mismatches according to the spec.
|
||||||
|
type Seen struct {
|
||||||
|
keys map[key]keyKind
|
||||||
|
|
||||||
|
// scoping from the previous CheckExpression call.
|
||||||
|
current [][]byte
|
||||||
|
|
||||||
|
// key builder
|
||||||
|
builder builder
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckExpression takes a top-level node and checks that it does not contain keys
|
||||||
|
// that have been seen in previous calls, and validates that types are consistent.
|
||||||
|
func (s *Seen) CheckExpression(node ast.Node) error {
|
||||||
|
s.builder.Reset(s.current)
|
||||||
|
switch node.Kind {
|
||||||
|
case ast.KeyValue:
|
||||||
|
return s.checkKeyValue(node)
|
||||||
|
case ast.Table:
|
||||||
|
case ast.ArrayTable:
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Seen) checkKeyValue(node ast.Node) error {
|
||||||
|
it := node.Key()
|
||||||
|
for it.Next() {
|
||||||
|
s.builder.Append(it.Node().Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
+14
-4
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||||
|
"github.com/pelletier/go-toml/v2/internal/tracker"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Unmarshal(data []byte, v interface{}) error {
|
func Unmarshal(data []byte, v interface{}) error {
|
||||||
@@ -19,6 +20,9 @@ func Unmarshal(data []byte, v interface{}) error {
|
|||||||
type decoder struct {
|
type decoder struct {
|
||||||
// Tracks position in Go arrays.
|
// Tracks position in Go arrays.
|
||||||
arrayIndexes map[reflect.Value]int
|
arrayIndexes map[reflect.Value]int
|
||||||
|
|
||||||
|
// Tracks keys that have been seen, with which type.
|
||||||
|
seen tracker.Seen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
|
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
|
||||||
@@ -46,19 +50,25 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
|
|||||||
return fmt.Errorf("target pointer must be non-nil")
|
return fmt.Errorf("target pointer must be non-nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
var skipUntilTable bool
|
var skipUntilTable bool
|
||||||
var root target = valueTarget(r.Elem())
|
var root target = valueTarget(r.Elem())
|
||||||
current := root
|
current := root
|
||||||
|
|
||||||
for p.NextExpression() {
|
for p.NextExpression() {
|
||||||
node := p.Expression()
|
node := p.Expression()
|
||||||
|
|
||||||
|
if node.Kind == ast.KeyValue && skipUntilTable {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := d.seen.CheckExpression(node)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
switch node.Kind {
|
switch node.Kind {
|
||||||
case ast.KeyValue:
|
case ast.KeyValue:
|
||||||
if skipUntilTable {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = d.unmarshalKeyValue(current, node)
|
err = d.unmarshalKeyValue(current, node)
|
||||||
found = true
|
found = true
|
||||||
case ast.Table:
|
case ast.Table:
|
||||||
|
|||||||
Reference in New Issue
Block a user