unmarshal: add strict mode (#372)

This PR adds a strict mode to the Decoder. It can be enabled with the
`Strict` method.

In the strict mode, the decoder fails if any fields that were part
of the input do not have a corresponding field in the struct.

Fixes #277
This commit is contained in:
Oncilla
2020-04-25 13:58:55 +02:00
committed by Thomas Pelletier
parent d1e0fc37ce
commit d3c92c5999
2 changed files with 153 additions and 1 deletions
+93 -1
View File
@@ -543,6 +543,8 @@ type Decoder struct {
tval *Tree
encOpts
tagName string
strict bool
visitor visitorState
}
// NewDecoder returns a new decoder that reads from r.
@@ -573,6 +575,13 @@ func (d *Decoder) SetTagName(v string) *Decoder {
return d
}
// Strict allows changing to strict decoding. Any fields that are found in the
// input data and do not have a corresponding struct member cause an error.
func (d *Decoder) Strict(strict bool) *Decoder {
d.strict = strict
return d
}
func (d *Decoder) unmarshal(v interface{}) error {
mtype := reflect.TypeOf(v)
if mtype == nil {
@@ -596,10 +605,17 @@ func (d *Decoder) unmarshal(v interface{}) error {
vv := reflect.ValueOf(v).Elem()
if d.strict {
d.visitor = newVisitorState(d.tval)
}
sval, err := d.valueFromTree(elem, d.tval, &vv)
if err != nil {
return err
}
if err := d.visitor.validate(); err != nil {
return err
}
reflect.ValueOf(v).Elem().Set(sval)
return nil
}
@@ -645,6 +661,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
if !exists {
continue
}
d.visitor.push(key)
val := tval.Get(key)
fval := mval.Field(i)
mvalf, err := d.valueFromToml(mtypef.Type, val, &fval)
@@ -653,6 +671,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
}
mval.Field(i).Set(mvalf)
found = true
d.visitor.pop()
break
}
}
@@ -685,7 +704,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
return mval.Field(i), err
}
default:
return mval.Field(i), fmt.Errorf("unsuported field type for default option")
return mval.Field(i), fmt.Errorf("unsupported field type for default option")
}
mval.Field(i).Set(reflect.ValueOf(val))
}
@@ -707,6 +726,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
case reflect.Map:
mval = reflect.MakeMap(mtype)
for _, key := range tval.Keys() {
d.visitor.push(key)
// TODO: path splits key
val := tval.GetPath([]string{key})
mvalf, err := d.valueFromToml(mtype.Elem(), val, nil)
@@ -714,6 +734,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
return mval, formatError(err, tval.GetPosition(key))
}
mval.SetMapIndex(reflect.ValueOf(key).Convert(mtype.Key()), mvalf)
d.visitor.pop()
}
}
return mval, nil
@@ -723,11 +744,13 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
mval := reflect.MakeSlice(mtype, len(tval), len(tval))
for i := 0; i < len(tval); i++ {
d.visitor.push(strconv.Itoa(i))
val, err := d.valueFromTree(mtype.Elem(), tval[i], nil)
if err != nil {
return mval, err
}
mval.Index(i).Set(val)
d.visitor.pop()
}
return mval, nil
}
@@ -802,6 +825,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref
}
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval)
case []interface{}:
d.visitor.visit()
if isOtherSequence(mtype) {
return d.valueFromOtherSlice(mtype, t)
}
@@ -815,6 +839,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref
}
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval)
default:
d.visitor.visit()
switch mtype.Kind() {
case reflect.Bool, reflect.Struct:
val := reflect.ValueOf(tval)
@@ -991,3 +1016,70 @@ func formatError(err error, pos Position) error {
}
return fmt.Errorf("%s: %s", pos, err)
}
// visitorState keeps track of which keys were unmarshaled.
type visitorState struct {
tree *Tree
path []string
keys map[string]struct{}
active bool
}
func newVisitorState(tree *Tree) visitorState {
path, result := []string{}, map[string]struct{}{}
insertKeys(path, result, tree)
return visitorState{
tree: tree,
path: path[:0],
keys: result,
active: true,
}
}
func (s *visitorState) push(key string) {
if s.active {
s.path = append(s.path, key)
}
}
func (s *visitorState) pop() {
if s.active {
s.path = s.path[:len(s.path)-1]
}
}
func (s *visitorState) visit() {
if s.active {
delete(s.keys, strings.Join(s.path, "."))
}
}
func (s *visitorState) validate() error {
if !s.active {
return nil
}
undecoded := make([]string, 0, len(s.keys))
for key := range s.keys {
undecoded = append(undecoded, key)
}
sort.Strings(undecoded)
if len(undecoded) > 0 {
return fmt.Errorf("undecoded keys: %q", undecoded)
}
return nil
}
func insertKeys(path []string, m map[string]struct{}, tree *Tree) {
for k, v := range tree.values {
switch node := v.(type) {
case []*Tree:
for i, item := range node {
insertKeys(append(path, k, strconv.Itoa(i)), m, item)
}
case *Tree:
insertKeys(append(path, k), m, node)
case *tomlValue:
m[strings.Join(append(path, k), ".")] = struct{}{}
}
}
}