diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 1a9e5d5..5b96dcd 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -296,14 +296,6 @@ func TestDocUnmarshal(t *testing.T) { expected := docData require.NoError(t, err) assert.Equal(t, expected, result) - //if err != nil { - // t.Fatal(err) - //} - //if !reflect.DeepEqual(result, expected) { - // resStr, _ := json.MarshalIndent(result, "", " ") - // expStr, _ := json.MarshalIndent(expected, "", " ") - // t.Errorf("Bad unmarshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expStr, resStr) - //} } type tomlTypeCheckTest struct { diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 7b8576e..2203973 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -44,7 +44,6 @@ func (b *Builder) getOrGenerateFieldGettersRecursive(m structFieldGetters, idx [ // only consider exported fields continue } - // TODO: handle embedded structs if f.Anonymous { b.getOrGenerateFieldGettersRecursive(m, copyAndAppend(idx, i), f.Type) } else { @@ -223,6 +222,10 @@ func (b *Builder) IsSlice() bool { return b.top().Kind() == reflect.Slice } +func (b *Builder) IsSliceOrPtr() bool { + return b.top().Kind() == reflect.Slice || (b.top().Kind() == reflect.Ptr && b.top().Type().Elem().Kind() == reflect.Slice) +} + // Last moves the cursor to the last value of the current value. // For a slice or an array, it is the last element they contain, if any. // For anything else, it's a no-op. @@ -269,12 +272,40 @@ func (b *Builder) SliceNewElem() error { return nil } +func assertPtr(v reflect.Value) { + if v.Kind() != reflect.Ptr { + panic(fmt.Sprintf("value '%s' should be a ptr, not '%s'", v, v.Kind())) + } +} + func (b *Builder) SliceAppend(v reflect.Value) error { + assertPtr(v) + t := b.top() + + // pointer to a slice + if t.Kind() == reflect.Ptr { + // if the pointer is nil we need to allocate the slice + if t.IsNil() { + x := reflect.New(t.Type().Elem()) + t.Set(x) + } + // target the slice itself + t = t.Elem() + } + err := checkKind(t.Type(), reflect.Slice) if err != nil { return err } + + if t.Type().Elem().Kind() == reflect.Ptr { + // if it is a slice of pointers, we can just append + } else { + // otherwise we need to reference the value + v = v.Elem() + } + newSlice := reflect.Append(t, v) t.Set(newSlice) b.replace(t.Index(t.Len() - 1)) @@ -286,12 +317,16 @@ func (b *Builder) SliceAppend(v reflect.Value) error { func (b *Builder) SetString(s string) error { t := b.top() - err := checkKind(t.Type(), reflect.String) - if err != nil { - return err - } + if t.Kind() == reflect.Ptr { + t.Set(reflect.ValueOf(&s)) + } else { + err := checkKind(t.Type(), reflect.String) + if err != nil { + return err + } - t.SetString(s) + t.SetString(s) + } return nil } diff --git a/internal/reflectbuild/reflectbuild_test.go b/internal/reflectbuild/reflectbuild_test.go index 9c80dfe..26d526c 100644 --- a/internal/reflectbuild/reflectbuild_test.go +++ b/internal/reflectbuild/reflectbuild_test.go @@ -165,3 +165,41 @@ func TestCursor(t *testing.T) { require.NoError(t, b.DigField("Field")) assert.Equal(t, b.Cursor().Kind(), reflect.String) } + +func TestStringPtr(t *testing.T) { + x := struct { + Field *string + }{} + b, err := reflectbuild.NewBuilder("", &x) + require.NoError(t, err) + assert.Equal(t, b.Cursor().Kind(), reflect.Struct) + require.NoError(t, b.DigField("Field")) + assert.NoError(t, b.SetString("A")) + assert.Equal(t, "A", *x.Field) +} + +func TestAppendSlicePtr(t *testing.T) { + x := struct { + Field *[]string + }{} + b, err := reflectbuild.NewBuilder("", &x) + require.NoError(t, err) + assert.Equal(t, b.Cursor().Kind(), reflect.Struct) + require.NoError(t, b.DigField("Field")) + v := "A" + assert.NoError(t, b.SliceAppend(reflect.ValueOf(&v))) + assert.Equal(t, []string{"A"}, *x.Field) +} + +func TestAppendPtrSlicePtr(t *testing.T) { + x := struct { + Field *[]*string + }{} + b, err := reflectbuild.NewBuilder("", &x) + require.NoError(t, err) + assert.Equal(t, b.Cursor().Kind(), reflect.Struct) + require.NoError(t, b.DigField("Field")) + v := "A" + assert.NoError(t, b.SliceAppend(reflect.ValueOf(&v))) + assert.Equal(t, "A", *(*x.Field)[0]) +} diff --git a/unmarshal.go b/unmarshal.go index 4dc20c1..43f8a94 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -139,9 +139,10 @@ func (u *unmarshaler) StringValue(v []byte) { if u.skipping() || u.err != nil { return } - if u.builder.IsSlice() { + if u.builder.IsSliceOrPtr() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(string(v))) + s := string(v) + u.err = u.builder.SliceAppend(reflect.ValueOf(&s)) if u.err != nil { return } @@ -157,7 +158,7 @@ func (u *unmarshaler) BoolValue(b bool) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(b)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&b)) if u.err != nil { return } @@ -173,7 +174,7 @@ func (u *unmarshaler) FloatValue(n float64) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(n)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&n)) if u.err != nil { return } @@ -189,7 +190,7 @@ func (u *unmarshaler) IntValue(n int64) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(n)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&n)) if u.err != nil { return } @@ -205,7 +206,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(date)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&date)) if u.err != nil { return } @@ -221,7 +222,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(dt)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&dt)) if u.err != nil { return } @@ -237,7 +238,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(dt)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&dt)) if u.err != nil { return } @@ -253,7 +254,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) { } if u.builder.IsSlice() { u.builder.Save() - u.err = u.builder.SliceAppend(reflect.ValueOf(localTime)) + u.err = u.builder.SliceAppend(reflect.ValueOf(&localTime)) if u.err != nil { return }