From a0d031abec86cd817124037da5557501bc1fb95a Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Wed, 24 Mar 2021 20:21:55 -0400 Subject: [PATCH] Arrays support --- README.md | 2 +- .../imported_tests/unmarshal_imported_test.go | 14 ---- targets.go | 53 ++++++++++--- targets_test.go | 25 +++--- unmarshaler.go | 79 +++++++++++++------ unmarshaler_test.go | 45 +++++++++-- 6 files changed, 152 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index cfc08dc..c7f0738 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Development branch. Probably does not work. - [x] Unmarshal into pointers. - [x] Support Date / times. - [x] Support struct tags annotations. -- [ ] Support Arrays. +- [x] Support Arrays. - [ ] Support Unmarshaler interface. - [ ] Original go-toml unmarshal tests pass. - [ ] Benchmark! diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 0e06d0c..602765f 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -1883,20 +1883,6 @@ func TestUnmarshalArray(t *testing.T) { assert.Equal(t, expected, actual) } -func TestUnmarshalArrayFail(t *testing.T) { - var actual arrayTooSmallStruct - err := toml.Unmarshal([]byte(`str_slice = ["Howdy", "Hey There"]`), &actual) - assert.Error(t, err) -} - -func TestUnmarshalArrayFail2(t *testing.T) { - doc := `str_slice=["Howdy","Hey There"]` - - var actual arrayTooSmallStruct - err := toml.Unmarshal([]byte(doc), &actual) - assert.Error(t, err) -} - func TestUnmarshalArrayFail3(t *testing.T) { doc := `[[struct_slice]] String2="1" diff --git a/targets.go b/targets.go index 8595f09..e620d28 100644 --- a/targets.go +++ b/targets.go @@ -120,7 +120,9 @@ func (t mapTarget) setFloat64(v float64) error { return t.set(reflect.ValueOf(v)) } -func ensureSlice(t target) error { +// makes sure that the value pointed at by t is indexable (Slice, Array), or +// dereferences to an indexable (Ptr, Interface). +func ensureValueIndexable(t target) error { f := t.get() switch f.Type().Kind() { @@ -144,7 +146,9 @@ func ensureSlice(t target) error { } f = t.get() } - return ensureSlice(valueTarget(f.Elem())) + return ensureValueIndexable(valueTarget(f.Elem())) + case reflect.Array: + // arrays are always initialized. default: return fmt.Errorf("cannot initialize a slice in %s", f.Kind()) } @@ -305,24 +309,34 @@ func setFloat64(t target, v float64) error { } } -func pushNew(t target) (target, error) { +// Returns the element at idx of the value pointed at by target, or an error if +// t does not point to an indexable. +// If the target points to an Array and idx is out of bounds, it returns +// (nil, nil) as this is not a fatal error (the unmarshaler will skip). +func elementAt(t target, idx int) (target, error) { f := t.get() switch f.Kind() { case reflect.Slice: + // TODO: use the idx function argument and avoid alloc if possible. idx := f.Len() err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) if err != nil { return nil, err } return valueTarget(t.get().Index(idx)), nil + case reflect.Array: + if idx >= f.Len() { + return nil, nil + } + return valueTarget(f.Index(idx)), nil case reflect.Interface: if f.IsNil() { panic("interface should have been initialized") } ifaceElem := f.Elem() if ifaceElem.Kind() != reflect.Slice { - return nil, fmt.Errorf("cannot pushNew on a %s", f.Kind()) + return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind()) } idx := ifaceElem.Len() newElem := reflect.New(ifaceElem.Type().Elem()).Elem() @@ -333,13 +347,13 @@ func pushNew(t target) (target, error) { } return valueTarget(t.get().Elem().Index(idx)), nil case reflect.Ptr: - return pushNew(valueTarget(f.Elem())) + return elementAt(valueTarget(f.Elem()), idx) default: - return nil, fmt.Errorf("cannot pushNew on a %s", f.Kind()) + return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind()) } } -func scopeTableTarget(append bool, t target, name string) (target, bool, error) { +func (d *decoder) scopeTableTarget(append bool, t target, name string) (target, bool, error) { x := t.get() switch x.Kind() { @@ -350,20 +364,27 @@ func scopeTableTarget(append bool, t target, name string) (target, bool, error) if err != nil { return t, false, err } - return scopeTableTarget(append, t, name) + return d.scopeTableTarget(append, t, name) case reflect.Ptr: t, err := scopePtr(t) if err != nil { return t, false, err } - return scopeTableTarget(append, t, name) + return d.scopeTableTarget(append, t, name) case reflect.Slice: t, err := scopeSlice(append, t) if err != nil { return t, false, err } append = false - return scopeTableTarget(append, t, name) + return d.scopeTableTarget(append, t, name) + case reflect.Array: + t, err := d.scopeArray(append, t) + if err != nil { + return t, false, err + } + append = false + return d.scopeTableTarget(append, t, name) // Terminal kinds @@ -443,6 +464,18 @@ func scopeSlice(append bool, t target) (target, error) { return valueTarget(v.Index(v.Len() - 1)), nil } +func (d *decoder) scopeArray(append bool, t target) (target, error) { + v := t.get() + + idx := d.arrayIndex(append, v) + + if idx >= v.Len() { + return nil, fmt.Errorf("not enough space in the array") + } + + return valueTarget(v.Index(idx)), nil +} + func scopeMap(v reflect.Value, name string) (target, bool, error) { if v.IsNil() { v.Set(reflect.MakeMap(v.Type())) diff --git a/targets_test.go b/targets_test.go index 7316994..86aab96 100644 --- a/targets_test.go +++ b/targets_test.go @@ -39,9 +39,10 @@ func TestStructTarget_Ensure(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name) + d := decoder{} + target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) - err = ensureSlice(target) + err = ensureValueIndexable(target) v := target.get() e.test(v, err) }) @@ -86,7 +87,8 @@ func TestStructTarget_SetString(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name) + d := decoder{} + target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) require.NoError(t, err) err = setString(target, str) v := target.get() @@ -102,15 +104,16 @@ func TestPushNew(t *testing.T) { } d := Doc{} - x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") + dec := decoder{} + x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) - n, err := pushNew(x) + n, err := elementAt(x, 0) require.NoError(t, err) require.NoError(t, n.setString("hello")) require.Equal(t, []string{"hello"}, d.A) - n, err = pushNew(x) + n, err = elementAt(x, 1) require.NoError(t, err) require.NoError(t, n.setString("world")) require.Equal(t, []string{"hello", "world"}, d.A) @@ -122,15 +125,16 @@ func TestPushNew(t *testing.T) { } d := Doc{} - x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") + dec := decoder{} + x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) - n, err := pushNew(x) + n, err := elementAt(x, 0) require.NoError(t, err) require.NoError(t, setString(n, "hello")) require.Equal(t, []interface{}{"hello"}, d.A) - n, err = pushNew(x) + n, err = elementAt(x, 1) require.NoError(t, err) require.NoError(t, setString(n, "world")) require.Equal(t, []interface{}{"hello", "world"}, d.A) @@ -164,7 +168,8 @@ func TestScope_Struct(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { - x, found, err := scopeTableTarget(false, valueTarget(e.input), e.name) + dec := decoder{} + x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name) assert.Equal(t, e.found, found) if e.err { assert.Error(t, err) diff --git a/unmarshaler.go b/unmarshaler.go index 47592ed..68881af 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -19,7 +19,9 @@ func Unmarshal(data []byte, v interface{}) error { // TODO: remove me; sanity check allValidOrDump(p.tree, p.tree) - return fromAst(p.tree, v) + d := decoder{} + + return d.fromAst(p.tree, v) } func allValidOrDump(tree ast.Root, nodes []ast.Node) bool { @@ -37,7 +39,28 @@ func allValidOrDump(tree ast.Root, nodes []ast.Node) bool { return true } -func fromAst(tree ast.Root, v interface{}) error { +type decoder struct { + // Tracks position in Go arrays. + arrayIndexes map[reflect.Value]int +} + +func (d *decoder) arrayIndex(append bool, v reflect.Value) int { + if d.arrayIndexes == nil { + d.arrayIndexes = make(map[reflect.Value]int, 1) + } + + idx, ok := d.arrayIndexes[v] + + if !ok { + d.arrayIndexes[v] = 0 + } else if append { + idx++ + d.arrayIndexes[v] = idx + } + return idx +} + +func (d *decoder) fromAst(tree ast.Root, v interface{}) error { r := reflect.ValueOf(v) if r.Kind() != reflect.Ptr { return fmt.Errorf("need to target a pointer, not %s", r.Kind()) @@ -57,12 +80,12 @@ func fromAst(tree ast.Root, v interface{}) error { if skipUntilTable { continue } - err = unmarshalKeyValue(current, &node) + err = d.unmarshalKeyValue(current, &node) found = true case ast.Table: - current, found, err = scopeWithKey(root, node.Key()) + current, found, err = d.scopeWithKey(root, node.Key()) case ast.ArrayTable: - current, found, err = scopeWithArrayTable(root, node.Key()) + current, found, err = d.scopeWithArrayTable(root, node.Key()) default: panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) } @@ -87,11 +110,11 @@ func fromAst(tree ast.Root, v interface{}) error { // // When encountering slices, it should always use its last element, and error // if the slice does not have any. -func scopeWithKey(x target, key []ast.Node) (target, bool, error) { +func (d *decoder) scopeWithKey(x target, key []ast.Node) (target, bool, error) { var err error found := true for _, n := range key { - x, found, err = scopeTableTarget(false, x, string(n.Data)) + x, found, err = d.scopeTableTarget(false, x, string(n.Data)) if err != nil || !found { return nil, found, err } @@ -104,18 +127,18 @@ func scopeWithKey(x target, key []ast.Node) (target, bool, error) { // // It is the same as scopeWithKey, but when scoping the last part of the key // it creates a new element in the array instead of using the last one. -func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) { +func (d *decoder) scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) { var err error found := true if len(key) > 1 { for _, n := range key[:len(key)-1] { - x, found, err = scopeTableTarget(false, x, string(n.Data)) + x, found, err = d.scopeTableTarget(false, x, string(n.Data)) if err != nil || !found { return nil, found, err } } } - x, found, err = scopeTableTarget(false, x, string(key[len(key)-1].Data)) + x, found, err = d.scopeTableTarget(false, x, string(key[len(key)-1].Data)) if err != nil || !found { return x, found, err } @@ -138,17 +161,20 @@ func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) { v = x.get() } - if v.Kind() == reflect.Slice { + switch v.Kind() { + case reflect.Slice: x, err = scopeSlice(true, x) + case reflect.Array: + x, err = d.scopeArray(true, x) } return x, found, err } -func unmarshalKeyValue(x target, node *ast.Node) error { +func (d *decoder) unmarshalKeyValue(x target, node *ast.Node) error { assertNode(ast.KeyValue, node) - x, found, err := scopeWithKey(x, node.Key()) + x, found, err := d.scopeWithKey(x, node.Key()) if err != nil { return err } @@ -158,10 +184,10 @@ func unmarshalKeyValue(x target, node *ast.Node) error { return nil } - return unmarshalValue(x, node.Value()) + return d.unmarshalValue(x, node.Value()) } -func unmarshalValue(x target, node *ast.Node) error { +func (d *decoder) unmarshalValue(x target, node *ast.Node) error { switch node.Kind { case ast.String: return unmarshalString(x, node) @@ -172,9 +198,9 @@ func unmarshalValue(x target, node *ast.Node) error { case ast.Float: return unmarshalFloat(x, node) case ast.Array: - return unmarshalArray(x, node) + return d.unmarshalArray(x, node) case ast.InlineTable: - return unmarshalInlineTable(x, node) + return d.unmarshalInlineTable(x, node) case ast.LocalDateTime: return unmarshalLocalDateTime(x, node) case ast.DateTime: @@ -242,11 +268,11 @@ func unmarshalFloat(x target, node *ast.Node) error { return setFloat64(x, v) } -func unmarshalInlineTable(x target, node *ast.Node) error { +func (d *decoder) unmarshalInlineTable(x target, node *ast.Node) error { assertNode(ast.InlineTable, node) for _, kv := range node.Children { - err := unmarshalKeyValue(x, &kv) + err := d.unmarshalKeyValue(x, &kv) if err != nil { return err } @@ -254,20 +280,25 @@ func unmarshalInlineTable(x target, node *ast.Node) error { return nil } -func unmarshalArray(x target, node *ast.Node) error { +func (d *decoder) unmarshalArray(x target, node *ast.Node) error { assertNode(ast.Array, node) - err := ensureSlice(x) + err := ensureValueIndexable(x) if err != nil { return err } - for _, n := range node.Children { - v, err := pushNew(x) + for idx, n := range node.Children { + v, err := elementAt(x, idx) if err != nil { return err } - err = unmarshalValue(v, &n) + if v == nil { + // when we go out of bound for an array just stop processing it to + // mimic encoding/json + break + } + err = d.unmarshalValue(v, &n) if err != nil { return err } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 8e4c084..af1592a 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -613,6 +613,30 @@ B = "data"`, } }, }, + { + desc: "array of structs with table arrays", + input: `[[A]] + B = "one" + [[A]] + B = "two"`, + gen: func() test { + type inner struct { + B string + } + type doc struct { + A [4]inner + } + return test{ + target: &doc{}, + expected: &doc{ + A: [4]inner{ + {B: "one"}, + {B: "two"}, + }, + }, + } + }, + }, } for _, e := range examples { @@ -657,7 +681,8 @@ func TestFromAst_KV(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: "hello"}, x) } @@ -709,7 +734,8 @@ func TestFromAst_Table(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{ Level1: Level1{ @@ -755,7 +781,8 @@ func TestFromAst_Table(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{ A: A{B: B{C: "value"}}, @@ -805,7 +832,8 @@ func TestFromAst_InlineTable(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{ Name: Name{ @@ -849,7 +877,8 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []string{"hello", "world"}}, x) }) @@ -885,7 +914,8 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []interface{}{"hello", "world"}}, x) }) @@ -930,7 +960,8 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := fromAst(root, &x) + d := decoder{} + err := d.fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []interface{}{"hello", []interface{}{"inner1", "inner2"}}}, x) })