From a1c9b661b4993215ad9a66b460230e1d9d080cc1 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Mon, 8 Mar 2021 21:41:03 -0500 Subject: [PATCH] Allocate slice if needed --- internal/imported_tests/unmarshal_imported_test.go | 9 ++------- internal/reflectbuild/reflectbuild.go | 6 ++++++ unmarshal.go | 14 +++++++------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 4ef0fb2..911c3a9 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -1827,9 +1827,7 @@ type arrayTooSmallStruct struct { func TestUnmarshalSlice(t *testing.T) { var actual sliceStruct err := toml.Unmarshal(sliceTomlDemo, &actual) - if err != nil { - t.Error("shound not err", err) - } + require.NoError(t, err) expected := sliceStruct{ Slice: []string{"Howdy", "Hey There"}, SlicePtr: &[]string{"Howdy", "Hey There"}, @@ -1838,10 +1836,7 @@ func TestUnmarshalSlice(t *testing.T) { StructSlice: []basicMarshalTestSubStruct{{"1"}, {"2"}}, StructSlicePtr: &[]basicMarshalTestSubStruct{{"1"}, {"2"}}, } - if !reflect.DeepEqual(actual, expected) { - t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual) - } - + assert.Equal(t, expected, actual) } func TestUnmarshalSliceFail(t *testing.T) { diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 0ad06e4..b79807a 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -380,6 +380,12 @@ func (b *Builder) SliceNewElem() error { v := t.get() if v.Kind() == reflect.Ptr { + // if the pointer is nil we need to allocate the slice + if v.IsNil() { + x := reflect.New(v.Type().Elem()) + v.Set(x) + } + // target the slice itself v = v.Elem() } diff --git a/unmarshal.go b/unmarshal.go index 60b309d..58cdfac 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -165,7 +165,7 @@ func (u *unmarshaler) BoolValue(b bool) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&b)) if u.err != nil { @@ -181,7 +181,7 @@ func (u *unmarshaler) FloatValue(n float64) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&n)) if u.err != nil { @@ -198,7 +198,7 @@ func (u *unmarshaler) IntValue(n int64) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&n)) if u.err != nil { @@ -214,7 +214,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&date)) if u.err != nil { @@ -230,7 +230,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&dt)) if u.err != nil { @@ -246,7 +246,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&dt)) if u.err != nil { @@ -262,7 +262,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() u.err = u.builder.SliceAppend(reflect.ValueOf(&localTime)) if u.err != nil {