decoder: strict mode (#512)

This commit is contained in:
Thomas Pelletier
2021-04-20 21:26:22 -04:00
committed by GitHub
parent dca2103910
commit 9b67e40640
11 changed files with 715 additions and 259 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ Development branch. Use at your own risk.
- [x] Abstract AST. - [x] Abstract AST.
- [x] Original go-toml testgen tests pass. - [x] Original go-toml testgen tests pass.
- [x] Track file position (line, column) for errors. - [x] Track file position (line, column) for errors.
- [ ] Strict mode. - [x] Strict mode.
- [ ] Document Unmarshal / Decode - [ ] Document Unmarshal / Decode
### Marshal ### Marshal
+38 -1
View File
@@ -18,15 +18,46 @@ type DecodeError struct {
message string message string
line int line int
column int column int
key Key
human string human string
} }
// StrictMissingError occurs in a TOML document that does not have a
// corresponding field in the target value. It contains all the missing fields
// in Errors.
//
// Emitted by Decoder when SetStrict(true) was called.
type StrictMissingError struct {
// One error per field that could not be found.
Errors []DecodeError
}
// Error returns the cannonical string for this error.
func (s *StrictMissingError) Error() string {
return "strict mode: fields in the document are missing in the target struct"
}
// String returns a human readable description of all errors.
func (s *StrictMissingError) String() string {
var buf strings.Builder
for i, e := range s.Errors {
if i > 0 {
buf.WriteString("\n---\n")
}
buf.WriteString(e.String())
}
return buf.String()
}
type Key []string
// internal version of DecodeError that is used as the base to create a // internal version of DecodeError that is used as the base to create a
// DecodeError with full context. // DecodeError with full context.
type decodeError struct { type decodeError struct {
highlight []byte highlight []byte
message string message string
key Key // optional
} }
func (de *decodeError) Error() string { func (de *decodeError) Error() string {
@@ -56,6 +87,11 @@ func (e *DecodeError) Position() (row int, column int) {
return e.line, e.column return e.line, e.column
} }
// Key that was being processed when the error occured.
func (e *DecodeError) Key() Key {
return e.key
}
// decodeErrorFromHighlight creates a DecodeError referencing to a highlighted // decodeErrorFromHighlight creates a DecodeError referencing to a highlighted
// range of bytes from document. // range of bytes from document.
// //
@@ -64,7 +100,7 @@ func (e *DecodeError) Position() (row int, column int) {
// The function copies all bytes used in DecodeError, so that document and // The function copies all bytes used in DecodeError, so that document and
// highlight can be freely deallocated. // highlight can be freely deallocated.
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) error { func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
if de == nil { if de == nil {
return nil return nil
} }
@@ -137,6 +173,7 @@ func wrapDecodeError(document []byte, de *decodeError) error {
message: errMessage, message: errMessage,
line: errLine, line: errLine,
column: errColumn, column: errColumn,
key: de.key,
human: buf.String(), human: buf.String(),
} }
} }
@@ -7,6 +7,7 @@ package imported_tests
// marked as skipped until we figure out if that's something we want in v2. // marked as skipped until we figure out if that's something we want in v2.
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -1955,66 +1956,80 @@ String2="2"`
assert.Error(t, err) assert.Error(t, err)
} }
func decoder(doc string) *toml.Decoder {
return toml.NewDecoder(bytes.NewReader([]byte(doc)))
}
func strictDecoder(doc string) *toml.Decoder {
d := decoder(doc)
d.SetStrict(true)
return d
}
func TestDecoderStrict(t *testing.T) { func TestDecoderStrict(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = ""
// [undecoded]
//[undecoded] key = ""
// key = ""
// [undecoded.inner]
// [undecoded.inner] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]` require.Error(t, err)
// require.IsType(t, &toml.StrictMissingError{}, err)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) se := err.(*toml.StrictMissingError)
// if err == nil {
// t.Error("expected error, got none") keys := []toml.Key{}
// } else if err.Error() != expected {
// t.Errorf("expect err: %s, got: %s", expected, err.Error()) for _, e := range se.Errors {
// } keys = append(keys, e.Key())
// }
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil {
// t.Errorf("unexpected err: %s", err) expectedKeys := []toml.Key{
// } {"undecoded"},
// {"undecoded", "inner"},
// var m map[string]interface{} {"undecoded", "array"},
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil { {"undecoded", "array"},
// t.Errorf("unexpected err: %s", err) }
// }
require.Equal(t, expectedKeys, keys)
err = decoder(input).Decode(&doc)
require.NoError(t, err)
var m map[string]interface{}
err = decoder(input).Decode(&m)
} }
func TestDecoderStrictValid(t *testing.T) { func TestDecoderStrictValid(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = "" `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) require.NoError(t, err)
// if err != nil {
// t.Fatal("unexpected error:", err)
// }
} }
type docUnmarshalTOML struct { type docUnmarshalTOML struct {
+50
View File
@@ -0,0 +1,50 @@
package tracker
import (
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked.
type KeyTracker struct {
k []string
}
// UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node ast.Node) {
t.reset()
t.Push(node)
}
// UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node ast.Node) {
t.reset()
t.Push(node)
}
// Push the given key on the stack.
func (t *KeyTracker) Push(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = append(t.k, string(it.Node().Data))
}
}
// Pop key from stack.
func (t *KeyTracker) Pop(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = t.k[:len(t.k)-1]
}
}
// Key returns the current key
func (t *KeyTracker) Key() []string {
k := make([]string, len(t.k))
copy(k, t.k)
return k
}
func (t *KeyTracker) reset() {
t.k = t.k[:0]
}
+200
View File
@@ -0,0 +1,200 @@
package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type SeenTracker struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *SeenTracker) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *SeenTracker) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *SeenTracker) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
-199
View File
@@ -1,200 +1 @@
package tracker package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// Tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type Seen struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *Seen) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *Seen) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *Seen) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *Seen) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
+24
View File
@@ -33,3 +33,27 @@ func SubsliceOffset(data []byte, subslice []byte) int {
return intoffset return intoffset
} }
func BytesRange(start []byte, end []byte) []byte {
if start == nil || end == nil {
panic("cannot call BytesRange with nil")
}
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
if startp.Data > endp.Data {
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
}
l := startp.Len
endLen := int(endp.Data-startp.Data) + endp.Len
if endLen > l {
l = endLen
}
if l > startp.Cap {
panic(fmt.Errorf("range length is larger than capacity"))
}
return start[:l]
}
+89
View File
@@ -77,3 +77,92 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) {
}) })
} }
} }
func TestUnsafeBytesRange(t *testing.T) {
type fn = func() ([]byte, []byte)
examples := []struct {
desc string
test fn
expected []byte
}{
{
desc: "simple",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[6:8]
},
expected: []byte("ello wo"),
},
{
desc: "full",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[0:1], full[len(full)-1:]
},
expected: []byte("hello world"),
},
{
desc: "end before start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[len(full)-1:], full[0:1]
},
},
{
desc: "nils",
test: func() ([]byte, []byte) {
return nil, nil
},
},
{
desc: "nils start",
test: func() ([]byte, []byte) {
return nil, []byte("foo")
},
},
{
desc: "nils end",
test: func() ([]byte, []byte) {
return []byte("foo"), nil
},
},
{
desc: "start is end",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[1:3]
},
expected: []byte("el"),
},
{
desc: "end contained in start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:7], full[2:4]
},
expected: []byte("ello w"),
},
{
desc: "different backing arrays",
test: func() ([]byte, []byte) {
one := []byte("hello world")
two := []byte("hello world")
return one, two
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
start, end := e.test()
if e.expected == nil {
require.Panics(t, func() {
unsafe.BytesRange(start, end)
})
} else {
res := unsafe.BytesRange(start, end)
require.Equal(t, e.expected, res)
}
})
}
}
+79
View File
@@ -0,0 +1,79 @@
package toml
import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker"
)
type strict struct {
Enabled bool
// Tracks the current key being processed.
key tracker.KeyTracker
missing []decodeError
}
func (s *strict) EnterTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateTable(node)
}
func (s *strict) EnterArrayTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateArrayTable(node)
}
func (s *strict) EnterKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Push(node)
}
func (s *strict) ExitKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Pop(node)
}
func (s *strict) MissingTable(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing table",
key: s.key.Key(),
})
}
func (s *strict) MissingField(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing field",
key: s.key.Key(),
})
}
func (s *strict) Error(doc []byte) error {
if !s.Enabled || len(s.missing) == 0 {
return nil
}
err := &StrictMissingError{
Errors: make([]DecodeError, 0, len(s.missing)),
}
for _, derr := range s.missing {
err.Errors = append(err.Errors, *wrapDecodeError(doc, &derr))
}
return err
}
+49 -2
View File
@@ -10,6 +10,7 @@ import (
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/internal/unsafe"
) )
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
@@ -21,7 +22,11 @@ func Unmarshal(data []byte, v interface{}) error {
// Decoder reads and decode a TOML document from an input stream. // Decoder reads and decode a TOML document from an input stream.
type Decoder struct { type Decoder struct {
// input
r io.Reader r io.Reader
// global settings
strict bool
} }
// NewDecoder creates a new Decoder that will read from r. // NewDecoder creates a new Decoder that will read from r.
@@ -29,6 +34,16 @@ func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r} return &Decoder{r: r}
} }
// SetStrict toggles decoding in stict mode.
//
// When the decoder is in strict mode, it will record fields from the document
// that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) {
d.strict = strict
}
// Decode the whole content of r into v. // Decode the whole content of r into v.
// //
// When a TOML local date is decoded into a time.Time, its value is represented // When a TOML local date is decoded into a time.Time, its value is represented
@@ -43,7 +58,11 @@ func (d *Decoder) Decode(v interface{}) error {
} }
p := parser{} p := parser{}
p.Reset(b) p.Reset(b)
dec := decoder{} dec := decoder{
strict: strict{
Enabled: d.strict,
},
}
return dec.FromParser(&p, v) return dec.FromParser(&p, v)
} }
@@ -52,7 +71,10 @@ type decoder struct {
arrayIndexes map[reflect.Value]int arrayIndexes map[reflect.Value]int
// Tracks keys that have been seen, with which type. // Tracks keys that have been seen, with which type.
seen tracker.Seen seen tracker.SeenTracker
// Strict mode
strict strict
} }
func (d *decoder) arrayIndex(append bool, v reflect.Value) int { func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
@@ -79,9 +101,27 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
err = wrapDecodeError(p.data, de) err = wrapDecodeError(p.data, de)
} }
} }
if err == nil {
err = d.strict.Error(p.data)
}
return err return err
} }
func keyLocation(node ast.Node) []byte {
k := node.Key()
hasOne := k.Next()
if !hasOne {
panic("should not be called with empty key")
}
start := k.Node().Data
end := k.Node().Data
for k.Next() {
end = k.Node().Data
}
return unsafe.BytesRange(start, end)
}
func (d *decoder) fromParser(p *parser, v interface{}) error { func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v) r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr { if r.Kind() != reflect.Ptr {
@@ -113,6 +153,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
err = d.unmarshalKeyValue(current, node) err = d.unmarshalKeyValue(current, node)
found = true found = true
case ast.Table: case ast.Table:
d.strict.EnterTable(node)
current, found, err = d.scopeWithKey(root, node.Key()) current, found, err = d.scopeWithKey(root, node.Key())
if err == nil && found { if err == nil && found {
// In case this table points to an interface, // In case this table points to an interface,
@@ -123,6 +164,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
ensureMapIfInterface(current) ensureMapIfInterface(current)
} }
case ast.ArrayTable: case ast.ArrayTable:
d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key()) current, found, err = d.scopeWithArrayTable(root, node.Key())
default: default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
@@ -134,6 +176,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
if !found { if !found {
skipUntilTable = true skipUntilTable = true
d.strict.MissingTable(node)
} }
} }
@@ -217,6 +260,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error { func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
assertNode(ast.KeyValue, node) assertNode(ast.KeyValue, node)
d.strict.EnterKeyValue(node)
defer d.strict.ExitKeyValue(node)
x, found, err := d.scopeWithKey(x, node.Key()) x, found, err := d.scopeWithKey(x, node.Key())
if err != nil { if err != nil {
return err return err
@@ -224,6 +270,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
// A struct in the path was not found. Skip this value. // A struct in the path was not found. Skip this value.
if !found { if !found {
d.strict.MissingField(node)
return nil return nil
} }
+114
View File
@@ -1,8 +1,10 @@
package toml_test package toml_test
import ( import (
"fmt"
"math" "math"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -989,3 +991,115 @@ func TestIssue508(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "This is a title", t1.head.Title) require.Equal(t, "This is a title", t1.head.Title)
} }
func TestDecoderStrict(t *testing.T) {
examples := []struct {
desc string
input string
expected string
target interface{}
}{
{
desc: "multiple missing root keys",
input: `
key1 = "value1"
key2 = "missing2"
key3 = "missing3"
key4 = "value4"
`,
expected: `
2| key1 = "value1"
3| key2 = "missing2"
| ~~~~ missing field
4| key3 = "missing3"
5| key4 = "value4"
---
2| key1 = "value1"
3| key2 = "missing2"
4| key3 = "missing3"
| ~~~~ missing field
5| key4 = "value4"
`,
target: &struct {
Key1 string
Key4 string
}{},
},
{
desc: "multi-part key",
input: `a.short.key="foo"`,
expected: `
1| a.short.key="foo"
| ~~~~~~~~~~~ missing field
`,
},
{
desc: "missing table",
input: `
[foo]
bar = 42
`,
expected: `
2| [foo]
| ~~~ missing table
3| bar = 42
`,
},
{
desc: "missing array table",
input: `
[[foo]]
bar = 42
`,
expected: `
2| [[foo]]
| ~~~ missing table
3| bar = 42
`,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
r := strings.NewReader(e.input)
d := toml.NewDecoder(r)
d.SetStrict(true)
x := e.target
if x == nil {
x = &struct{}{}
}
err := d.Decode(x)
details := err.(*toml.StrictMissingError)
equalStringsIgnoreNewlines(t, e.expected, details.String())
})
}
}
func ExampleDecoder_SetStrict() {
type S struct {
Key1 string
Key3 string
}
doc := `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
r := strings.NewReader(doc)
d := toml.NewDecoder(r)
d.SetStrict(true)
s := S{}
err := d.Decode(&s)
fmt.Println(err.Error())
// Output: strict mode: fields in the document are missing in the target struct
details := err.(*toml.StrictMissingError)
fmt.Println(details.String())
// Ouput:
// 2| key1 = "value1"
// 3| key2 = "value2"
// | ~~~~ missing field
// 4| key3 = "value3"
}