From 7dbf7554c41cfcbedc0c9173a364a647f2fd0586 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 9 Feb 2021 19:22:54 -0500 Subject: [PATCH] Nested arrays --- internal/reflectbuild/reflectbuild.go | 239 +++++++++++++++++++++ internal/reflectbuild/reflectbuild_test.go | 167 ++++++++++++++ unmarshal.go | 183 +++++++--------- unmarshal_test.go | 36 ++++ 4 files changed, 523 insertions(+), 102 deletions(-) create mode 100644 internal/reflectbuild/reflectbuild.go create mode 100644 internal/reflectbuild/reflectbuild_test.go diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go new file mode 100644 index 0000000..7199f9f --- /dev/null +++ b/internal/reflectbuild/reflectbuild.go @@ -0,0 +1,239 @@ +// reflectbuild is a package that provides utility functions to build Go +// objects using reflection. +package reflectbuild + +import ( + "fmt" + "reflect" + "strings" +) + +// Builder wraps a value and provides method to modify its structure. +// It is a stateful object that keeps a cursor of what part of the object is +// being modified. +// Create a Builder with NewBuilder. +type Builder struct { + root reflect.Value + // Root is always a pointer to a non-nil value. + // Cursor is the top of the stack. + stack []reflect.Value +} + +// NewBuilder creates a Builder to construct v. +// If v is nil or not a pointer, an error will be returned. +func NewBuilder(v interface{}) (Builder, error) { + if v == nil { + return Builder{}, fmt.Errorf("cannot build a nil value") + } + + rv := reflect.ValueOf(v) + if rv.Type().Kind() != reflect.Ptr { + return Builder{}, fmt.Errorf("cannot build a %s: need a pointer", rv.Type().Kind()) + } + + return Builder{ + root: rv.Elem(), + stack: []reflect.Value{rv.Elem()}, + }, nil +} + +func (b *Builder) top() reflect.Value { + return b.stack[len(b.stack)-1] +} + +func (b *Builder) push(v reflect.Value) { + b.stack = append(b.stack, v) +} + +func (b *Builder) pop() { + b.stack = b.stack[:len(b.stack)-1] +} + +func (b *Builder) len() int { + return len(b.stack) +} + +func (b *Builder) Dump() string { + str := strings.Builder{} + str.WriteByte('[') + + for i, x := range b.stack { + if i > 0 { + str.WriteString(" | ") + } + fmt.Fprintf(&str, "%s (%s)", x.Type(), x) + } + + str.WriteByte(']') + return str.String() +} + +func (b *Builder) replace(v reflect.Value) { + b.stack[len(b.stack)-1] = v +} + +// DigField pushes the cursor into a field of the current struct. +// Errors if the current value is not a struct, or the field does not exist. +func (b *Builder) DigField(s string) error { + t := b.top() + + err := checkKind(t.Type(), reflect.Struct) + if err != nil { + return err + } + + f := t.FieldByName(s) + if !f.IsValid() { + return FieldNotFoundError{FieldName: s, Struct: t} + } + + b.replace(f) + + return nil +} + +// Save stores a copy of the current cursor position. +// It can be restored using Back(). +// Save points are stored as a stack. +func (b *Builder) Save() { + b.push(b.top()) +} + +// Reset brings the cursor back to the root object. +func (b *Builder) Reset() { + b.stack = b.stack[:1] + b.stack[0] = b.root +} + +// Load is the opposite of Save. It discards the current cursor and loads the +// last saved cursor. +// Panics if no cursor has been saved. +func (b *Builder) Load() { + if b.len() < 2 { + panic(fmt.Errorf("tried to Back() when cursor was already at root")) + } + b.pop() +} + +// Cursor returns the value pointed at by the cursor. +func (b *Builder) Cursor() reflect.Value { + return b.top() +} + +func (b *Builder) IsSlice() bool { + return b.top().Kind() == reflect.Slice +} + +// Last moves the cursor to the last value of the current value. +// For a slice or an array, it is the last element they contain, if any. +// For anything else, it's a no-op. +func (b *Builder) Last() { + switch b.Cursor().Kind() { + case reflect.Slice, reflect.Array: + length := b.Cursor().Len() + if length > 0 { + x := b.Cursor().Index(length - 1) + b.replace(x) + } + } +} + +// SliceLastOrCreate moves the cursor to the last element of the slice if any. +// Otherwise creates a new element in that slice and moves to it. +func (b *Builder) SliceLastOrCreate() error { + t := b.top() + err := checkKind(t.Type(), reflect.Slice) + if err != nil { + return err + } + + if t.Len() == 0 { + return b.SliceNewElem() + } + b.Last() + return nil +} + +// SliceNewElem operates on a slice. It creates a new object (of type contained +// by the slice), append it to the slice, and moves the cursor to the new +// object. +func (b *Builder) SliceNewElem() error { + t := b.top() + err := checkKind(t.Type(), reflect.Slice) + if err != nil { + return err + } + elem := reflect.New(t.Type().Elem()) + newSlice := reflect.Append(t, elem.Elem()) + t.Set(newSlice) + b.replace(t.Index(t.Len() - 1)) + return nil +} + +func (b *Builder) SliceAppend(v reflect.Value) error { + t := b.top() + err := checkKind(t.Type(), reflect.Slice) + if err != nil { + return err + } + newSlice := reflect.Append(t, v) + t.Set(newSlice) + b.replace(t.Index(t.Len() - 1)) + return nil +} + +// Set the value at the cursor to the given string. +// Errors if a string cannot be assigned to the current value. +func (b *Builder) SetString(s string) error { + t := b.top() + + err := checkKind(t.Type(), reflect.String) + if err != nil { + return err + } + + t.SetString(s) + return nil +} + +// Set the value at the cursor to the given boolean. +// Errors if a boolean cannot be assigned to the current value. +func (b *Builder) SetBool(v bool) error { + t := b.top() + + err := checkKind(t.Type(), reflect.Bool) + if err != nil { + return err + } + + t.SetBool(v) + return nil +} + +func checkKind(rt reflect.Type, expected reflect.Kind) error { + if rt.Kind() != expected { + return IncorrectKindError{ + Actual: rt.Kind(), + Expected: expected, + } + } + return nil +} + +type IncorrectKindError struct { + Actual reflect.Kind + Expected reflect.Kind +} + +func (e IncorrectKindError) Error() string { + return fmt.Sprintf("incorrect kind: expected '%s', got '%s'", e.Expected, e.Actual) +} + +type FieldNotFoundError struct { + Struct reflect.Value + FieldName string +} + +func (e FieldNotFoundError) Error() string { + return fmt.Sprintf("field not found: '%s' on '%s'", e.FieldName, e.Struct.Type()) +} diff --git a/internal/reflectbuild/reflectbuild_test.go b/internal/reflectbuild/reflectbuild_test.go new file mode 100644 index 0000000..433a341 --- /dev/null +++ b/internal/reflectbuild/reflectbuild_test.go @@ -0,0 +1,167 @@ +package reflectbuild_test + +import ( + "reflect" + "testing" + + "github.com/pelletier/go-toml/v2/internal/reflectbuild" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBuilderSuccess(t *testing.T) { + x := struct{}{} + _, err := reflectbuild.NewBuilder(&x) + assert.NoError(t, err) +} + +func TestNewBuilderNil(t *testing.T) { + _, err := reflectbuild.NewBuilder(nil) + assert.Error(t, err) +} + +func TestNewBuilderNonPtr(t *testing.T) { + _, err := reflectbuild.NewBuilder(struct{}{}) + assert.Error(t, err) +} + +func TestDigField(t *testing.T) { + x := struct { + Field string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + assert.Error(t, b.DigField("oops")) + assert.NoError(t, b.DigField("Field")) + assert.Error(t, b.DigField("does not work on strings")) +} + +func TestBack(t *testing.T) { + x := struct { + A string + B string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + b.Save() + assert.NoError(t, b.DigField("A")) + assert.NoError(t, b.SetString("A")) + b.Load() + b.Save() + assert.NoError(t, b.DigField("B")) + assert.NoError(t, b.SetString("B")) + assert.Equal(t, "A", x.A) + assert.Equal(t, "B", x.B) + b.Load() // back to root + assert.Panics(t, func() { + b.Load() // oops + }) +} + +func TestReset(t *testing.T) { + x := struct { + A []string + B string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + require.NoError(t, b.DigField("A")) + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("hello")) + b.Reset() + require.NoError(t, b.DigField("B")) + require.NoError(t, b.SetString("world")) + + assert.Equal(t, []string{"hello"}, x.A) + assert.Equal(t, "world", x.B) +} + +func TestSetString(t *testing.T) { + x := struct { + Field string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + assert.Error(t, b.SetString("oops")) + require.NoError(t, b.DigField("Field")) + require.NoError(t, b.SetString("hello!")) + assert.Equal(t, "hello!", x.Field) +} + +func TestSliceNewElem(t *testing.T) { + x := struct { + Field []string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + require.NoError(t, b.DigField("Field")) + b.Save() + + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("Val1")) + b.Load() + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("Val2")) + + require.Error(t, b.SliceNewElem()) + + assert.Equal(t, []string{"Val1", "Val2"}, x.Field) +} + +func TestSliceNewElemNested(t *testing.T) { + x := struct { + Field [][]string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + require.NoError(t, b.DigField("Field")) + + b.Save() + + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("Val1.1")) + b.Load() + b.Save() + + require.NoError(t, b.SliceNewElem()) + b.Save() + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("Val2.1")) + b.Load() + require.NoError(t, b.SliceNewElem()) + require.NoError(t, b.SetString("Val2.2")) + b.Load() + require.NoError(t, b.SliceNewElem()) + + assert.Equal(t, [][]string{{"Val1.1"}, {"Val2.1", "Val2.2"}, nil}, x.Field) +} + +func TestIncorrectKindError(t *testing.T) { + err := reflectbuild.IncorrectKindError{ + Actual: reflect.String, + Expected: reflect.Struct, + } + assert.NotEmpty(t, err.Error()) +} + +func TestFieldNotFoundError(t *testing.T) { + err := reflectbuild.FieldNotFoundError{ + Struct: reflect.ValueOf(struct { + Blah string + }{}), + FieldName: "Foo", + } + assert.NotEmpty(t, err.Error()) +} + +func TestCursor(t *testing.T) { + x := struct { + Field string + }{} + b, err := reflectbuild.NewBuilder(&x) + require.NoError(t, err) + assert.Equal(t, b.Cursor().Kind(), reflect.Struct) + require.NoError(t, b.DigField("Field")) + assert.Equal(t, b.Cursor().Kind(), reflect.String) +} diff --git a/unmarshal.go b/unmarshal.go index 6e83f02..170ee8d 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -5,30 +5,24 @@ import ( "encoding/hex" "fmt" "reflect" + + "github.com/pelletier/go-toml/v2/internal/reflectbuild" ) func Unmarshal(data []byte, v interface{}) error { - if v == nil { - return fmt.Errorf("cannot unmarshal to nil target") - } - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr { - return fmt.Errorf("can only marshal to pointer, not %s", rv.Kind()) - } - - u := &unmarshaler{stack: []reflect.Value{rv.Elem()}} - parseErr := parser{builder: u}.parse(data) - if parseErr != nil { - return parseErr + u := &unmarshaler{} + u.builder, u.err = reflectbuild.NewBuilder(v) + if u.err == nil { + parseErr := parser{builder: u}.parse(data) + if parseErr != nil { + return parseErr + } } return u.err } type unmarshaler struct { - // Each stack frame is a pointer to the root object that should be - // considered when settings values. - // It at least contains the root object passed to Unmarshal. - stack []reflect.Value + builder reflectbuild.Builder // First error that appeared during the construction of the object. // When set all callbacks are no-ops. @@ -41,6 +35,35 @@ type unmarshaler struct { // Table Arrays need a buffer of keys because we need to know which one is // the last one, as it may result in creating a new element in the array. arrayTableKey [][]byte + + // Flag to indicate that the next value is an an assignment. + // Assignments are when the builder already points to the value, and should + // be directly assigned. This is used to distinguish between assigning or + // appending to arrays. + assign bool +} + +func (u *unmarshaler) Assignation() { + u.assign = true +} + +func (u *unmarshaler) ArrayBegin() { + if u.err != nil { + return + } + u.builder.Save() + if u.assign { + u.assign = false + } else { + u.builder.SliceNewElem() + } +} + +func (u *unmarshaler) ArrayEnd() { + if u.err != nil { + return + } + u.builder.Load() } func (u *unmarshaler) ArrayTableBegin() { @@ -56,99 +79,48 @@ func (u *unmarshaler) ArrayTableEnd() { return } - u.parsingTableArray = false + u.builder.Reset() - u.stack = u.stack[:1] - - parent := u.top() - for _, k := range u.arrayTableKey { - switch parent.Type().Kind() { - case reflect.Slice: - l := parent.Len() - parent = parent.Index(l - 1) - case reflect.Struct: - default: - u.err = fmt.Errorf("value of type '%s' cannot have children", parent) + for _, v := range u.arrayTableKey[:len(u.arrayTableKey)-1] { + u.err = u.builder.DigField(string(v)) + if u.err != nil { return } - - f := parent.FieldByName(string(k)) - if !f.IsValid() { - // TODO: implement alternative names - u.err = fmt.Errorf("field '%s' not found", string(k)) - return - } - parent = f + u.err = u.builder.SliceLastOrCreate() } - if parent.Type().Kind() != reflect.Slice { - u.err = fmt.Errorf("array table key is not a slice") + v := u.arrayTableKey[len(u.arrayTableKey)-1] + u.err = u.builder.DigField(string(v)) + if u.err != nil { return } + u.err = u.builder.SliceNewElem() - n := reflect.New(parent.Type().Elem()) - parent.Set(reflect.Append(parent, n.Elem())) - last := parent.Index(parent.Len() - 1) - u.push(last) + u.parsingTableArray = false u.arrayTableKey = u.arrayTableKey[:0] } func (u *unmarshaler) KeyValBegin() { - u.push(u.top()) + u.builder.Save() } func (u *unmarshaler) KeyValEnd() { - u.pop() -} - -func (u *unmarshaler) getOrCreateChild(key string) (reflect.Value, error) { - parent := u.top() - switch parent.Type().Kind() { - case reflect.Slice: - l := parent.Len() - parent = parent.Index(l - 1) - case reflect.Struct: - default: - return reflect.Value{}, fmt.Errorf("value of type '%s' cannot have children", parent) - } - - f := parent.FieldByName(key) - if !f.IsValid() { - // TODO: implement alternative names - return reflect.Value{}, fmt.Errorf("field '%s' not found", key) - } - // TODO create things - return f, nil -} - -func (u *unmarshaler) top() reflect.Value { - return u.stack[len(u.stack)-1] -} - -func (u *unmarshaler) push(v reflect.Value) { - u.stack = append(u.stack, v) -} - -func (u *unmarshaler) pop() { - u.stack = u.stack[:len(u.stack)-1] -} - -func (u *unmarshaler) replace(v reflect.Value) { - u.stack[len(u.stack)-1] = v + u.builder.Load() } func (u *unmarshaler) StringValue(v []byte) { if u.err != nil { return } - - t := u.top() - if t.Type().Kind() == reflect.Slice { - s := reflect.ValueOf(string(v)) - n := reflect.Append(t, s) - t.Set(n) + if u.builder.IsSlice() { + u.builder.Save() + u.err = u.builder.SliceAppend(reflect.ValueOf(string(v))) + if u.err != nil { + return + } + u.builder.Load() } else { - u.top().SetString(string(v)) + u.err = u.builder.SetString(string(v)) } } @@ -156,14 +128,15 @@ func (u *unmarshaler) BoolValue(b bool) { if u.err != nil { return } - - t := u.top() - if t.Type().Kind() == reflect.Slice { - s := reflect.ValueOf(b) - n := reflect.Append(t, s) - t.Set(n) + if u.builder.IsSlice() { + u.builder.Save() + u.err = u.builder.SliceAppend(reflect.ValueOf(b)) + if u.err != nil { + return + } + u.builder.Load() } else { - u.top().SetBool(b) + u.err = u.builder.SetBool(b) } } @@ -175,12 +148,13 @@ func (u *unmarshaler) SimpleKey(v []byte) { if u.parsingTableArray { u.arrayTableKey = append(u.arrayTableKey, v) } else { - target, err := u.getOrCreateChild(string(v)) - if err != nil { - u.err = err - return + if u.builder.Cursor().Kind() == reflect.Slice { + u.err = u.builder.SliceLastOrCreate() + if u.err != nil { + return + } } - u.replace(target) + u.err = u.builder.DigField(string(v)) } } @@ -190,9 +164,7 @@ func (u *unmarshaler) StandardTableBegin() { } // tables are only top-level - u.stack = u.stack[:1] - - u.push(u.top()) + u.builder.Reset() } func (u *unmarshaler) StandardTableEnd() { @@ -210,6 +182,9 @@ type builder interface { ArrayTableEnd() KeyValBegin() KeyValEnd() + ArrayBegin() + ArrayEnd() + Assignation() StringValue(v []byte) BoolValue(b bool) @@ -355,6 +330,7 @@ func (p parser) parseKeyval(b []byte) ([]byte, error) { if err != nil { return nil, err } + p.builder.Assignation() b = p.parseWhitespace(b) return p.parseVal(b) @@ -471,6 +447,9 @@ func (p parser) parseValArray(b []byte) ([]byte, error) { //array-sep = %x2C ; , Comma //ws-comment-newline = *( wschar / [ comment ] newline ) + p.builder.ArrayBegin() + defer p.builder.ArrayEnd() + b = b[1:] first := true diff --git a/unmarshal_test.go b/unmarshal_test.go index af7a893..b876aec 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -245,6 +245,42 @@ func TestArraySimple(t *testing.T) { assert.Equal(t, []string{"hello", "world"}, x.Foo) } +func TestArrayNestedInTable(t *testing.T) { + x := struct { + Wrapper struct { + Foo []string + } + }{} + doc := `[Wrapper] +Foo = ["hello", "world"]` + err := toml.Unmarshal([]byte(doc), &x) + require.NoError(t, err) + assert.Equal(t, []string{"hello", "world"}, x.Wrapper.Foo) +} + +func TestArrayMixed(t *testing.T) { + x := struct { + Wrapper struct { + Foo []interface{} + } + }{} + doc := `[Wrapper] +Foo = ["hello", true]` + err := toml.Unmarshal([]byte(doc), &x) + require.NoError(t, err) + assert.Equal(t, []interface{}{"hello", true}, x.Wrapper.Foo) +} + +func TestArrayNested(t *testing.T) { + x := struct { + Foo [][]string + }{} + doc := `Foo = [["hello", "world"], ["a"], []]` + err := toml.Unmarshal([]byte(doc), &x) + require.NoError(t, err) + assert.Equal(t, [][]string{{"hello", "world"}, {"a"}, nil}, x.Foo) +} + func TestBool(t *testing.T) { x := struct { Truthy bool