diff --git a/internal/ast/ast.go b/internal/ast/ast.go index c0748f6..25a89be 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -129,23 +129,29 @@ type Node struct { // InlineTables have one child per key-value pair in the table. // KeyValues have at least two children. The last one is the value. The // rest make a potentially dotted key. + // Table and Array table have one child per element of the key they + // represent (same as KeyValue, but without the last node being the value). Children []Node } var NoNode = Node{} -// Key returns the nodes making the Key of a KeyValue. +// Key returns the child nodes making the Key on a supported node. Panics +// otherwise. // They are guaranteed to be all be of the Kind Key. A simple key would return // just one element. -// Panics if not called on a KeyValue node, or if the Children are malformed. func (n *Node) Key() []Node { - if n.Kind != KeyValue { - panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind)) + switch n.Kind { + case KeyValue: + if len(n.Children) < 2 { + panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children))) + } + return n.Children[:len(n.Children)-1] + case Table: + return n.Children + default: + panic(fmt.Errorf("Key() is not supported on a %s", n.Kind)) } - if len(n.Children) < 2 { - panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children))) - } - return n.Children[:len(n.Children)-1] } // Value returns a pointer to the value node of a KeyValue. diff --git a/internal/unmarshaler/unmarshaler.go b/internal/unmarshaler/unmarshaler.go index 00bdae4..4e79a4e 100644 --- a/internal/unmarshaler/unmarshaler.go +++ b/internal/unmarshaler/unmarshaler.go @@ -7,19 +7,28 @@ import ( "github.com/pelletier/go-toml/v2/internal/ast" ) -func FromAst(tree ast.Root, target interface{}) error { - v := reflect.ValueOf(target) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("need to target a pointer, not %s", v.Kind()) +func Unmarshal(data []byte, v interface{}) error { + p := parser{} + err := p.parse(data) + if err != nil { + return err } - if v.IsNil() { + return fromAst(p.tree, v) +} + +func fromAst(tree ast.Root, v interface{}) error { + r := reflect.ValueOf(v) + if r.Kind() != reflect.Ptr { + return fmt.Errorf("need to target a pointer, not %s", r.Kind()) + } + if r.IsNil() { return fmt.Errorf("target pointer must be non-nil") } - x := valueTarget(v.Elem()) - + var x target = valueTarget(r.Elem()) + var err error for _, node := range tree { - err := unmarshalTopLevelNode(x, &node) + x, err = unmarshalTopLevelNode(x, &node) if err != nil { return err } @@ -28,31 +37,39 @@ func FromAst(tree ast.Root, target interface{}) error { return nil } -func unmarshalTopLevelNode(x target, node *ast.Node) error { +// The target return value is the target for the next top-level node. Mostly +// unchanged, except by table and array table. +func unmarshalTopLevelNode(x target, node *ast.Node) (target, error) { switch node.Kind { case ast.Table: - panic("TODO") + return scopeWithKey(x, node.Key()) case ast.ArrayTable: panic("TODO") case ast.KeyValue: - return unmarshalKeyValue(x, node) + return x, unmarshalKeyValue(x, node) default: panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) } } -func unmarshalKeyValue(x target, node *ast.Node) error { - assertNode(ast.KeyValue, node) - - key := node.Key() - +func scopeWithKey(x target, key []ast.Node) (target, error) { var err error for _, n := range key { x, err = scopeTarget(x, string(n.Data)) if err != nil { - return err + return nil, err } } + return x, nil +} + +func unmarshalKeyValue(x target, node *ast.Node) error { + assertNode(ast.KeyValue, node) + + x, err := scopeWithKey(x, node.Key()) + if err != nil { + return err + } return unmarshalValue(x, node.Value()) } diff --git a/internal/unmarshaler/unmarshaler_test.go b/internal/unmarshaler/unmarshaler_test.go index 06cf0eb..33b331d 100644 --- a/internal/unmarshaler/unmarshaler_test.go +++ b/internal/unmarshaler/unmarshaler_test.go @@ -1,4 +1,4 @@ -package unmarshaler_test +package unmarshaler import ( "testing" @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/require" "github.com/pelletier/go-toml/v2/internal/ast" - "github.com/pelletier/go-toml/v2/internal/unmarshaler" ) func TestFromAst_KV(t *testing.T) { @@ -32,11 +31,112 @@ func TestFromAst_KV(t *testing.T) { } x := Doc{} - err := unmarshaler.FromAst(root, &x) + err := fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: "hello"}, x) } +func TestFromAst_Table(t *testing.T) { + t.Run("one level table on struct", func(t *testing.T) { + root := ast.Root{ + ast.Node{ + Kind: ast.Table, + Children: []ast.Node{ + {Kind: ast.Key, Data: []byte(`Level1`)}, + }, + }, + ast.Node{ + Kind: ast.KeyValue, + Children: []ast.Node{ + { + Kind: ast.Key, + Data: []byte(`A`), + }, + { + Kind: ast.String, + Data: []byte(`hello`), + }, + }, + }, + ast.Node{ + Kind: ast.KeyValue, + Children: []ast.Node{ + { + Kind: ast.Key, + Data: []byte(`B`), + }, + { + Kind: ast.String, + Data: []byte(`world`), + }, + }, + }, + } + + type Level1 struct { + A string + B string + } + + type Doc struct { + Level1 Level1 + } + + x := Doc{} + err := fromAst(root, &x) + require.NoError(t, err) + assert.Equal(t, Doc{ + Level1: Level1{ + A: "hello", + B: "world", + }, + }, x) + }) + t.Run("one level table on struct", func(t *testing.T) { + root := ast.Root{ + ast.Node{ + Kind: ast.Table, + Children: []ast.Node{ + {Kind: ast.Key, Data: []byte(`A`)}, + {Kind: ast.Key, Data: []byte(`B`)}, + }, + }, + ast.Node{ + Kind: ast.KeyValue, + Children: []ast.Node{ + { + Kind: ast.Key, + Data: []byte(`C`), + }, + { + Kind: ast.String, + Data: []byte(`value`), + }, + }, + }, + } + + type B struct { + C string + } + + type A struct { + B B + } + + type Doc struct { + A A + } + + x := Doc{} + err := fromAst(root, &x) + require.NoError(t, err) + assert.Equal(t, Doc{ + A: A{B: B{C: "value"}}, + }, x) + }) +} + func TestFromAst_Slice(t *testing.T) { t.Run("slice of string", func(t *testing.T) { root := ast.Root{ @@ -69,7 +169,7 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := unmarshaler.FromAst(root, &x) + err := fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []string{"hello", "world"}}, x) }) @@ -105,7 +205,7 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := unmarshaler.FromAst(root, &x) + err := fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []interface{}{"hello", "world"}}, x) }) @@ -150,7 +250,7 @@ func TestFromAst_Slice(t *testing.T) { } x := Doc{} - err := unmarshaler.FromAst(root, &x) + err := fromAst(root, &x) require.NoError(t, err) assert.Equal(t, Doc{Foo: []interface{}{"hello", []interface{}{"inner1", "inner2"}}}, x) })