diff --git a/unmarshaler.go b/unmarshaler.go index b61a347..edc7c80 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -61,8 +61,10 @@ func (d *Decoder) DisallowUnknownFields() *Decoder { // that don't have a straightforward TOML representation to provide their own // decoding logic. // -// Currently, types can only decode from a single value. Tables and array tables -// are not supported. +// Types can decode from single values, inline tables, arrays, and standard +// tables/array tables. When decoding from a table (e.g., [table] or [[array]]), +// the UnmarshalTOML method receives a synthetic InlineTable node containing +// all the key-value pairs belonging to that table. // // *Unstable:* This method does not follow the compatibility guarantees of // semver. It can be changed or removed without a new major version being @@ -624,6 +626,24 @@ func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.V // Handle root expressions until the end of the document or the next // non-key-value. func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { + // Check if target implements Unmarshaler before processing key-values. + // This allows types to handle entire tables themselves. + if d.unmarshalerInterface { + vv := v + for vv.Kind() == reflect.Ptr { + if vv.IsNil() { + vv.Set(reflect.New(vv.Type().Elem())) + } + vv = vv.Elem() + } + if vv.CanAddr() && vv.Addr().CanInterface() { + if outi, ok := vv.Addr().Interface().(unstable.Unmarshaler); ok { + // Collect all key-value expressions for this table + return d.handleKeyValuesUnmarshaler(outi) + } + } + } + var rv reflect.Value for d.nextExpr() { expr := d.expr() @@ -653,6 +673,135 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { return rv, nil } +// handleKeyValuesUnmarshaler collects all key-value expressions for a table +// and passes them to the Unmarshaler as a synthetic InlineTable node. +func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) { + // We need to collect all key-value expressions and build a synthetic table. + // The parser reuses its internal builder between expressions, so we need to + // copy each expression's nodes immediately before parsing the next one. + var allNodes []unstable.Node + allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable}) + // Initialize root node's child and next to -1 (invalid reference) + unstable.SetNodeChild(&allNodes[0], -1) + unstable.SetNodeNext(&allNodes[0], -1) + + var lastKVIdx int = -1 + + for d.nextExpr() { + expr := d.expr() + if expr.Kind != unstable.KeyValue { + d.stashExpr() + break + } + + _, err := d.seen.CheckExpression(expr) + if err != nil { + return reflect.Value{}, err + } + + // Deep copy this expression's nodes into our slice before parsing the next + kvIdx := d.copyExpressionNodes(&allNodes, expr) + + // Link to previous sibling or set as first child of root + if lastKVIdx == -1 { + unstable.SetNodeChild(&allNodes[0], int32(kvIdx)) + } else { + unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx)) + } + lastKVIdx = kvIdx + } + + // Set up all nodes with the backing slice + for i := range allNodes { + unstable.SetNodeSlice(&allNodes[i], &allNodes) + } + + if err := u.UnmarshalTOML(&allNodes[0]); err != nil { + return reflect.Value{}, err + } + + return reflect.Value{}, nil +} + +// copyExpressionNodes recursively copies all nodes from an expression into +// the destination slice. Returns the index of the root node in the destination. +// Note: The caller is responsible for setting the nodes slice on all copied nodes +// after all expressions have been collected. +func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int { + // Recursively collect all nodes in this expression tree + collected := collectNodes(node) + + // Calculate the offset for this batch + baseIdx := len(*dst) + + // Copy all nodes with child/next initialized to -1 (invalid reference) + for _, n := range collected { + copied := unstable.Node{ + Kind: n.Kind, + Raw: n.Raw, + Data: n.Data, + } + *dst = append(*dst, copied) + // Initialize child and next to invalid reference (-1) + // Go's zero value is 0, which would incorrectly point to the first node + unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1) + unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1) + } + + // Now fix up the child and next indices + for i, n := range collected { + dstIdx := baseIdx + i + if child := unstable.GetNodeChild(n); child >= 0 { + // Find the position of the child in our collected slice + childOffset := findNodeOffset(collected, child, n) + if childOffset >= 0 { + unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset)) + } + } + if next := unstable.GetNodeNext(n); next >= 0 { + // Find the position of the next in our collected slice + nextOffset := findNodeOffset(collected, next, n) + if nextOffset >= 0 { + unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset)) + } + } + } + + return baseIdx +} + +// collectNodes collects a node and all its descendants into a slice +func collectNodes(root *unstable.Node) []*unstable.Node { + var result []*unstable.Node + var visit func(n *unstable.Node) + visit = func(n *unstable.Node) { + if n == nil { + return + } + result = append(result, n) + // Visit children + it := n.Children() + for it.Next() { + child := it.Node() + visit(child) + } + } + visit(root) + return result +} + +// findNodeOffset finds the offset of a node with the given index in the collected slice +func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int { + // The idx is an index into the original backing slice. + // We need to find which node in our collected slice corresponds to that index. + for i, n := range collected { + if unstable.GetNodeIndex(n, relativeTo) == idx { + return i + } + } + return -1 +} + type ( handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error) valueMakerFn func() reflect.Value diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 3e3b2a3..f820d18 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -4385,3 +4385,294 @@ func TestIssue1028(t *testing.T) { assert.Error(t, err) }) } + +// Tests for issue #873 - Bring back toml.Unmarshaler for tables and arrays + +type customTable873 struct { + Keys []string + Values map[string]string +} + +func (c *customTable873) UnmarshalTOML(node *unstable.Node) error { + c.Keys = []string{} + c.Values = make(map[string]string) + + it := node.Children() + for it.Next() { + kv := it.Node() + if kv.Kind != unstable.KeyValue { + continue + } + // Get the key + keyIt := kv.Key() + if keyIt.Next() { + keyNode := keyIt.Node() + key := string(keyNode.Data) + c.Keys = append(c.Keys, key) + + // Get the value + valueNode := kv.Value() + if valueNode != nil && valueNode.Kind == unstable.String { + c.Values[key] = string(valueNode.Data) + } + } + } + return nil +} + +func TestIssue873_TableUnmarshaler(t *testing.T) { + type Config struct { + Section customTable873 `toml:"section"` + } + + doc := ` +[section] +key1 = "value1" +key2 = "value2" +key3 = "value3" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, []string{"key1", "key2", "key3"}, cfg.Section.Keys) + assert.Equal(t, "value1", cfg.Section.Values["key1"]) + assert.Equal(t, "value2", cfg.Section.Values["key2"]) + assert.Equal(t, "value3", cfg.Section.Values["key3"]) +} + +func TestIssue873_TableUnmarshaler_EmptyTable(t *testing.T) { + type Config struct { + Section customTable873 `toml:"section"` + } + + doc := ` +[section] +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, []string{}, cfg.Section.Keys) +} + +func TestIssue873_ArrayTableUnmarshaler(t *testing.T) { + type Config struct { + Items []customTable873 `toml:"items"` + } + + doc := ` +[[items]] +name = "first" +id = "1" + +[[items]] +name = "second" +id = "2" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, 2, len(cfg.Items)) + assert.Equal(t, []string{"name", "id"}, cfg.Items[0].Keys) + assert.Equal(t, "first", cfg.Items[0].Values["name"]) + assert.Equal(t, "1", cfg.Items[0].Values["id"]) + assert.Equal(t, []string{"name", "id"}, cfg.Items[1].Keys) + assert.Equal(t, "second", cfg.Items[1].Values["name"]) + assert.Equal(t, "2", cfg.Items[1].Values["id"]) +} + +func TestIssue873_NestedTableUnmarshaler(t *testing.T) { + type Config struct { + Outer struct { + Inner customTable873 `toml:"inner"` + } `toml:"outer"` + } + + doc := ` +[outer.inner] +a = "A" +b = "B" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, cfg.Outer.Inner.Keys) + assert.Equal(t, "A", cfg.Outer.Inner.Values["a"]) + assert.Equal(t, "B", cfg.Outer.Inner.Values["b"]) +} + +func TestIssue873_TableUnmarshaler_MultipleTables(t *testing.T) { + type Config struct { + First customTable873 `toml:"first"` + Second customTable873 `toml:"second"` + } + + doc := ` +[first] +key1 = "value1" + +[second] +key2 = "value2" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, []string{"key1"}, cfg.First.Keys) + assert.Equal(t, "value1", cfg.First.Values["key1"]) + assert.Equal(t, []string{"key2"}, cfg.Second.Keys) + assert.Equal(t, "value2", cfg.Second.Values["key2"]) +} + +// Test that regular struct fields still work alongside Unmarshaler tables +func TestIssue873_MixedWithRegularFields(t *testing.T) { + type Config struct { + Name string `toml:"name"` + Section customTable873 `toml:"section"` + Count int `toml:"count"` + } + + doc := ` +name = "test" +count = 42 + +[section] +foo = "bar" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.Equal(t, "test", cfg.Name) + assert.Equal(t, 42, cfg.Count) + assert.Equal(t, []string{"foo"}, cfg.Section.Keys) + assert.Equal(t, "bar", cfg.Section.Values["foo"]) +} + +// Test that pointer to Unmarshaler type works +func TestIssue873_PointerToUnmarshaler(t *testing.T) { + type Config struct { + Section *customTable873 `toml:"section"` + } + + doc := ` +[section] +hello = "world" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + assert.True(t, cfg.Section != nil) + assert.Equal(t, []string{"hello"}, cfg.Section.Keys) + assert.Equal(t, "world", cfg.Section.Values["hello"]) +} + +// Test table with sub-tables defined separately +func TestIssue873_TableWithSubTables(t *testing.T) { + type Config struct { + Parent customTable873 `toml:"parent"` + } + + doc := ` +[parent] +name = "root" + +[parent.child] +name = "nested" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + // The parent should only get the keys directly under [parent], + // not the [parent.child] sub-table + assert.NoError(t, err) + assert.Equal(t, []string{"name"}, cfg.Parent.Keys) + assert.Equal(t, "root", cfg.Parent.Values["name"]) +} + +// Test for issue #994 follow-up - tables defined piece-wise +// This addresses the maintainer's comment: "It doesn't deal with tables defined piece-wise yet" +func TestIssue994_TablesPieceWise(t *testing.T) { + // Test with piece-wise table definition (using [table] syntax) + // The customTable873 type captures key-value pairs in order, + // which is useful for use cases like maintaining map ordering + doc := ` +[section] +first = "1" +second = "2" +third = "3" +` + + type Config struct { + Section customTable873 `toml:"section"` + } + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + // Verify ordering is preserved (keys are collected in document order) + assert.Equal(t, []string{"first", "second", "third"}, cfg.Section.Keys) + assert.Equal(t, "1", cfg.Section.Values["first"]) + assert.Equal(t, "2", cfg.Section.Values["second"]) + assert.Equal(t, "3", cfg.Section.Values["third"]) +} + +// Test root-level struct with tables - combines #994 fix with #873 enhancement +func TestIssue994_RootWithTables(t *testing.T) { + type rootDoc struct { + Tables []customTable873 `toml:"tables"` + } + + doc := ` +[[tables]] +name = "first" +value = "one" + +[[tables]] +name = "second" +value = "two" +` + + var d rootDoc + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&d) + + assert.NoError(t, err) + assert.Equal(t, 2, len(d.Tables)) + assert.Equal(t, "first", d.Tables[0].Values["name"]) + assert.Equal(t, "one", d.Tables[0].Values["value"]) + assert.Equal(t, "second", d.Tables[1].Values["name"]) + assert.Equal(t, "two", d.Tables[1].Values["value"]) +} diff --git a/unstable/ast.go b/unstable/ast.go index 34ef628..0d71ac3 100644 --- a/unstable/ast.go +++ b/unstable/ast.go @@ -143,3 +143,49 @@ func (n *Node) Value() *Node { func (n *Node) Children() Iterator { return Iterator{nodes: n.nodes, idx: n.child} } + +// SetNodeSlice sets the backing nodes slice for a node. +// This is used when building synthetic nodes. +func SetNodeSlice(n *Node, nodes *[]Node) { + n.nodes = nodes +} + +// SetNodeChild sets the child index for a node. +// This is used when building synthetic nodes. +func SetNodeChild(n *Node, child int32) { + n.child = child +} + +// SetNodeNext sets the next sibling index for a node. +// This is used when building synthetic nodes. +func SetNodeNext(n *Node, next int32) { + n.next = next +} + +// GetNodeChild returns the child index for a node. +// This is used when copying nodes. +func GetNodeChild(n *Node) int32 { + return n.child +} + +// GetNodeNext returns the next sibling index for a node. +// This is used when copying nodes. +func GetNodeNext(n *Node) int32 { + return n.next +} + +// GetNodeIndex returns the index of node n in the backing slice, +// using relativeTo's nodes slice as reference. +// Returns -1 if n is not in the slice. +func GetNodeIndex(n *Node, relativeTo *Node) int32 { + if relativeTo.nodes == nil || n == nil { + return -1 + } + nodes := *relativeTo.nodes + for i := range nodes { + if &nodes[i] == n { + return int32(i) + } + } + return -1 +}