Merge struct fields in Unmarshal (#284)
* add test for unexported field preservation * merge struct values instead of replacing them * use struct merging on nested value structs * unmarshalling merges nested struct pointers when non-nil
This commit is contained in:
committed by
Thomas Pelletier
parent
dba45d427f
commit
84da2c4a25
+38
-16
@@ -526,7 +526,9 @@ func (d *Decoder) unmarshal(v interface{}) error {
|
||||
return errors.New("only a pointer to struct or map can be unmarshaled from TOML")
|
||||
}
|
||||
|
||||
sval, err := d.valueFromTree(elem, d.tval)
|
||||
vv := reflect.ValueOf(v).Elem()
|
||||
|
||||
sval, err := d.valueFromTree(elem, d.tval, &vv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -534,15 +536,21 @@ func (d *Decoder) unmarshal(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert toml tree to marshal struct or map, using marshal type
|
||||
func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, error) {
|
||||
// Convert toml tree to marshal struct or map, using marshal type. When mval1
|
||||
// is non-nil, merge fields into the given value instead of allocating a new one.
|
||||
func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.Value) (reflect.Value, error) {
|
||||
if mtype.Kind() == reflect.Ptr {
|
||||
return d.unwrapPointer(mtype, tval)
|
||||
return d.unwrapPointer(mtype, tval, mval1)
|
||||
}
|
||||
var mval reflect.Value
|
||||
switch mtype.Kind() {
|
||||
case reflect.Struct:
|
||||
mval = reflect.New(mtype).Elem()
|
||||
if mval1 != nil {
|
||||
mval = *mval1
|
||||
} else {
|
||||
mval = reflect.New(mtype).Elem()
|
||||
}
|
||||
|
||||
for i := 0; i < mtype.NumField(); i++ {
|
||||
mtypef := mtype.Field(i)
|
||||
an := annotation{tag: d.tagName}
|
||||
@@ -563,7 +571,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value,
|
||||
continue
|
||||
}
|
||||
val := tval.Get(key)
|
||||
mvalf, err := d.valueFromToml(mtypef.Type, val)
|
||||
fval := mval.Field(i)
|
||||
mvalf, err := d.valueFromToml(mtypef.Type, val, &fval)
|
||||
if err != nil {
|
||||
return mval, formatError(err, tval.GetPosition(key))
|
||||
}
|
||||
@@ -607,7 +616,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value,
|
||||
|
||||
// save the old behavior above and try to check anonymous structs
|
||||
if !found && opts.defaultValue == "" && mtypef.Anonymous && mtypef.Type.Kind() == reflect.Struct {
|
||||
v, err := d.valueFromTree(mtypef.Type, tval)
|
||||
v, err := d.valueFromTree(mtypef.Type, tval, nil)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
@@ -620,7 +629,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value,
|
||||
for _, key := range tval.Keys() {
|
||||
// TODO: path splits key
|
||||
val := tval.GetPath([]string{key})
|
||||
mvalf, err := d.valueFromToml(mtype.Elem(), val)
|
||||
mvalf, err := d.valueFromToml(mtype.Elem(), val, nil)
|
||||
if err != nil {
|
||||
return mval, formatError(err, tval.GetPosition(key))
|
||||
}
|
||||
@@ -634,7 +643,7 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value,
|
||||
func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
|
||||
mval := reflect.MakeSlice(mtype, len(tval), len(tval))
|
||||
for i := 0; i < len(tval); i++ {
|
||||
val, err := d.valueFromTree(mtype.Elem(), tval[i])
|
||||
val, err := d.valueFromTree(mtype.Elem(), tval[i], nil)
|
||||
if err != nil {
|
||||
return mval, err
|
||||
}
|
||||
@@ -647,7 +656,7 @@ func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.
|
||||
func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) {
|
||||
mval := reflect.MakeSlice(mtype, len(tval), len(tval))
|
||||
for i := 0; i < len(tval); i++ {
|
||||
val, err := d.valueFromToml(mtype.Elem(), tval[i])
|
||||
val, err := d.valueFromToml(mtype.Elem(), tval[i], nil)
|
||||
if err != nil {
|
||||
return mval, err
|
||||
}
|
||||
@@ -656,16 +665,22 @@ func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (r
|
||||
return mval, nil
|
||||
}
|
||||
|
||||
// Convert toml value to marshal value, using marshal type
|
||||
func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
|
||||
// Convert toml value to marshal value, using marshal type. When mval1 is non-nil
|
||||
// and the given type is a struct value, merge fields into it.
|
||||
func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) {
|
||||
if mtype.Kind() == reflect.Ptr {
|
||||
return d.unwrapPointer(mtype, tval)
|
||||
return d.unwrapPointer(mtype, tval, mval1)
|
||||
}
|
||||
|
||||
switch t := tval.(type) {
|
||||
case *Tree:
|
||||
var mval11 *reflect.Value
|
||||
if mtype.Kind() == reflect.Struct {
|
||||
mval11 = mval1
|
||||
}
|
||||
|
||||
if isTree(mtype) {
|
||||
return d.valueFromTree(mtype, t)
|
||||
return d.valueFromTree(mtype, t, mval11)
|
||||
}
|
||||
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a tree", tval, tval)
|
||||
case []*Tree:
|
||||
@@ -743,8 +758,15 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.V
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
|
||||
val, err := d.valueFromToml(mtype.Elem(), tval)
|
||||
func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) {
|
||||
var melem *reflect.Value
|
||||
|
||||
if mval1 != nil && !mval1.IsNil() && mtype.Elem().Kind() == reflect.Struct {
|
||||
elem := mval1.Elem()
|
||||
melem = &elem
|
||||
}
|
||||
|
||||
val, err := d.valueFromToml(mtype.Elem(), tval, melem)
|
||||
if err != nil {
|
||||
return reflect.ValueOf(nil), err
|
||||
}
|
||||
|
||||
+105
@@ -1460,3 +1460,108 @@ func TestUnmarshalNestedAnonymousStructs_Controversial(t *testing.T) {
|
||||
t.Fatal("should error")
|
||||
}
|
||||
}
|
||||
|
||||
type unexportedFieldPreservationTest struct {
|
||||
Exported string `toml:"exported"`
|
||||
unexported string
|
||||
Nested1 unexportedFieldPreservationTestNested `toml:"nested1"`
|
||||
Nested2 *unexportedFieldPreservationTestNested `toml:"nested2"`
|
||||
Nested3 *unexportedFieldPreservationTestNested `toml:"nested3"`
|
||||
Slice1 []unexportedFieldPreservationTestNested `toml:"slice1"`
|
||||
Slice2 []*unexportedFieldPreservationTestNested `toml:"slice2"`
|
||||
}
|
||||
|
||||
type unexportedFieldPreservationTestNested struct {
|
||||
Exported1 string `toml:"exported1"`
|
||||
unexported1 string
|
||||
}
|
||||
|
||||
func TestUnmarshalPreservesUnexportedFields(t *testing.T) {
|
||||
toml := `
|
||||
exported = "visible"
|
||||
unexported = "ignored"
|
||||
|
||||
[nested1]
|
||||
exported1 = "visible1"
|
||||
unexported1 = "ignored1"
|
||||
|
||||
[nested2]
|
||||
exported1 = "visible2"
|
||||
unexported1 = "ignored2"
|
||||
|
||||
[nested3]
|
||||
exported1 = "visible3"
|
||||
unexported1 = "ignored3"
|
||||
|
||||
[[slice1]]
|
||||
exported1 = "visible3"
|
||||
|
||||
[[slice1]]
|
||||
exported1 = "visible4"
|
||||
|
||||
[[slice2]]
|
||||
exported1 = "visible5"
|
||||
`
|
||||
|
||||
t.Run("unexported field should not be set from toml", func(t *testing.T) {
|
||||
var actual unexportedFieldPreservationTest
|
||||
err := Unmarshal([]byte(toml), &actual)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("did not expect an error")
|
||||
}
|
||||
|
||||
expect := unexportedFieldPreservationTest{
|
||||
Exported: "visible",
|
||||
unexported: "",
|
||||
Nested1: unexportedFieldPreservationTestNested{"visible1", ""},
|
||||
Nested2: &unexportedFieldPreservationTestNested{"visible2", ""},
|
||||
Nested3: &unexportedFieldPreservationTestNested{"visible3", ""},
|
||||
Slice1: []unexportedFieldPreservationTestNested{
|
||||
{Exported1: "visible3"},
|
||||
{Exported1: "visible4"},
|
||||
},
|
||||
Slice2: []*unexportedFieldPreservationTestNested{
|
||||
{Exported1: "visible5"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expect) {
|
||||
t.Fatalf("%+v did not equal %+v", actual, expect)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexported field should be preserved", func(t *testing.T) {
|
||||
actual := unexportedFieldPreservationTest{
|
||||
Exported: "foo",
|
||||
unexported: "bar",
|
||||
Nested1: unexportedFieldPreservationTestNested{"baz", "bax"},
|
||||
Nested2: nil,
|
||||
Nested3: &unexportedFieldPreservationTestNested{"baz", "bax"},
|
||||
}
|
||||
err := Unmarshal([]byte(toml), &actual)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("did not expect an error")
|
||||
}
|
||||
|
||||
expect := unexportedFieldPreservationTest{
|
||||
Exported: "visible",
|
||||
unexported: "bar",
|
||||
Nested1: unexportedFieldPreservationTestNested{"visible1", "bax"},
|
||||
Nested2: &unexportedFieldPreservationTestNested{"visible2", ""},
|
||||
Nested3: &unexportedFieldPreservationTestNested{"visible3", "bax"},
|
||||
Slice1: []unexportedFieldPreservationTestNested{
|
||||
{Exported1: "visible3"},
|
||||
{Exported1: "visible4"},
|
||||
},
|
||||
Slice2: []*unexportedFieldPreservationTestNested{
|
||||
{Exported1: "visible5"},
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expect) {
|
||||
t.Fatalf("%+v did not equal %+v", actual, expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user