From 7f9822db3597de366661216b13cf08fb78591a8e Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Fri, 19 Feb 2021 08:39:18 -0500 Subject: [PATCH] Target set methods now check for types --- .../imported_tests/unmarshal_imported_test.go | 49 +++++++++---------- internal/reflectbuild/reflectbuild.go | 43 +++++++++------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 5b87dec..e5f282f 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -421,26 +421,6 @@ func TestErrUnmarshal(t *testing.T) { } } -type emptyMarshalTestStruct struct { - Title string `toml:"title"` - Bool bool `toml:"bool"` - Int int `toml:"int"` - String string `toml:"string"` - StringList []string `toml:"stringlist"` - Ptr *basicMarshalTestStruct `toml:"ptr"` - Map map[string]string `toml:"map"` -} - -var emptyTestData = emptyMarshalTestStruct{ - Title: "Placeholder", - Bool: false, - Int: 0, - String: "", - StringList: []string{}, - Ptr: nil, - Map: map[string]string{}, -} - var emptyTestToml = []byte(`bool = false int = 0 string = "" @@ -474,15 +454,30 @@ var emptyTestToml2 = []byte(`title = "Placeholder" `) func TestEmptytomlUnmarshal(t *testing.T) { + type emptyMarshalTestStruct struct { + Title string `toml:"title"` + Bool bool `toml:"bool"` + Int int `toml:"int"` + String string `toml:"string"` + StringList []string `toml:"stringlist"` + Ptr *basicMarshalTestStruct `toml:"ptr"` + Map map[string]string `toml:"map"` + } + + emptyTestData := emptyMarshalTestStruct{ + Title: "Placeholder", + Bool: false, + Int: 0, + String: "", + StringList: []string{}, + Ptr: nil, + Map: map[string]string{}, + } + result := emptyMarshalTestStruct{} err := toml.Unmarshal(emptyTestToml, &result) - expected := emptyTestData - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(result, expected) { - t.Errorf("Bad empty unmarshal: expected %v, got %v", expected, result) - } + require.NoError(t, err) + assert.Equal(t, emptyTestData, result) } func TestEmptyUnmarshalOmit(t *testing.T) { diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 85a9e6b..d2ce732 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -17,19 +17,32 @@ type structFieldGetters map[string]fieldGetter type target interface { get() reflect.Value - set(value reflect.Value) + set(value reflect.Value) error fmt.Stringer } +func isAssignable(t reflect.Type, v reflect.Value) error { + if v.Type().AssignableTo(t) { + return nil + } + return fmt.Errorf("cannot assign '%s' ('%s') to a '%s'", v, v.Type(), t) +} + type valueTarget reflect.Value func (v valueTarget) get() reflect.Value { return reflect.Value(v) } -func (v valueTarget) set(value reflect.Value) { +func (v valueTarget) set(value reflect.Value) error { + rv := reflect.Value(v) + err := isAssignable(rv.Type(), value) + if err != nil { + return err + } reflect.Value(v).Set(value) + return nil } func (v valueTarget) String() string { @@ -45,8 +58,13 @@ func (v mapTarget) get() reflect.Value { return v.m.MapIndex(v.index) } -func (v mapTarget) set(value reflect.Value) { +func (v mapTarget) set(value reflect.Value) error { + err := isAssignable(v.m.Type().Elem(), value) + if err != nil { + return err + } v.m.SetMapIndex(v.index, value) + return nil } func (v mapTarget) String() string { @@ -413,23 +431,11 @@ func (b *Builder) SetString(s string) error { t := b.top() v := t.get() - if !v.IsValid() { - fmt.Println("============ INVALID ===========") - fmt.Println(b.Dump()) - fmt.Println("==================== ===========") - } - if v.Kind() == reflect.Ptr { v.Set(reflect.ValueOf(&s)) - } else { - err := checkKind(v.Type(), reflect.String) - if err != nil { - return err - } - - v.SetString(s) + return nil } - return nil + return t.set(reflect.ValueOf(s)) } // Set the value at the cursor to the given boolean. @@ -481,8 +487,7 @@ func (b *Builder) SetInt(n int64) error { func (b *Builder) Set(v reflect.Value) error { t := b.top() - t.set(v) - return nil + return t.set(v) } func checkKindInt(rt reflect.Type) error {