From 84da2c4a25c585816f2c4211b699228d111d18ab Mon Sep 17 00:00:00 2001 From: Roberth Kulbin Date: Thu, 25 Jul 2019 08:06:17 +0100 Subject: [PATCH] Merge struct fields in Unmarshal (#284) * add test for unexported field preservation * merge struct values instead of replacing them * use struct merging on nested value structs * unmarshalling merges nested struct pointers when non-nil --- marshal.go | 54 +++++++++++++++++-------- marshal_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 16 deletions(-) diff --git a/marshal.go b/marshal.go index a0a5ca1..6984dd8 100644 --- a/marshal.go +++ b/marshal.go @@ -526,7 +526,9 @@ func (d *Decoder) unmarshal(v interface{}) error { return errors.New("only a pointer to struct or map can be unmarshaled from TOML") } - sval, err := d.valueFromTree(elem, d.tval) + vv := reflect.ValueOf(v).Elem() + + sval, err := d.valueFromTree(elem, d.tval, &vv) if err != nil { return err } @@ -534,15 +536,21 @@ func (d *Decoder) unmarshal(v interface{}) error { return nil } -// Convert toml tree to marshal struct or map, using marshal type -func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, error) { +// Convert toml tree to marshal struct or map, using marshal type. When mval1 +// is non-nil, merge fields into the given value instead of allocating a new one. +func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.Value) (reflect.Value, error) { if mtype.Kind() == reflect.Ptr { - return d.unwrapPointer(mtype, tval) + return d.unwrapPointer(mtype, tval, mval1) } var mval reflect.Value switch mtype.Kind() { case reflect.Struct: - mval = reflect.New(mtype).Elem() + if mval1 != nil { + mval = *mval1 + } else { + mval = reflect.New(mtype).Elem() + } + for i := 0; i < mtype.NumField(); i++ { mtypef := mtype.Field(i) an := annotation{tag: d.tagName} @@ -563,7 +571,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, continue } val := tval.Get(key) - mvalf, err := d.valueFromToml(mtypef.Type, val) + fval := mval.Field(i) + mvalf, err := d.valueFromToml(mtypef.Type, val, &fval) if err != nil { return mval, formatError(err, tval.GetPosition(key)) } @@ -607,7 +616,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, // save the old behavior above and try to check anonymous structs if !found && opts.defaultValue == "" && mtypef.Anonymous && mtypef.Type.Kind() == reflect.Struct { - v, err := d.valueFromTree(mtypef.Type, tval) + v, err := d.valueFromTree(mtypef.Type, tval, nil) if err != nil { return v, err } @@ -620,7 +629,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, for _, key := range tval.Keys() { // TODO: path splits key val := tval.GetPath([]string{key}) - mvalf, err := d.valueFromToml(mtype.Elem(), val) + mvalf, err := d.valueFromToml(mtype.Elem(), val, nil) if err != nil { return mval, formatError(err, tval.GetPosition(key)) } @@ -634,7 +643,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) { mval := reflect.MakeSlice(mtype, len(tval), len(tval)) for i := 0; i < len(tval); i++ { - val, err := d.valueFromTree(mtype.Elem(), tval[i]) + val, err := d.valueFromTree(mtype.Elem(), tval[i], nil) if err != nil { return mval, err } @@ -647,7 +656,7 @@ func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect. func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) { mval := reflect.MakeSlice(mtype, len(tval), len(tval)) for i := 0; i < len(tval); i++ { - val, err := d.valueFromToml(mtype.Elem(), tval[i]) + val, err := d.valueFromToml(mtype.Elem(), tval[i], nil) if err != nil { return mval, err } @@ -656,16 +665,22 @@ func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (r return mval, nil } -// Convert toml value to marshal value, using marshal type -func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.Value, error) { +// Convert toml value to marshal value, using marshal type. When mval1 is non-nil +// and the given type is a struct value, merge fields into it. +func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) { if mtype.Kind() == reflect.Ptr { - return d.unwrapPointer(mtype, tval) + return d.unwrapPointer(mtype, tval, mval1) } switch t := tval.(type) { case *Tree: + var mval11 *reflect.Value + if mtype.Kind() == reflect.Struct { + mval11 = mval1 + } + if isTree(mtype) { - return d.valueFromTree(mtype, t) + return d.valueFromTree(mtype, t, mval11) } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a tree", tval, tval) case []*Tree: @@ -743,8 +758,15 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.V } } -func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}) (reflect.Value, error) { - val, err := d.valueFromToml(mtype.Elem(), tval) +func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) { + var melem *reflect.Value + + if mval1 != nil && !mval1.IsNil() && mtype.Elem().Kind() == reflect.Struct { + elem := mval1.Elem() + melem = &elem + } + + val, err := d.valueFromToml(mtype.Elem(), tval, melem) if err != nil { return reflect.ValueOf(nil), err } diff --git a/marshal_test.go b/marshal_test.go index 02dc9fe..45c0e92 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1460,3 +1460,108 @@ func TestUnmarshalNestedAnonymousStructs_Controversial(t *testing.T) { t.Fatal("should error") } } + +type unexportedFieldPreservationTest struct { + Exported string `toml:"exported"` + unexported string + Nested1 unexportedFieldPreservationTestNested `toml:"nested1"` + Nested2 *unexportedFieldPreservationTestNested `toml:"nested2"` + Nested3 *unexportedFieldPreservationTestNested `toml:"nested3"` + Slice1 []unexportedFieldPreservationTestNested `toml:"slice1"` + Slice2 []*unexportedFieldPreservationTestNested `toml:"slice2"` +} + +type unexportedFieldPreservationTestNested struct { + Exported1 string `toml:"exported1"` + unexported1 string +} + +func TestUnmarshalPreservesUnexportedFields(t *testing.T) { + toml := ` + exported = "visible" + unexported = "ignored" + + [nested1] + exported1 = "visible1" + unexported1 = "ignored1" + + [nested2] + exported1 = "visible2" + unexported1 = "ignored2" + + [nested3] + exported1 = "visible3" + unexported1 = "ignored3" + + [[slice1]] + exported1 = "visible3" + + [[slice1]] + exported1 = "visible4" + + [[slice2]] + exported1 = "visible5" + ` + + t.Run("unexported field should not be set from toml", func(t *testing.T) { + var actual unexportedFieldPreservationTest + err := Unmarshal([]byte(toml), &actual) + + if err != nil { + t.Fatal("did not expect an error") + } + + expect := unexportedFieldPreservationTest{ + Exported: "visible", + unexported: "", + Nested1: unexportedFieldPreservationTestNested{"visible1", ""}, + Nested2: &unexportedFieldPreservationTestNested{"visible2", ""}, + Nested3: &unexportedFieldPreservationTestNested{"visible3", ""}, + Slice1: []unexportedFieldPreservationTestNested{ + {Exported1: "visible3"}, + {Exported1: "visible4"}, + }, + Slice2: []*unexportedFieldPreservationTestNested{ + {Exported1: "visible5"}, + }, + } + + if !reflect.DeepEqual(actual, expect) { + t.Fatalf("%+v did not equal %+v", actual, expect) + } + }) + + t.Run("unexported field should be preserved", func(t *testing.T) { + actual := unexportedFieldPreservationTest{ + Exported: "foo", + unexported: "bar", + Nested1: unexportedFieldPreservationTestNested{"baz", "bax"}, + Nested2: nil, + Nested3: &unexportedFieldPreservationTestNested{"baz", "bax"}, + } + err := Unmarshal([]byte(toml), &actual) + + if err != nil { + t.Fatal("did not expect an error") + } + + expect := unexportedFieldPreservationTest{ + Exported: "visible", + unexported: "bar", + Nested1: unexportedFieldPreservationTestNested{"visible1", "bax"}, + Nested2: &unexportedFieldPreservationTestNested{"visible2", ""}, + Nested3: &unexportedFieldPreservationTestNested{"visible3", "bax"}, + Slice1: []unexportedFieldPreservationTestNested{ + {Exported1: "visible3"}, + {Exported1: "visible4"}, + }, + Slice2: []*unexportedFieldPreservationTestNested{ + {Exported1: "visible5"}, + }, + } + + if !reflect.DeepEqual(actual, expect) { + t.Fatalf("%+v did not equal %+v", actual, expect) + } + }) +}