decoder: strict mode (#512)
This commit is contained in:
+49
-2
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
"github.com/pelletier/go-toml/v2/internal/tracker"
|
||||
"github.com/pelletier/go-toml/v2/internal/unsafe"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
@@ -21,7 +22,11 @@ func Unmarshal(data []byte, v interface{}) error {
|
||||
|
||||
// Decoder reads and decode a TOML document from an input stream.
|
||||
type Decoder struct {
|
||||
// input
|
||||
r io.Reader
|
||||
|
||||
// global settings
|
||||
strict bool
|
||||
}
|
||||
|
||||
// NewDecoder creates a new Decoder that will read from r.
|
||||
@@ -29,6 +34,16 @@ func NewDecoder(r io.Reader) *Decoder {
|
||||
return &Decoder{r: r}
|
||||
}
|
||||
|
||||
// SetStrict toggles decoding in stict mode.
|
||||
//
|
||||
// When the decoder is in strict mode, it will record fields from the document
|
||||
// that could not be set on the target value. In that case, the decoder returns
|
||||
// a StrictMissingError that can be used to retrieve the individual errors as
|
||||
// well as generate a human readable description of the missing fields.
|
||||
func (d *Decoder) SetStrict(strict bool) {
|
||||
d.strict = strict
|
||||
}
|
||||
|
||||
// Decode the whole content of r into v.
|
||||
//
|
||||
// When a TOML local date is decoded into a time.Time, its value is represented
|
||||
@@ -43,7 +58,11 @@ func (d *Decoder) Decode(v interface{}) error {
|
||||
}
|
||||
p := parser{}
|
||||
p.Reset(b)
|
||||
dec := decoder{}
|
||||
dec := decoder{
|
||||
strict: strict{
|
||||
Enabled: d.strict,
|
||||
},
|
||||
}
|
||||
return dec.FromParser(&p, v)
|
||||
}
|
||||
|
||||
@@ -52,7 +71,10 @@ type decoder struct {
|
||||
arrayIndexes map[reflect.Value]int
|
||||
|
||||
// Tracks keys that have been seen, with which type.
|
||||
seen tracker.Seen
|
||||
seen tracker.SeenTracker
|
||||
|
||||
// Strict mode
|
||||
strict strict
|
||||
}
|
||||
|
||||
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
|
||||
@@ -79,9 +101,27 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
|
||||
err = wrapDecodeError(p.data, de)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
err = d.strict.Error(p.data)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func keyLocation(node ast.Node) []byte {
|
||||
k := node.Key()
|
||||
hasOne := k.Next()
|
||||
if !hasOne {
|
||||
panic("should not be called with empty key")
|
||||
}
|
||||
start := k.Node().Data
|
||||
end := k.Node().Data
|
||||
for k.Next() {
|
||||
end = k.Node().Data
|
||||
}
|
||||
return unsafe.BytesRange(start, end)
|
||||
}
|
||||
|
||||
func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
r := reflect.ValueOf(v)
|
||||
if r.Kind() != reflect.Ptr {
|
||||
@@ -113,6 +153,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
err = d.unmarshalKeyValue(current, node)
|
||||
found = true
|
||||
case ast.Table:
|
||||
d.strict.EnterTable(node)
|
||||
current, found, err = d.scopeWithKey(root, node.Key())
|
||||
if err == nil && found {
|
||||
// In case this table points to an interface,
|
||||
@@ -123,6 +164,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
ensureMapIfInterface(current)
|
||||
}
|
||||
case ast.ArrayTable:
|
||||
d.strict.EnterArrayTable(node)
|
||||
current, found, err = d.scopeWithArrayTable(root, node.Key())
|
||||
default:
|
||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||
@@ -134,6 +176,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
|
||||
|
||||
if !found {
|
||||
skipUntilTable = true
|
||||
d.strict.MissingTable(node)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +260,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
|
||||
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
|
||||
assertNode(ast.KeyValue, node)
|
||||
|
||||
d.strict.EnterKeyValue(node)
|
||||
defer d.strict.ExitKeyValue(node)
|
||||
|
||||
x, found, err := d.scopeWithKey(x, node.Key())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -224,6 +270,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
|
||||
|
||||
// A struct in the path was not found. Skip this value.
|
||||
if !found {
|
||||
d.strict.MissingField(node)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user