diff --git a/README.md b/README.md index 5b8ed22..2253d6e 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Development branch. Probably does not work. - [x] Unmarshal into maps. - [x] Support Array Tables. -- [ ] Unmarshal into pointers. +- [x] Unmarshal into pointers. - [ ] Support Date / times. - [ ] Support Unmarshaler interface. - [ ] Support struct tags annotations. diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index d604cfe..9c92491 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -138,13 +138,8 @@ func TestInterface(t *testing.T) { func TestBasicUnmarshal(t *testing.T) { result := basicMarshalTestStruct{} err := toml.Unmarshal(basicTestToml, &result) - expected := basicTestData - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(result, expected) { - t.Errorf("Bad unmarshal: expected %v, got %v", expected, result) - } + require.NoError(t, err) + require.Equal(t, basicTestData, result) } type quotedKeyMarshalTestStruct struct { diff --git a/targets.go b/targets.go index 0b0401d..f7c0b02 100644 --- a/targets.go +++ b/targets.go @@ -229,16 +229,20 @@ func scopeTableTarget(append bool, t target, name string) (target, error) { x := t.get() switch x.Kind() { + // Kinds that need to recurse + case reflect.Interface: t, err := scopeInterface(append, t) if err != nil { return t, err } return scopeTableTarget(append, t, name) - case reflect.Struct: - return scopeStruct(x, name) - case reflect.Map: - return scopeMap(x, name) + case reflect.Ptr: + t, err := scopePtr(t) + if err != nil { + return t, err + } + return scopeTableTarget(append, t, name) case reflect.Slice: t, err := scopeSlice(append, t) if err != nil { @@ -246,6 +250,13 @@ func scopeTableTarget(append bool, t target, name string) (target, error) { } append = false return scopeTableTarget(append, t, name) + + // Terminal kinds + + case reflect.Struct: + return scopeStruct(x, name) + case reflect.Map: + return scopeMap(x, name) default: panic(fmt.Errorf("can't scope on a %s", x.Kind())) } @@ -260,6 +271,22 @@ func scopeInterface(append bool, t target) (target, error) { return interfaceTarget{t}, nil } +func scopePtr(t target) (target, error) { + err := initPtr(t) + if err != nil { + return t, err + } + return valueTarget(t.get().Elem()), nil +} + +func initPtr(t target) error { + x := t.get() + if !x.IsNil() { + return nil + } + return t.set(reflect.New(x.Type().Elem())) +} + // initInterface makes sure that the interface pointed at by the target is not // nil. // Returns the target to the initialized value of the target.