From d3c92c5999876ef37f03850546c93238a7841d5f Mon Sep 17 00:00:00 2001 From: Oncilla Date: Sat, 25 Apr 2020 13:58:55 +0200 Subject: [PATCH] unmarshal: add strict mode (#372) This PR adds a strict mode to the Decoder. It can be enabled with the `Strict` method. In the strict mode, the decoder fails if any fields that were part of the input do not have a corresponding field in the struct. Fixes #277 --- marshal.go | 94 ++++++++++++++++++++++++++++++++++++++++++++++++- marshal_test.go | 60 +++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/marshal.go b/marshal.go index 0832630..1045d3b 100644 --- a/marshal.go +++ b/marshal.go @@ -543,6 +543,8 @@ type Decoder struct { tval *Tree encOpts tagName string + strict bool + visitor visitorState } // NewDecoder returns a new decoder that reads from r. @@ -573,6 +575,13 @@ func (d *Decoder) SetTagName(v string) *Decoder { return d } +// Strict allows changing to strict decoding. Any fields that are found in the +// input data and do not have a corresponding struct member cause an error. +func (d *Decoder) Strict(strict bool) *Decoder { + d.strict = strict + return d +} + func (d *Decoder) unmarshal(v interface{}) error { mtype := reflect.TypeOf(v) if mtype == nil { @@ -596,10 +605,17 @@ func (d *Decoder) unmarshal(v interface{}) error { vv := reflect.ValueOf(v).Elem() + if d.strict { + d.visitor = newVisitorState(d.tval) + } + sval, err := d.valueFromTree(elem, d.tval, &vv) if err != nil { return err } + if err := d.visitor.validate(); err != nil { + return err + } reflect.ValueOf(v).Elem().Set(sval) return nil } @@ -645,6 +661,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V if !exists { continue } + + d.visitor.push(key) val := tval.Get(key) fval := mval.Field(i) mvalf, err := d.valueFromToml(mtypef.Type, val, &fval) @@ -653,6 +671,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V } mval.Field(i).Set(mvalf) found = true + d.visitor.pop() break } } @@ -685,7 +704,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V return mval.Field(i), err } default: - return mval.Field(i), fmt.Errorf("unsuported field type for default option") + return mval.Field(i), fmt.Errorf("unsupported field type for default option") } mval.Field(i).Set(reflect.ValueOf(val)) } @@ -707,6 +726,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V case reflect.Map: mval = reflect.MakeMap(mtype) for _, key := range tval.Keys() { + d.visitor.push(key) // TODO: path splits key val := tval.GetPath([]string{key}) mvalf, err := d.valueFromToml(mtype.Elem(), val, nil) @@ -714,6 +734,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V return mval, formatError(err, tval.GetPosition(key)) } mval.SetMapIndex(reflect.ValueOf(key).Convert(mtype.Key()), mvalf) + d.visitor.pop() } } return mval, nil @@ -723,11 +744,13 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) { mval := reflect.MakeSlice(mtype, len(tval), len(tval)) for i := 0; i < len(tval); i++ { + d.visitor.push(strconv.Itoa(i)) val, err := d.valueFromTree(mtype.Elem(), tval[i], nil) if err != nil { return mval, err } mval.Index(i).Set(val) + d.visitor.pop() } return mval, nil } @@ -802,6 +825,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval) case []interface{}: + d.visitor.visit() if isOtherSequence(mtype) { return d.valueFromOtherSlice(mtype, t) } @@ -815,6 +839,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: + d.visitor.visit() switch mtype.Kind() { case reflect.Bool, reflect.Struct: val := reflect.ValueOf(tval) @@ -991,3 +1016,70 @@ func formatError(err error, pos Position) error { } return fmt.Errorf("%s: %s", pos, err) } + +// visitorState keeps track of which keys were unmarshaled. +type visitorState struct { + tree *Tree + path []string + keys map[string]struct{} + active bool +} + +func newVisitorState(tree *Tree) visitorState { + path, result := []string{}, map[string]struct{}{} + insertKeys(path, result, tree) + return visitorState{ + tree: tree, + path: path[:0], + keys: result, + active: true, + } +} + +func (s *visitorState) push(key string) { + if s.active { + s.path = append(s.path, key) + } +} + +func (s *visitorState) pop() { + if s.active { + s.path = s.path[:len(s.path)-1] + } +} + +func (s *visitorState) visit() { + if s.active { + delete(s.keys, strings.Join(s.path, ".")) + } +} + +func (s *visitorState) validate() error { + if !s.active { + return nil + } + undecoded := make([]string, 0, len(s.keys)) + for key := range s.keys { + undecoded = append(undecoded, key) + } + sort.Strings(undecoded) + if len(undecoded) > 0 { + return fmt.Errorf("undecoded keys: %q", undecoded) + } + return nil +} + +func insertKeys(path []string, m map[string]struct{}, tree *Tree) { + for k, v := range tree.values { + switch node := v.(type) { + case []*Tree: + for i, item := range node { + insertKeys(append(path, k, strconv.Itoa(i)), m, item) + } + case *Tree: + insertKeys(append(path, k), m, node) + case *tomlValue: + m[strings.Join(append(path, k), ".")] = struct{}{} + } + } +} diff --git a/marshal_test.go b/marshal_test.go index 04c90ef..a8a5695 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3052,3 +3052,63 @@ func TestUnmarshalSliceFail2(t *testing.T) { } } + +func TestDecoderStrict(t *testing.T) { + input := ` +[decoded] + key = "" + +[undecoded] + key = "" + + [undecoded.inner] + key = "" + + [[undecoded.array]] + key = "" + + [[undecoded.array]] + key = "" + +` + var doc struct { + Decoded struct { + Key string + } + } + + expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]` + + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err == nil { + t.Error("expected error, got none") + } else if err.Error() != expected { + t.Errorf("expect err: %s, got: %s", expected, err.Error()) + } + + if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil { + t.Errorf("unexpected err: %s", err) + } + + var m map[string]interface{} + if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil { + t.Errorf("unexpected err: %s", err) + } +} + +func TestDecoderStrictValid(t *testing.T) { + input := ` +[decoded] + key = "" +` + var doc struct { + Decoded struct { + Key string + } + } + + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err != nil { + t.Fatal("unexpected error:", err) + } +}