From 4526154571816caab84c5d33c1c7d21cb3247bcd Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Fri, 19 Feb 2021 09:39:50 -0500 Subject: [PATCH] wip --- .../imported_tests/unmarshal_imported_test.go | 64 ++++++++----------- internal/reflectbuild/reflectbuild.go | 60 +++++++++++++++-- internal/reflectbuild/reflectbuild_test.go | 2 +- unmarshal.go | 17 +++-- 4 files changed, 94 insertions(+), 49 deletions(-) diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index e5f282f..d642fc1 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -430,29 +430,6 @@ title = "Placeholder" [map] `) -type emptyMarshalTestStruct2 struct { - Title string `toml:"title"` - Bool bool `toml:"bool,omitempty"` - Int int `toml:"int, omitempty"` - String string `toml:"string,omitempty "` - StringList []string `toml:"stringlist,omitempty"` - Ptr *basicMarshalTestStruct `toml:"ptr,omitempty"` - Map map[string]string `toml:"map,omitempty"` -} - -var emptyTestData2 = emptyMarshalTestStruct2{ - Title: "Placeholder", - Bool: false, - Int: 0, - String: "", - StringList: []string{}, - Ptr: nil, - Map: map[string]string{}, -} - -var emptyTestToml2 = []byte(`title = "Placeholder" -`) - func TestEmptytomlUnmarshal(t *testing.T) { type emptyMarshalTestStruct struct { Title string `toml:"title"` @@ -481,15 +458,32 @@ func TestEmptytomlUnmarshal(t *testing.T) { } func TestEmptyUnmarshalOmit(t *testing.T) { + t.Skipf("Have not figured yet if omitempty is a good idea") + + type emptyMarshalTestStruct2 struct { + Title string `toml:"title"` + Bool bool `toml:"bool,omitempty"` + Int int `toml:"int, omitempty"` + String string `toml:"string,omitempty "` + StringList []string `toml:"stringlist,omitempty"` + Ptr *basicMarshalTestStruct `toml:"ptr,omitempty"` + Map map[string]string `toml:"map,omitempty"` + } + + var emptyTestData2 = emptyMarshalTestStruct2{ + Title: "Placeholder", + Bool: false, + Int: 0, + String: "", + StringList: []string{}, + Ptr: nil, + Map: map[string]string{}, + } + result := emptyMarshalTestStruct2{} err := toml.Unmarshal(emptyTestToml, &result) - expected := emptyTestData2 - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(result, expected) { - t.Errorf("Bad empty omit unmarshal: expected %v, got %v", expected, result) - } + require.NoError(t, err) + assert.Equal(t, emptyTestData2, result) } type pointerMarshalTestStruct struct { @@ -532,15 +526,11 @@ Str = "Hello" `) func TestPointerUnmarshal(t *testing.T) { + t.Log("TOML data:", string(pointerTestToml)) result := pointerMarshalTestStruct{} err := toml.Unmarshal(pointerTestToml, &result) - expected := pointerTestData - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(result, expected) { - t.Errorf("Bad pointer unmarshal: expected %v, got %v", expected, result) - } + require.NoError(t, err) + assert.Equal(t, pointerTestData, result) } func TestUnmarshalTypeMismatch(t *testing.T) { diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 4a67797..edbcc82 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -37,6 +37,14 @@ func (v valueTarget) get() reflect.Value { func (v valueTarget) set(value reflect.Value) error { rv := reflect.Value(v) + + // value is guaranteed to be a pointer + + if rv.Kind() != reflect.Ptr { + // TODO: check value is nil? + value = value.Elem() + } + err := isAssignable(rv.Type(), value) if err != nil { return err @@ -59,6 +67,13 @@ func (v mapTarget) get() reflect.Value { } func (v mapTarget) set(value reflect.Value) error { + // value is guaranteed to be a pointer + + if v.m.Type().Elem().Kind() != reflect.Ptr { + // TODO: check value is nil? + value = value.Elem() + } + err := isAssignable(v.m.Type().Elem(), value) if err != nil { return err @@ -486,6 +501,7 @@ func (b *Builder) SetInt(n int64) error { } func (b *Builder) Set(v reflect.Value) error { + assertPtr(v) t := b.top() return t.set(v) } @@ -494,8 +510,16 @@ func (b *Builder) Set(v reflect.Value) error { func (b *Builder) EnsureSlice() error { t := b.top() v := t.get() + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + if v.Kind() != reflect.Slice { - return IncorrectKindError{Actual: v.Kind(), Expected: reflect.Slice} + return IncorrectKindError{Actual: v.Kind(), Expected: []reflect.Kind{reflect.Slice}} } if v.IsNil() { @@ -505,6 +529,27 @@ func (b *Builder) EnsureSlice() error { return nil } +// EnsureStructOrMap makes sure that the cursor points to an initialized +// struct or map. +func (b *Builder) EnsureStructOrMap() error { + t := b.top() + v := t.get() + + switch v.Kind() { + case reflect.Struct: + case reflect.Map: + if v.IsNil() { + return t.set(reflect.MakeMap(v.Type())) + } + default: + return IncorrectKindError{ + Actual: v.Kind(), + Expected: []reflect.Kind{reflect.Struct, reflect.Map}, + } + } + return nil +} + func checkKindInt(rt reflect.Type) error { switch rt.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -513,7 +558,7 @@ func checkKindInt(rt reflect.Type) error { return IncorrectKindError{ Actual: rt.Kind(), - Expected: reflect.Int, + Expected: []reflect.Kind{reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64}, } } @@ -525,7 +570,7 @@ func checkKindFloat(rt reflect.Type) error { return IncorrectKindError{ Actual: rt.Kind(), - Expected: reflect.Float64, + Expected: []reflect.Kind{reflect.Float64}, } } @@ -533,7 +578,7 @@ func checkKind(rt reflect.Type, expected reflect.Kind) error { if rt.Kind() != expected { return IncorrectKindError{ Actual: rt.Kind(), - Expected: expected, + Expected: []reflect.Kind{expected}, } } return nil @@ -541,11 +586,14 @@ func checkKind(rt reflect.Type, expected reflect.Kind) error { type IncorrectKindError struct { Actual reflect.Kind - Expected reflect.Kind + Expected []reflect.Kind } func (e IncorrectKindError) Error() string { - return fmt.Sprintf("incorrect kind: expected '%s', got '%s'", e.Expected, e.Actual) + if len(e.Expected) < 2 { + return fmt.Sprintf("incorrect kind: expected '%s', got '%s'", e.Expected[0], e.Actual) + } + return fmt.Sprintf("incorrect kind: expected any of '%s', got '%s'", e.Expected, e.Actual) } type FieldNotFoundError struct { diff --git a/internal/reflectbuild/reflectbuild_test.go b/internal/reflectbuild/reflectbuild_test.go index 26d526c..692e0c2 100644 --- a/internal/reflectbuild/reflectbuild_test.go +++ b/internal/reflectbuild/reflectbuild_test.go @@ -140,7 +140,7 @@ func TestSliceNewElemNested(t *testing.T) { func TestIncorrectKindError(t *testing.T) { err := reflectbuild.IncorrectKindError{ Actual: reflect.String, - Expected: reflect.Struct, + Expected: []reflect.Kind{reflect.Struct}, } assert.NotEmpty(t, err.Error()) } diff --git a/unmarshal.go b/unmarshal.go index 3dd4559..06d8eb9 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -45,6 +45,10 @@ type unmarshaler struct { // keyval if a field is missing. parsingTable bool + // Counters that indicate that we are skipping TOML expressions. It happens + // when the document contains values that are not in the target struct. + // TODO: signal the parser that it can just scan to avoid processing the + // unused data. skipKeyValCount uint skipTable bool } @@ -152,7 +156,8 @@ func (u *unmarshaler) StringValue(v []byte) { } u.builder.Load() } else { - u.err = u.builder.SetString(string(v)) + s := string(v) + u.err = u.builder.Set(reflect.ValueOf(&s)) } } @@ -216,7 +221,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) { } u.builder.Load() } else { - u.err = u.builder.Set(reflect.ValueOf(date)) + u.err = u.builder.Set(reflect.ValueOf(&date)) } } @@ -232,7 +237,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) { } u.builder.Load() } else { - u.err = u.builder.Set(reflect.ValueOf(dt)) + u.err = u.builder.Set(reflect.ValueOf(&dt)) } } @@ -248,7 +253,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) { } u.builder.Load() } else { - u.err = u.builder.Set(reflect.ValueOf(dt)) + u.err = u.builder.Set(reflect.ValueOf(&dt)) } } @@ -264,7 +269,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) { } u.builder.Load() } else { - u.err = u.builder.Set(reflect.ValueOf(localTime)) + u.err = u.builder.Set(reflect.ValueOf(&localTime)) } } @@ -313,4 +318,6 @@ func (u *unmarshaler) StandardTableEnd() { if u.skipping() || u.err != nil { return } + + u.builder.EnsureStructOrMap() }