Arrays support

This commit is contained in:
Thomas Pelletier
2021-03-24 20:21:55 -04:00
parent a25f636a07
commit a0d031abec
6 changed files with 152 additions and 66 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ Development branch. Probably does not work.
- [x] Unmarshal into pointers.
- [x] Support Date / times.
- [x] Support struct tags annotations.
- [ ] Support Arrays.
- [x] Support Arrays.
- [ ] Support Unmarshaler interface.
- [ ] Original go-toml unmarshal tests pass.
- [ ] Benchmark!
@@ -1883,20 +1883,6 @@ func TestUnmarshalArray(t *testing.T) {
assert.Equal(t, expected, actual)
}
func TestUnmarshalArrayFail(t *testing.T) {
var actual arrayTooSmallStruct
err := toml.Unmarshal([]byte(`str_slice = ["Howdy", "Hey There"]`), &actual)
assert.Error(t, err)
}
func TestUnmarshalArrayFail2(t *testing.T) {
doc := `str_slice=["Howdy","Hey There"]`
var actual arrayTooSmallStruct
err := toml.Unmarshal([]byte(doc), &actual)
assert.Error(t, err)
}
func TestUnmarshalArrayFail3(t *testing.T) {
doc := `[[struct_slice]]
String2="1"
+43 -10
View File
@@ -120,7 +120,9 @@ func (t mapTarget) setFloat64(v float64) error {
return t.set(reflect.ValueOf(v))
}
func ensureSlice(t target) error {
// makes sure that the value pointed at by t is indexable (Slice, Array), or
// dereferences to an indexable (Ptr, Interface).
func ensureValueIndexable(t target) error {
f := t.get()
switch f.Type().Kind() {
@@ -144,7 +146,9 @@ func ensureSlice(t target) error {
}
f = t.get()
}
return ensureSlice(valueTarget(f.Elem()))
return ensureValueIndexable(valueTarget(f.Elem()))
case reflect.Array:
// arrays are always initialized.
default:
return fmt.Errorf("cannot initialize a slice in %s", f.Kind())
}
@@ -305,24 +309,34 @@ func setFloat64(t target, v float64) error {
}
}
func pushNew(t target) (target, error) {
// Returns the element at idx of the value pointed at by target, or an error if
// t does not point to an indexable.
// If the target points to an Array and idx is out of bounds, it returns
// (nil, nil) as this is not a fatal error (the unmarshaler will skip).
func elementAt(t target, idx int) (target, error) {
f := t.get()
switch f.Kind() {
case reflect.Slice:
// TODO: use the idx function argument and avoid alloc if possible.
idx := f.Len()
err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
if err != nil {
return nil, err
}
return valueTarget(t.get().Index(idx)), nil
case reflect.Array:
if idx >= f.Len() {
return nil, nil
}
return valueTarget(f.Index(idx)), nil
case reflect.Interface:
if f.IsNil() {
panic("interface should have been initialized")
}
ifaceElem := f.Elem()
if ifaceElem.Kind() != reflect.Slice {
return nil, fmt.Errorf("cannot pushNew on a %s", f.Kind())
return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind())
}
idx := ifaceElem.Len()
newElem := reflect.New(ifaceElem.Type().Elem()).Elem()
@@ -333,13 +347,13 @@ func pushNew(t target) (target, error) {
}
return valueTarget(t.get().Elem().Index(idx)), nil
case reflect.Ptr:
return pushNew(valueTarget(f.Elem()))
return elementAt(valueTarget(f.Elem()), idx)
default:
return nil, fmt.Errorf("cannot pushNew on a %s", f.Kind())
return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind())
}
}
func scopeTableTarget(append bool, t target, name string) (target, bool, error) {
func (d *decoder) scopeTableTarget(append bool, t target, name string) (target, bool, error) {
x := t.get()
switch x.Kind() {
@@ -350,20 +364,27 @@ func scopeTableTarget(append bool, t target, name string) (target, bool, error)
if err != nil {
return t, false, err
}
return scopeTableTarget(append, t, name)
return d.scopeTableTarget(append, t, name)
case reflect.Ptr:
t, err := scopePtr(t)
if err != nil {
return t, false, err
}
return scopeTableTarget(append, t, name)
return d.scopeTableTarget(append, t, name)
case reflect.Slice:
t, err := scopeSlice(append, t)
if err != nil {
return t, false, err
}
append = false
return scopeTableTarget(append, t, name)
return d.scopeTableTarget(append, t, name)
case reflect.Array:
t, err := d.scopeArray(append, t)
if err != nil {
return t, false, err
}
append = false
return d.scopeTableTarget(append, t, name)
// Terminal kinds
@@ -443,6 +464,18 @@ func scopeSlice(append bool, t target) (target, error) {
return valueTarget(v.Index(v.Len() - 1)), nil
}
func (d *decoder) scopeArray(append bool, t target) (target, error) {
v := t.get()
idx := d.arrayIndex(append, v)
if idx >= v.Len() {
return nil, fmt.Errorf("not enough space in the array")
}
return valueTarget(v.Index(idx)), nil
}
func scopeMap(v reflect.Value, name string) (target, bool, error) {
if v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
+15 -10
View File
@@ -39,9 +39,10 @@ func TestStructTarget_Ensure(t *testing.T) {
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
d := decoder{}
target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err)
err = ensureSlice(target)
err = ensureValueIndexable(target)
v := target.get()
e.test(v, err)
})
@@ -86,7 +87,8 @@ func TestStructTarget_SetString(t *testing.T) {
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, _, err := scopeTableTarget(false, valueTarget(e.input), e.name)
d := decoder{}
target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err)
err = setString(target, str)
v := target.get()
@@ -102,15 +104,16 @@ func TestPushNew(t *testing.T) {
}
d := Doc{}
x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
dec := decoder{}
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err)
n, err := pushNew(x)
n, err := elementAt(x, 0)
require.NoError(t, err)
require.NoError(t, n.setString("hello"))
require.Equal(t, []string{"hello"}, d.A)
n, err = pushNew(x)
n, err = elementAt(x, 1)
require.NoError(t, err)
require.NoError(t, n.setString("world"))
require.Equal(t, []string{"hello", "world"}, d.A)
@@ -122,15 +125,16 @@ func TestPushNew(t *testing.T) {
}
d := Doc{}
x, _, err := scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
dec := decoder{}
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err)
n, err := pushNew(x)
n, err := elementAt(x, 0)
require.NoError(t, err)
require.NoError(t, setString(n, "hello"))
require.Equal(t, []interface{}{"hello"}, d.A)
n, err = pushNew(x)
n, err = elementAt(x, 1)
require.NoError(t, err)
require.NoError(t, setString(n, "world"))
require.Equal(t, []interface{}{"hello", "world"}, d.A)
@@ -164,7 +168,8 @@ func TestScope_Struct(t *testing.T) {
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
x, found, err := scopeTableTarget(false, valueTarget(e.input), e.name)
dec := decoder{}
x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name)
assert.Equal(t, e.found, found)
if e.err {
assert.Error(t, err)
+55 -24
View File
@@ -19,7 +19,9 @@ func Unmarshal(data []byte, v interface{}) error {
// TODO: remove me; sanity check
allValidOrDump(p.tree, p.tree)
return fromAst(p.tree, v)
d := decoder{}
return d.fromAst(p.tree, v)
}
func allValidOrDump(tree ast.Root, nodes []ast.Node) bool {
@@ -37,7 +39,28 @@ func allValidOrDump(tree ast.Root, nodes []ast.Node) bool {
return true
}
func fromAst(tree ast.Root, v interface{}) error {
type decoder struct {
// Tracks position in Go arrays.
arrayIndexes map[reflect.Value]int
}
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
if d.arrayIndexes == nil {
d.arrayIndexes = make(map[reflect.Value]int, 1)
}
idx, ok := d.arrayIndexes[v]
if !ok {
d.arrayIndexes[v] = 0
} else if append {
idx++
d.arrayIndexes[v] = idx
}
return idx
}
func (d *decoder) fromAst(tree ast.Root, v interface{}) error {
r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr {
return fmt.Errorf("need to target a pointer, not %s", r.Kind())
@@ -57,12 +80,12 @@ func fromAst(tree ast.Root, v interface{}) error {
if skipUntilTable {
continue
}
err = unmarshalKeyValue(current, &node)
err = d.unmarshalKeyValue(current, &node)
found = true
case ast.Table:
current, found, err = scopeWithKey(root, node.Key())
current, found, err = d.scopeWithKey(root, node.Key())
case ast.ArrayTable:
current, found, err = scopeWithArrayTable(root, node.Key())
current, found, err = d.scopeWithArrayTable(root, node.Key())
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
@@ -87,11 +110,11 @@ func fromAst(tree ast.Root, v interface{}) error {
//
// When encountering slices, it should always use its last element, and error
// if the slice does not have any.
func scopeWithKey(x target, key []ast.Node) (target, bool, error) {
func (d *decoder) scopeWithKey(x target, key []ast.Node) (target, bool, error) {
var err error
found := true
for _, n := range key {
x, found, err = scopeTableTarget(false, x, string(n.Data))
x, found, err = d.scopeTableTarget(false, x, string(n.Data))
if err != nil || !found {
return nil, found, err
}
@@ -104,18 +127,18 @@ func scopeWithKey(x target, key []ast.Node) (target, bool, error) {
//
// It is the same as scopeWithKey, but when scoping the last part of the key
// it creates a new element in the array instead of using the last one.
func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) {
func (d *decoder) scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) {
var err error
found := true
if len(key) > 1 {
for _, n := range key[:len(key)-1] {
x, found, err = scopeTableTarget(false, x, string(n.Data))
x, found, err = d.scopeTableTarget(false, x, string(n.Data))
if err != nil || !found {
return nil, found, err
}
}
}
x, found, err = scopeTableTarget(false, x, string(key[len(key)-1].Data))
x, found, err = d.scopeTableTarget(false, x, string(key[len(key)-1].Data))
if err != nil || !found {
return x, found, err
}
@@ -138,17 +161,20 @@ func scopeWithArrayTable(x target, key []ast.Node) (target, bool, error) {
v = x.get()
}
if v.Kind() == reflect.Slice {
switch v.Kind() {
case reflect.Slice:
x, err = scopeSlice(true, x)
case reflect.Array:
x, err = d.scopeArray(true, x)
}
return x, found, err
}
func unmarshalKeyValue(x target, node *ast.Node) error {
func (d *decoder) unmarshalKeyValue(x target, node *ast.Node) error {
assertNode(ast.KeyValue, node)
x, found, err := scopeWithKey(x, node.Key())
x, found, err := d.scopeWithKey(x, node.Key())
if err != nil {
return err
}
@@ -158,10 +184,10 @@ func unmarshalKeyValue(x target, node *ast.Node) error {
return nil
}
return unmarshalValue(x, node.Value())
return d.unmarshalValue(x, node.Value())
}
func unmarshalValue(x target, node *ast.Node) error {
func (d *decoder) unmarshalValue(x target, node *ast.Node) error {
switch node.Kind {
case ast.String:
return unmarshalString(x, node)
@@ -172,9 +198,9 @@ func unmarshalValue(x target, node *ast.Node) error {
case ast.Float:
return unmarshalFloat(x, node)
case ast.Array:
return unmarshalArray(x, node)
return d.unmarshalArray(x, node)
case ast.InlineTable:
return unmarshalInlineTable(x, node)
return d.unmarshalInlineTable(x, node)
case ast.LocalDateTime:
return unmarshalLocalDateTime(x, node)
case ast.DateTime:
@@ -242,11 +268,11 @@ func unmarshalFloat(x target, node *ast.Node) error {
return setFloat64(x, v)
}
func unmarshalInlineTable(x target, node *ast.Node) error {
func (d *decoder) unmarshalInlineTable(x target, node *ast.Node) error {
assertNode(ast.InlineTable, node)
for _, kv := range node.Children {
err := unmarshalKeyValue(x, &kv)
err := d.unmarshalKeyValue(x, &kv)
if err != nil {
return err
}
@@ -254,20 +280,25 @@ func unmarshalInlineTable(x target, node *ast.Node) error {
return nil
}
func unmarshalArray(x target, node *ast.Node) error {
func (d *decoder) unmarshalArray(x target, node *ast.Node) error {
assertNode(ast.Array, node)
err := ensureSlice(x)
err := ensureValueIndexable(x)
if err != nil {
return err
}
for _, n := range node.Children {
v, err := pushNew(x)
for idx, n := range node.Children {
v, err := elementAt(x, idx)
if err != nil {
return err
}
err = unmarshalValue(v, &n)
if v == nil {
// when we go out of bound for an array just stop processing it to
// mimic encoding/json
break
}
err = d.unmarshalValue(v, &n)
if err != nil {
return err
}
+38 -7
View File
@@ -613,6 +613,30 @@ B = "data"`,
}
},
},
{
desc: "array of structs with table arrays",
input: `[[A]]
B = "one"
[[A]]
B = "two"`,
gen: func() test {
type inner struct {
B string
}
type doc struct {
A [4]inner
}
return test{
target: &doc{},
expected: &doc{
A: [4]inner{
{B: "one"},
{B: "two"},
},
},
}
},
},
}
for _, e := range examples {
@@ -657,7 +681,8 @@ func TestFromAst_KV(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: "hello"}, x)
}
@@ -709,7 +734,8 @@ func TestFromAst_Table(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{
Level1: Level1{
@@ -755,7 +781,8 @@ func TestFromAst_Table(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{
A: A{B: B{C: "value"}},
@@ -805,7 +832,8 @@ func TestFromAst_InlineTable(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{
Name: Name{
@@ -849,7 +877,8 @@ func TestFromAst_Slice(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: []string{"hello", "world"}}, x)
})
@@ -885,7 +914,8 @@ func TestFromAst_Slice(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: []interface{}{"hello", "world"}}, x)
})
@@ -930,7 +960,8 @@ func TestFromAst_Slice(t *testing.T) {
}
x := Doc{}
err := fromAst(root, &x)
d := decoder{}
err := d.fromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: []interface{}{"hello", []interface{}{"inner1", "inner2"}}}, x)
})