decoder: strict mode (#512)
This commit is contained in:
@@ -22,7 +22,7 @@ Development branch. Use at your own risk.
|
||||
- [x] Abstract AST.
|
||||
- [x] Original go-toml testgen tests pass.
|
||||
- [x] Track file position (line, column) for errors.
|
||||
- [ ] Strict mode.
|
||||
- [x] Strict mode.
|
||||
- [ ] Document Unmarshal / Decode
|
||||
|
||||
### Marshal
|
||||
|
||||
@@ -18,15 +18,46 @@ type DecodeError struct {
|
||||
message string
|
||||
line int
|
||||
column int
|
||||
key Key
|
||||
|
||||
human string
|
||||
}
|
||||
|
||||
// StrictMissingError occurs in a TOML document that does not have a
|
||||
// corresponding field in the target value. It contains all the missing fields
|
||||
// in Errors.
|
||||
//
|
||||
// Emitted by Decoder when SetStrict(true) was called.
|
||||
type StrictMissingError struct {
|
||||
// One error per field that could not be found.
|
||||
Errors []DecodeError
|
||||
}
|
||||
|
||||
// Error returns the cannonical string for this error.
|
||||
func (s *StrictMissingError) Error() string {
|
||||
return "strict mode: fields in the document are missing in the target struct"
|
||||
}
|
||||
|
||||
// String returns a human readable description of all errors.
|
||||
func (s *StrictMissingError) String() string {
|
||||
var buf strings.Builder
|
||||
for i, e := range s.Errors {
|
||||
if i > 0 {
|
||||
buf.WriteString("\n---\n")
|
||||
}
|
||||
buf.WriteString(e.String())
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type Key []string
|
||||
|
||||
// internal version of DecodeError that is used as the base to create a
|
||||
// DecodeError with full context.
|
||||
type decodeError struct {
|
||||
highlight []byte
|
||||
message string
|
||||
key Key // optional
|
||||
}
|
||||
|
||||
func (de *decodeError) Error() string {
|
||||
@@ -56,6 +87,11 @@ func (e *DecodeError) Position() (row int, column int) {
|
||||
return e.line, e.column
|
||||
}
|
||||
|
||||
// Key that was being processed when the error occured.
|
||||
func (e *DecodeError) Key() Key {
|
||||
return e.key
|
||||
}
|
||||
|
||||
// decodeErrorFromHighlight creates a DecodeError referencing to a highlighted
|
||||
// range of bytes from document.
|
||||
//
|
||||
@@ -64,7 +100,7 @@ func (e *DecodeError) Position() (row int, column int) {
|
||||
// The function copies all bytes used in DecodeError, so that document and
|
||||
// highlight can be freely deallocated.
|
||||
//nolint:funlen
|
||||
func wrapDecodeError(document []byte, de *decodeError) error {
|
||||
func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
|
||||
if de == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -137,6 +173,7 @@ func wrapDecodeError(document []byte, de *decodeError) error {
|
||||
message: errMessage,
|
||||
line: errLine,
|
||||
column: errColumn,
|
||||
key: de.key,
|
||||
human: buf.String(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ package imported_tests
|
||||
// marked as skipped until we figure out if that's something we want in v2.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
@@ -1955,66 +1956,80 @@ String2="2"`
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func decoder(doc string) *toml.Decoder {
|
||||
return toml.NewDecoder(bytes.NewReader([]byte(doc)))
|
||||
}
|
||||
|
||||
func strictDecoder(doc string) *toml.Decoder {
|
||||
d := decoder(doc)
|
||||
d.SetStrict(true)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestDecoderStrict(t *testing.T) {
|
||||
t.Skip()
|
||||
// input := `
|
||||
//[decoded]
|
||||
// key = ""
|
||||
//
|
||||
//[undecoded]
|
||||
// key = ""
|
||||
//
|
||||
// [undecoded.inner]
|
||||
// key = ""
|
||||
//
|
||||
// [[undecoded.array]]
|
||||
// key = ""
|
||||
//
|
||||
// [[undecoded.array]]
|
||||
// key = ""
|
||||
//
|
||||
//`
|
||||
// var doc struct {
|
||||
// Decoded struct {
|
||||
// Key string
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]`
|
||||
//
|
||||
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc)
|
||||
// if err == nil {
|
||||
// t.Error("expected error, got none")
|
||||
// } else if err.Error() != expected {
|
||||
// t.Errorf("expect err: %s, got: %s", expected, err.Error())
|
||||
// }
|
||||
//
|
||||
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil {
|
||||
// t.Errorf("unexpected err: %s", err)
|
||||
// }
|
||||
//
|
||||
// var m map[string]interface{}
|
||||
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil {
|
||||
// t.Errorf("unexpected err: %s", err)
|
||||
// }
|
||||
input := `
|
||||
[decoded]
|
||||
key = ""
|
||||
|
||||
[undecoded]
|
||||
key = ""
|
||||
|
||||
[undecoded.inner]
|
||||
key = ""
|
||||
|
||||
[[undecoded.array]]
|
||||
key = ""
|
||||
|
||||
[[undecoded.array]]
|
||||
key = ""
|
||||
|
||||
`
|
||||
var doc struct {
|
||||
Decoded struct {
|
||||
Key string
|
||||
}
|
||||
}
|
||||
|
||||
err := strictDecoder(input).Decode(&doc)
|
||||
require.Error(t, err)
|
||||
require.IsType(t, &toml.StrictMissingError{}, err)
|
||||
se := err.(*toml.StrictMissingError)
|
||||
|
||||
keys := []toml.Key{}
|
||||
|
||||
for _, e := range se.Errors {
|
||||
keys = append(keys, e.Key())
|
||||
}
|
||||
|
||||
expectedKeys := []toml.Key{
|
||||
{"undecoded"},
|
||||
{"undecoded", "inner"},
|
||||
{"undecoded", "array"},
|
||||
{"undecoded", "array"},
|
||||
}
|
||||
|
||||
require.Equal(t, expectedKeys, keys)
|
||||
|
||||
err = decoder(input).Decode(&doc)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
err = decoder(input).Decode(&m)
|
||||
}
|
||||
|
||||
func TestDecoderStrictValid(t *testing.T) {
|
||||
t.Skip()
|
||||
// input := `
|
||||
//[decoded]
|
||||
// key = ""
|
||||
//`
|
||||
// var doc struct {
|
||||
// Decoded struct {
|
||||
// Key string
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc)
|
||||
// if err != nil {
|
||||
// t.Fatal("unexpected error:", err)
|
||||
// }
|
||||
input := `
|
||||
[decoded]
|
||||
key = ""
|
||||
`
|
||||
var doc struct {
|
||||
Decoded struct {
|
||||
Key string
|
||||
}
|
||||
}
|
||||
|
||||
err := strictDecoder(input).Decode(&doc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type docUnmarshalTOML struct {
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package tracker
|
||||
|
||||
import (
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
)
|
||||
|
||||
// KeyTracker is a tracker that keeps track of the current Key as the AST is
|
||||
// walked.
|
||||
type KeyTracker struct {
|
||||
k []string
|
||||
}
|
||||
|
||||
// UpdateTable sets the state of the tracker with the AST table node.
|
||||
func (t *KeyTracker) UpdateTable(node ast.Node) {
|
||||
t.reset()
|
||||
t.Push(node)
|
||||
}
|
||||
|
||||
// UpdateArrayTable sets the state of the tracker with the AST array table node.
|
||||
func (t *KeyTracker) UpdateArrayTable(node ast.Node) {
|
||||
t.reset()
|
||||
t.Push(node)
|
||||
}
|
||||
|
||||
// Push the given key on the stack.
|
||||
func (t *KeyTracker) Push(node ast.Node) {
|
||||
it := node.Key()
|
||||
for it.Next() {
|
||||
t.k = append(t.k, string(it.Node().Data))
|
||||
}
|
||||
}
|
||||
|
||||
// Pop key from stack.
|
||||
func (t *KeyTracker) Pop(node ast.Node) {
|
||||
it := node.Key()
|
||||
for it.Next() {
|
||||
t.k = t.k[:len(t.k)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// Key returns the current key
|
||||
func (t *KeyTracker) Key() []string {
|
||||
k := make([]string, len(t.k))
|
||||
copy(k, t.k)
|
||||
return k
|
||||
}
|
||||
|
||||
func (t *KeyTracker) reset() {
|
||||
t.k = t.k[:0]
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package tracker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
)
|
||||
|
||||
type keyKind uint8
|
||||
|
||||
const (
|
||||
invalidKind keyKind = iota
|
||||
valueKind
|
||||
tableKind
|
||||
arrayTableKind
|
||||
)
|
||||
|
||||
func (k keyKind) String() string {
|
||||
switch k {
|
||||
case invalidKind:
|
||||
return "invalid"
|
||||
case valueKind:
|
||||
return "value"
|
||||
case tableKind:
|
||||
return "table"
|
||||
case arrayTableKind:
|
||||
return "array table"
|
||||
}
|
||||
panic("missing keyKind string mapping")
|
||||
}
|
||||
|
||||
// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates
|
||||
// and mismatches according to the spec.
|
||||
type SeenTracker struct {
|
||||
root *info
|
||||
current *info
|
||||
}
|
||||
|
||||
type info struct {
|
||||
parent *info
|
||||
kind keyKind
|
||||
children map[string]*info
|
||||
explicit bool
|
||||
}
|
||||
|
||||
func (i *info) Clear() {
|
||||
i.children = nil
|
||||
}
|
||||
|
||||
func (i *info) Has(k string) (*info, bool) {
|
||||
c, ok := i.children[k]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
func (i *info) SetKind(kind keyKind) {
|
||||
i.kind = kind
|
||||
}
|
||||
|
||||
func (i *info) CreateTable(k string, explicit bool) *info {
|
||||
return i.createChild(k, tableKind, explicit)
|
||||
}
|
||||
|
||||
func (i *info) CreateArrayTable(k string, explicit bool) *info {
|
||||
return i.createChild(k, arrayTableKind, explicit)
|
||||
}
|
||||
|
||||
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
|
||||
if i.children == nil {
|
||||
i.children = make(map[string]*info, 1)
|
||||
}
|
||||
|
||||
x := &info{
|
||||
parent: i,
|
||||
kind: kind,
|
||||
explicit: explicit,
|
||||
}
|
||||
i.children[k] = x
|
||||
return x
|
||||
}
|
||||
|
||||
// CheckExpression takes a top-level node and checks that it does not contain keys
|
||||
// that have been seen in previous calls, and validates that types are consistent.
|
||||
func (s *SeenTracker) CheckExpression(node ast.Node) error {
|
||||
if s.root == nil {
|
||||
s.root = &info{
|
||||
kind: tableKind,
|
||||
}
|
||||
s.current = s.root
|
||||
}
|
||||
switch node.Kind {
|
||||
case ast.KeyValue:
|
||||
return s.checkKeyValue(s.current, node)
|
||||
case ast.Table:
|
||||
return s.checkTable(node)
|
||||
case ast.ArrayTable:
|
||||
return s.checkArrayTable(node)
|
||||
default:
|
||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||
}
|
||||
|
||||
}
|
||||
func (s *SeenTracker) checkTable(node ast.Node) error {
|
||||
s.current = s.root
|
||||
|
||||
it := node.Key()
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
if !it.Node().Next().Valid() {
|
||||
break
|
||||
}
|
||||
|
||||
k := string(it.Node().Data)
|
||||
child, found := s.current.Has(k)
|
||||
if !found {
|
||||
child = s.current.CreateTable(k, false)
|
||||
}
|
||||
s.current = child
|
||||
}
|
||||
|
||||
// handle the last part of the key
|
||||
k := string(it.Node().Data)
|
||||
|
||||
i, found := s.current.Has(k)
|
||||
if found {
|
||||
if i.kind != tableKind {
|
||||
return fmt.Errorf("key %s should be a table", k)
|
||||
}
|
||||
if i.explicit {
|
||||
return fmt.Errorf("table %s already exists", k)
|
||||
}
|
||||
i.explicit = true
|
||||
s.current = i
|
||||
} else {
|
||||
s.current = s.current.CreateTable(k, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkArrayTable(node ast.Node) error {
|
||||
s.current = s.root
|
||||
|
||||
it := node.Key()
|
||||
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
if !it.Node().Next().Valid() {
|
||||
break
|
||||
}
|
||||
|
||||
k := string(it.Node().Data)
|
||||
child, found := s.current.Has(k)
|
||||
if !found {
|
||||
child = s.current.CreateTable(k, false)
|
||||
}
|
||||
s.current = child
|
||||
}
|
||||
|
||||
// handle the last part of the key
|
||||
k := string(it.Node().Data)
|
||||
|
||||
info, found := s.current.Has(k)
|
||||
if found {
|
||||
if info.kind != arrayTableKind {
|
||||
return fmt.Errorf("key %s already exists but is not an array table", k)
|
||||
}
|
||||
info.Clear()
|
||||
} else {
|
||||
info = s.current.CreateArrayTable(k, true)
|
||||
}
|
||||
|
||||
s.current = info
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
|
||||
it := node.Key()
|
||||
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
k := string(it.Node().Data)
|
||||
child, found := context.Has(k)
|
||||
if found {
|
||||
if child.kind != tableKind {
|
||||
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
|
||||
}
|
||||
} else {
|
||||
child = context.CreateTable(k, false)
|
||||
}
|
||||
context = child
|
||||
}
|
||||
|
||||
if node.Value().Kind == ast.InlineTable {
|
||||
context.SetKind(tableKind)
|
||||
} else {
|
||||
context.SetKind(valueKind)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,200 +1 @@
|
||||
package tracker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
)
|
||||
|
||||
type keyKind uint8
|
||||
|
||||
const (
|
||||
invalidKind keyKind = iota
|
||||
valueKind
|
||||
tableKind
|
||||
arrayTableKind
|
||||
)
|
||||
|
||||
func (k keyKind) String() string {
|
||||
switch k {
|
||||
case invalidKind:
|
||||
return "invalid"
|
||||
case valueKind:
|
||||
return "value"
|
||||
case tableKind:
|
||||
return "table"
|
||||
case arrayTableKind:
|
||||
return "array table"
|
||||
}
|
||||
panic("missing keyKind string mapping")
|
||||
}
|
||||
|
||||
// Tracks which keys have been seen with which TOML type to flag duplicates
|
||||
// and mismatches according to the spec.
|
||||
type Seen struct {
|
||||
root *info
|
||||
current *info
|
||||
}
|
||||
|
||||
type info struct {
|
||||
parent *info
|
||||
kind keyKind
|
||||
children map[string]*info
|
||||
explicit bool
|
||||
}
|
||||
|
||||
func (i *info) Clear() {
|
||||
i.children = nil
|
||||
}
|
||||
|
||||
func (i *info) Has(k string) (*info, bool) {
|
||||
c, ok := i.children[k]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
func (i *info) SetKind(kind keyKind) {
|
||||
i.kind = kind
|
||||
}
|
||||
|
||||
func (i *info) CreateTable(k string, explicit bool) *info {
|
||||
return i.createChild(k, tableKind, explicit)
|
||||
}
|
||||
|
||||
func (i *info) CreateArrayTable(k string, explicit bool) *info {
|
||||
return i.createChild(k, arrayTableKind, explicit)
|
||||
}
|
||||
|
||||
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
|
||||
if i.children == nil {
|
||||
i.children = make(map[string]*info, 1)
|
||||
}
|
||||
|
||||
x := &info{
|
||||
parent: i,
|
||||
kind: kind,
|
||||
explicit: explicit,
|
||||
}
|
||||
i.children[k] = x
|
||||
return x
|
||||
}
|
||||
|
||||
// CheckExpression takes a top-level node and checks that it does not contain keys
|
||||
// that have been seen in previous calls, and validates that types are consistent.
|
||||
func (s *Seen) CheckExpression(node ast.Node) error {
|
||||
if s.root == nil {
|
||||
s.root = &info{
|
||||
kind: tableKind,
|
||||
}
|
||||
s.current = s.root
|
||||
}
|
||||
switch node.Kind {
|
||||
case ast.KeyValue:
|
||||
return s.checkKeyValue(s.current, node)
|
||||
case ast.Table:
|
||||
return s.checkTable(node)
|
||||
case ast.ArrayTable:
|
||||
return s.checkArrayTable(node)
|
||||
default:
|
||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||
}
|
||||
|
||||
}
|
||||
func (s *Seen) checkTable(node ast.Node) error {
|
||||
s.current = s.root
|
||||
|
||||
it := node.Key()
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
if !it.Node().Next().Valid() {
|
||||
break
|
||||
}
|
||||
|
||||
k := string(it.Node().Data)
|
||||
child, found := s.current.Has(k)
|
||||
if !found {
|
||||
child = s.current.CreateTable(k, false)
|
||||
}
|
||||
s.current = child
|
||||
}
|
||||
|
||||
// handle the last part of the key
|
||||
k := string(it.Node().Data)
|
||||
|
||||
i, found := s.current.Has(k)
|
||||
if found {
|
||||
if i.kind != tableKind {
|
||||
return fmt.Errorf("key %s should be a table", k)
|
||||
}
|
||||
if i.explicit {
|
||||
return fmt.Errorf("table %s already exists", k)
|
||||
}
|
||||
i.explicit = true
|
||||
s.current = i
|
||||
} else {
|
||||
s.current = s.current.CreateTable(k, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Seen) checkArrayTable(node ast.Node) error {
|
||||
s.current = s.root
|
||||
|
||||
it := node.Key()
|
||||
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
if !it.Node().Next().Valid() {
|
||||
break
|
||||
}
|
||||
|
||||
k := string(it.Node().Data)
|
||||
child, found := s.current.Has(k)
|
||||
if !found {
|
||||
child = s.current.CreateTable(k, false)
|
||||
}
|
||||
s.current = child
|
||||
}
|
||||
|
||||
// handle the last part of the key
|
||||
k := string(it.Node().Data)
|
||||
|
||||
info, found := s.current.Has(k)
|
||||
if found {
|
||||
if info.kind != arrayTableKind {
|
||||
return fmt.Errorf("key %s already exists but is not an array table", k)
|
||||
}
|
||||
info.Clear()
|
||||
} else {
|
||||
info = s.current.CreateArrayTable(k, true)
|
||||
}
|
||||
|
||||
s.current = info
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Seen) checkKeyValue(context *info, node ast.Node) error {
|
||||
it := node.Key()
|
||||
|
||||
// handle the first parts of the key, excluding the last one
|
||||
for it.Next() {
|
||||
k := string(it.Node().Data)
|
||||
child, found := context.Has(k)
|
||||
if found {
|
||||
if child.kind != tableKind {
|
||||
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
|
||||
}
|
||||
} else {
|
||||
child = context.CreateTable(k, false)
|
||||
}
|
||||
context = child
|
||||
}
|
||||
|
||||
if node.Value().Kind == ast.InlineTable {
|
||||
context.SetKind(tableKind)
|
||||
} else {
|
||||
context.SetKind(valueKind)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,3 +33,27 @@ func SubsliceOffset(data []byte, subslice []byte) int {
|
||||
|
||||
return intoffset
|
||||
}
|
||||
|
||||
func BytesRange(start []byte, end []byte) []byte {
|
||||
if start == nil || end == nil {
|
||||
panic("cannot call BytesRange with nil")
|
||||
}
|
||||
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
|
||||
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
|
||||
|
||||
if startp.Data > endp.Data {
|
||||
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
|
||||
}
|
||||
|
||||
l := startp.Len
|
||||
endLen := int(endp.Data-startp.Data) + endp.Len
|
||||
if endLen > l {
|
||||
l = endLen
|
||||
}
|
||||
|
||||
if l > startp.Cap {
|
||||
panic(fmt.Errorf("range length is larger than capacity"))
|
||||
}
|
||||
|
||||
return start[:l]
|
||||
}
|
||||
|
||||
@@ -77,3 +77,92 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsafeBytesRange(t *testing.T) {
|
||||
type fn = func() ([]byte, []byte)
|
||||
examples := []struct {
|
||||
desc string
|
||||
test fn
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
desc: "simple",
|
||||
test: func() ([]byte, []byte) {
|
||||
full := []byte("hello world")
|
||||
return full[1:3], full[6:8]
|
||||
},
|
||||
expected: []byte("ello wo"),
|
||||
},
|
||||
{
|
||||
desc: "full",
|
||||
test: func() ([]byte, []byte) {
|
||||
full := []byte("hello world")
|
||||
return full[0:1], full[len(full)-1:]
|
||||
},
|
||||
expected: []byte("hello world"),
|
||||
},
|
||||
{
|
||||
desc: "end before start",
|
||||
test: func() ([]byte, []byte) {
|
||||
full := []byte("hello world")
|
||||
return full[len(full)-1:], full[0:1]
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "nils",
|
||||
test: func() ([]byte, []byte) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "nils start",
|
||||
test: func() ([]byte, []byte) {
|
||||
return nil, []byte("foo")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "nils end",
|
||||
test: func() ([]byte, []byte) {
|
||||
return []byte("foo"), nil
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "start is end",
|
||||
test: func() ([]byte, []byte) {
|
||||
full := []byte("hello world")
|
||||
return full[1:3], full[1:3]
|
||||
},
|
||||
expected: []byte("el"),
|
||||
},
|
||||
{
|
||||
desc: "end contained in start",
|
||||
test: func() ([]byte, []byte) {
|
||||
full := []byte("hello world")
|
||||
return full[1:7], full[2:4]
|
||||
},
|
||||
expected: []byte("ello w"),
|
||||
},
|
||||
{
|
||||
desc: "different backing arrays",
|
||||
test: func() ([]byte, []byte) {
|
||||
one := []byte("hello world")
|
||||
two := []byte("hello world")
|
||||
return one, two
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
start, end := e.test()
|
||||
if e.expected == nil {
|
||||
require.Panics(t, func() {
|
||||
unsafe.BytesRange(start, end)
|
||||
})
|
||||
} else {
|
||||
res := unsafe.BytesRange(start, end)
|
||||
require.Equal(t, e.expected, res)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package toml
|
||||
|
||||
import (
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
"github.com/pelletier/go-toml/v2/internal/tracker"
|
||||
)
|
||||
|
||||
type strict struct {
|
||||
Enabled bool
|
||||
|
||||
// Tracks the current key being processed.
|
||||
key tracker.KeyTracker
|
||||
|
||||
missing []decodeError
|
||||
}
|
||||
|
||||
func (s *strict) EnterTable(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.key.UpdateTable(node)
|
||||
}
|
||||
|
||||
func (s *strict) EnterArrayTable(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.key.UpdateArrayTable(node)
|
||||
}
|
||||
|
||||
func (s *strict) EnterKeyValue(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.key.Push(node)
|
||||
}
|
||||
|
||||
func (s *strict) ExitKeyValue(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.key.Pop(node)
|
||||
}
|
||||
|
||||
func (s *strict) MissingTable(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.missing = append(s.missing, decodeError{
|
||||
highlight: keyLocation(node),
|
||||
message: "missing table",
|
||||
key: s.key.Key(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *strict) MissingField(node ast.Node) {
|
||||
if !s.Enabled {
|
||||
return
|
||||
}
|
||||
s.missing = append(s.missing, decodeError{
|
||||
highlight: keyLocation(node),
|
||||
message: "missing field",
|
||||
key: s.key.Key(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *strict) Error(doc []byte) error {
|
||||
if !s.Enabled || len(s.missing) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := &StrictMissingError{
|
||||
Errors: make([]DecodeError, 0, len(s.missing)),
|
||||
}
|
||||
for _, derr := range s.missing {
|
||||
err.Errors = append(err.Errors, *wrapDecodeError(doc, &derr))
|
||||
}
|
||||
return err
|
||||
}
|
||||
+49
-2
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
"github.com/pelletier/go-toml/v2/internal/tracker"
|
||||
"github.com/pelletier/go-toml/v2/internal/unsafe"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
@@ -21,7 +22,11 @@ func Unmarshal(data []byte, v interface{}) error {
|
||||
|
||||
// Decoder reads and decode a TOML document from an input stream.
|
||||
type Decoder struct {
|
||||
// input
|
||||
r io.Reader
|
||||
|
||||
// global settings
|
||||
strict bool
|
||||
}
|
||||
|
||||
// NewDecoder creates a new Decoder that will read from r.
|
||||
@@ -29,6 +34,16 @@ func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
}
|
||||
|
||||
// SetStrict toggles decoding in stict mode.
|
||||
//
|
||||
// When the decoder is in strict mode, it will record fields from the document
|
||||
// that could not be set on the target value. In that case, the decoder returns
|
||||
// a StrictMissingError that can be used to retrieve the individual errors as
|
||||
// well as generate a human readable description of the missing fields.
|
||||
func (d *Decoder) SetStrict(strict bool) {
|
||||
d.strict = strict
|
||||
}
|
||||
|
||||
// Decode the whole content of r into v.
|
||||
//
|
||||
// When a TOML local date is decoded into a time.Time, its value is represented
|
||||
@@ -43,7 +58,11 @@ func (d *Decoder) Decode(v interface{}) error {
|
||||
}
|
||||
p := parser{}
|
||||
p.Reset(b)
|
||||
dec := decoder{}
|
||||
dec := decoder{
|
||||
strict: strict{
|
||||
Enabled: d.strict,
|
||||
},
|
||||
}
|
||||
return dec.FromParser(&p, v)
|
||||
}
|
||||
|
||||
@@ -52,7 +71,10 @@ type decoder struct {
|
||||
arrayIndexes map[reflect.Value]int
|
||||
|
||||
// Tracks keys that have been seen, with which type.
|
||||
seen tracker.Seen
|
||||
seen tracker.SeenTracker
|
||||
|
||||
// Strict mode
|
||||
strict strict
|
||||
}
|
||||
|
||||
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
|
||||
@@ -79,9 +101,27 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
|
||||
err = wrapDecodeError(p.data, de)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
err = d.strict.Error(p.data)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func keyLocation(node ast.Node) []byte {
|
||||
k := node.Key()
|
||||
hasOne := k.Next()
|
||||
if !hasOne {
|
||||
panic("should not be called with empty key")
|
||||
}
|
||||
start := k.Node().Data
|
||||
end := k.Node().Data
|
||||
for k.Next() {
|
||||
end = k.Node().Data
|
||||
}
|
||||
return unsafe.BytesRange(start, end)
|
||||
}
|
||||
|
||||
func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
r := reflect.ValueOf(v)
|
||||
if r.Kind() != reflect.Ptr {
|
||||
@@ -113,6 +153,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
err = d.unmarshalKeyValue(current, node)
|
||||
found = true
|
||||
case ast.Table:
|
||||
d.strict.EnterTable(node)
|
||||
current, found, err = d.scopeWithKey(root, node.Key())
|
||||
if err == nil && found {
|
||||
// In case this table points to an interface,
|
||||
@@ -123,6 +164,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
ensureMapIfInterface(current)
|
||||
}
|
||||
case ast.ArrayTable:
|
||||
d.strict.EnterArrayTable(node)
|
||||
current, found, err = d.scopeWithArrayTable(root, node.Key())
|
||||
default:
|
||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||
@@ -134,6 +176,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
|
||||
if !found {
|
||||
skipUntilTable = true
|
||||
d.strict.MissingTable(node)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +260,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
|
||||
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
|
||||
assertNode(ast.KeyValue, node)
|
||||
|
||||
d.strict.EnterKeyValue(node)
|
||||
defer d.strict.ExitKeyValue(node)
|
||||
|
||||
x, found, err := d.scopeWithKey(x, node.Key())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -224,6 +270,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
|
||||
|
||||
// A struct in the path was not found. Skip this value.
|
||||
if !found {
|
||||
d.strict.MissingField(node)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package toml_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -989,3 +991,115 @@ func TestIssue508(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "This is a title", t1.head.Title)
|
||||
}
|
||||
|
||||
func TestDecoderStrict(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
input string
|
||||
expected string
|
||||
target interface{}
|
||||
}{
|
||||
{
|
||||
desc: "multiple missing root keys",
|
||||
input: `
|
||||
key1 = "value1"
|
||||
key2 = "missing2"
|
||||
key3 = "missing3"
|
||||
key4 = "value4"
|
||||
`,
|
||||
expected: `
|
||||
2| key1 = "value1"
|
||||
3| key2 = "missing2"
|
||||
| ~~~~ missing field
|
||||
4| key3 = "missing3"
|
||||
5| key4 = "value4"
|
||||
---
|
||||
2| key1 = "value1"
|
||||
3| key2 = "missing2"
|
||||
4| key3 = "missing3"
|
||||
| ~~~~ missing field
|
||||
5| key4 = "value4"
|
||||
`,
|
||||
target: &struct {
|
||||
Key1 string
|
||||
Key4 string
|
||||
}{},
|
||||
},
|
||||
{
|
||||
desc: "multi-part key",
|
||||
input: `a.short.key="foo"`,
|
||||
expected: `
|
||||
1| a.short.key="foo"
|
||||
| ~~~~~~~~~~~ missing field
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: "missing table",
|
||||
input: `
|
||||
[foo]
|
||||
bar = 42
|
||||
`,
|
||||
expected: `
|
||||
2| [foo]
|
||||
| ~~~ missing table
|
||||
3| bar = 42
|
||||
`,
|
||||
},
|
||||
|
||||
{
|
||||
desc: "missing array table",
|
||||
input: `
|
||||
[[foo]]
|
||||
bar = 42
|
||||
`,
|
||||
expected: `
|
||||
2| [[foo]]
|
||||
| ~~~ missing table
|
||||
3| bar = 42
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
r := strings.NewReader(e.input)
|
||||
d := toml.NewDecoder(r)
|
||||
d.SetStrict(true)
|
||||
x := e.target
|
||||
if x == nil {
|
||||
x = &struct{}{}
|
||||
}
|
||||
err := d.Decode(x)
|
||||
details := err.(*toml.StrictMissingError)
|
||||
equalStringsIgnoreNewlines(t, e.expected, details.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleDecoder_SetStrict() {
|
||||
type S struct {
|
||||
Key1 string
|
||||
Key3 string
|
||||
}
|
||||
doc := `
|
||||
key1 = "value1"
|
||||
key2 = "value2"
|
||||
key3 = "value3"
|
||||
`
|
||||
r := strings.NewReader(doc)
|
||||
d := toml.NewDecoder(r)
|
||||
d.SetStrict(true)
|
||||
s := S{}
|
||||
err := d.Decode(&s)
|
||||
|
||||
fmt.Println(err.Error())
|
||||
// Output: strict mode: fields in the document are missing in the target struct
|
||||
|
||||
details := err.(*toml.StrictMissingError)
|
||||
fmt.Println(details.String())
|
||||
// Ouput:
|
||||
// 2| key1 = "value1"
|
||||
// 3| key2 = "value2"
|
||||
// | ~~~~ missing field
|
||||
// 4| key3 = "value3"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user