Support Unmarshaler interface for tables and array tables (#873)

Extend the unstable.Unmarshaler interface support to work with tables
and array tables, not just single values.

When a type implementing unstable.Unmarshaler is the target of a table
(e.g., [table] or [[array]]), the UnmarshalTOML method receives a
synthetic InlineTable node containing all the key-value pairs belonging
to that table.

Key changes:
- Add handleKeyValuesUnmarshaler to collect and process table content
- Add copyExpressionNodes to deep-copy AST nodes for synthetic tables
- Add helper functions in unstable/ast.go for node manipulation
- Update documentation for EnableUnmarshalerInterface
- Add comprehensive tests for table and array table unmarshaling
This commit is contained in:
Claude
2026-01-15 03:01:45 +00:00
parent 2edc61f171
commit 5b6828661c
3 changed files with 488 additions and 2 deletions
+151 -2
View File
@@ -61,8 +61,10 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
// that don't have a straightforward TOML representation to provide their own
// decoding logic.
//
// Currently, types can only decode from a single value. Tables and array tables
// are not supported.
// Types can decode from single values, inline tables, arrays, and standard
// tables/array tables. When decoding from a table (e.g., [table] or [[array]]),
// the UnmarshalTOML method receives a synthetic InlineTable node containing
// all the key-value pairs belonging to that table.
//
// *Unstable:* This method does not follow the compatibility guarantees of
// semver. It can be changed or removed without a new major version being
@@ -624,6 +626,24 @@ func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.V
// Handle root expressions until the end of the document or the next
// non-key-value.
func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
// Check if target implements Unmarshaler before processing key-values.
// This allows types to handle entire tables themselves.
if d.unmarshalerInterface {
vv := v
for vv.Kind() == reflect.Ptr {
if vv.IsNil() {
vv.Set(reflect.New(vv.Type().Elem()))
}
vv = vv.Elem()
}
if vv.CanAddr() && vv.Addr().CanInterface() {
if outi, ok := vv.Addr().Interface().(unstable.Unmarshaler); ok {
// Collect all key-value expressions for this table
return d.handleKeyValuesUnmarshaler(outi)
}
}
}
var rv reflect.Value
for d.nextExpr() {
expr := d.expr()
@@ -653,6 +673,135 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
return rv, nil
}
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
// and passes them to the Unmarshaler as a synthetic InlineTable node.
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
// We need to collect all key-value expressions and build a synthetic table.
// The parser reuses its internal builder between expressions, so we need to
// copy each expression's nodes immediately before parsing the next one.
var allNodes []unstable.Node
allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable})
// Initialize root node's child and next to -1 (invalid reference)
unstable.SetNodeChild(&allNodes[0], -1)
unstable.SetNodeNext(&allNodes[0], -1)
var lastKVIdx int = -1
for d.nextExpr() {
expr := d.expr()
if expr.Kind != unstable.KeyValue {
d.stashExpr()
break
}
_, err := d.seen.CheckExpression(expr)
if err != nil {
return reflect.Value{}, err
}
// Deep copy this expression's nodes into our slice before parsing the next
kvIdx := d.copyExpressionNodes(&allNodes, expr)
// Link to previous sibling or set as first child of root
if lastKVIdx == -1 {
unstable.SetNodeChild(&allNodes[0], int32(kvIdx))
} else {
unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx))
}
lastKVIdx = kvIdx
}
// Set up all nodes with the backing slice
for i := range allNodes {
unstable.SetNodeSlice(&allNodes[i], &allNodes)
}
if err := u.UnmarshalTOML(&allNodes[0]); err != nil {
return reflect.Value{}, err
}
return reflect.Value{}, nil
}
// copyExpressionNodes recursively copies all nodes from an expression into
// the destination slice. Returns the index of the root node in the destination.
// Note: The caller is responsible for setting the nodes slice on all copied nodes
// after all expressions have been collected.
func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int {
// Recursively collect all nodes in this expression tree
collected := collectNodes(node)
// Calculate the offset for this batch
baseIdx := len(*dst)
// Copy all nodes with child/next initialized to -1 (invalid reference)
for _, n := range collected {
copied := unstable.Node{
Kind: n.Kind,
Raw: n.Raw,
Data: n.Data,
}
*dst = append(*dst, copied)
// Initialize child and next to invalid reference (-1)
// Go's zero value is 0, which would incorrectly point to the first node
unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1)
unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1)
}
// Now fix up the child and next indices
for i, n := range collected {
dstIdx := baseIdx + i
if child := unstable.GetNodeChild(n); child >= 0 {
// Find the position of the child in our collected slice
childOffset := findNodeOffset(collected, child, n)
if childOffset >= 0 {
unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset))
}
}
if next := unstable.GetNodeNext(n); next >= 0 {
// Find the position of the next in our collected slice
nextOffset := findNodeOffset(collected, next, n)
if nextOffset >= 0 {
unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset))
}
}
}
return baseIdx
}
// collectNodes collects a node and all its descendants into a slice
func collectNodes(root *unstable.Node) []*unstable.Node {
var result []*unstable.Node
var visit func(n *unstable.Node)
visit = func(n *unstable.Node) {
if n == nil {
return
}
result = append(result, n)
// Visit children
it := n.Children()
for it.Next() {
child := it.Node()
visit(child)
}
}
visit(root)
return result
}
// findNodeOffset finds the offset of a node with the given index in the collected slice
func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int {
// The idx is an index into the original backing slice.
// We need to find which node in our collected slice corresponds to that index.
for i, n := range collected {
if unstable.GetNodeIndex(n, relativeTo) == idx {
return i
}
}
return -1
}
type (
handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
valueMakerFn func() reflect.Value
+291
View File
@@ -4385,3 +4385,294 @@ func TestIssue1028(t *testing.T) {
assert.Error(t, err)
})
}
// Tests for issue #873 - Bring back toml.Unmarshaler for tables and arrays
type customTable873 struct {
Keys []string
Values map[string]string
}
func (c *customTable873) UnmarshalTOML(node *unstable.Node) error {
c.Keys = []string{}
c.Values = make(map[string]string)
it := node.Children()
for it.Next() {
kv := it.Node()
if kv.Kind != unstable.KeyValue {
continue
}
// Get the key
keyIt := kv.Key()
if keyIt.Next() {
keyNode := keyIt.Node()
key := string(keyNode.Data)
c.Keys = append(c.Keys, key)
// Get the value
valueNode := kv.Value()
if valueNode != nil && valueNode.Kind == unstable.String {
c.Values[key] = string(valueNode.Data)
}
}
}
return nil
}
func TestIssue873_TableUnmarshaler(t *testing.T) {
type Config struct {
Section customTable873 `toml:"section"`
}
doc := `
[section]
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, []string{"key1", "key2", "key3"}, cfg.Section.Keys)
assert.Equal(t, "value1", cfg.Section.Values["key1"])
assert.Equal(t, "value2", cfg.Section.Values["key2"])
assert.Equal(t, "value3", cfg.Section.Values["key3"])
}
func TestIssue873_TableUnmarshaler_EmptyTable(t *testing.T) {
type Config struct {
Section customTable873 `toml:"section"`
}
doc := `
[section]
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, []string{}, cfg.Section.Keys)
}
func TestIssue873_ArrayTableUnmarshaler(t *testing.T) {
type Config struct {
Items []customTable873 `toml:"items"`
}
doc := `
[[items]]
name = "first"
id = "1"
[[items]]
name = "second"
id = "2"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, 2, len(cfg.Items))
assert.Equal(t, []string{"name", "id"}, cfg.Items[0].Keys)
assert.Equal(t, "first", cfg.Items[0].Values["name"])
assert.Equal(t, "1", cfg.Items[0].Values["id"])
assert.Equal(t, []string{"name", "id"}, cfg.Items[1].Keys)
assert.Equal(t, "second", cfg.Items[1].Values["name"])
assert.Equal(t, "2", cfg.Items[1].Values["id"])
}
func TestIssue873_NestedTableUnmarshaler(t *testing.T) {
type Config struct {
Outer struct {
Inner customTable873 `toml:"inner"`
} `toml:"outer"`
}
doc := `
[outer.inner]
a = "A"
b = "B"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, []string{"a", "b"}, cfg.Outer.Inner.Keys)
assert.Equal(t, "A", cfg.Outer.Inner.Values["a"])
assert.Equal(t, "B", cfg.Outer.Inner.Values["b"])
}
func TestIssue873_TableUnmarshaler_MultipleTables(t *testing.T) {
type Config struct {
First customTable873 `toml:"first"`
Second customTable873 `toml:"second"`
}
doc := `
[first]
key1 = "value1"
[second]
key2 = "value2"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, []string{"key1"}, cfg.First.Keys)
assert.Equal(t, "value1", cfg.First.Values["key1"])
assert.Equal(t, []string{"key2"}, cfg.Second.Keys)
assert.Equal(t, "value2", cfg.Second.Values["key2"])
}
// Test that regular struct fields still work alongside Unmarshaler tables
func TestIssue873_MixedWithRegularFields(t *testing.T) {
type Config struct {
Name string `toml:"name"`
Section customTable873 `toml:"section"`
Count int `toml:"count"`
}
doc := `
name = "test"
count = 42
[section]
foo = "bar"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.Equal(t, "test", cfg.Name)
assert.Equal(t, 42, cfg.Count)
assert.Equal(t, []string{"foo"}, cfg.Section.Keys)
assert.Equal(t, "bar", cfg.Section.Values["foo"])
}
// Test that pointer to Unmarshaler type works
func TestIssue873_PointerToUnmarshaler(t *testing.T) {
type Config struct {
Section *customTable873 `toml:"section"`
}
doc := `
[section]
hello = "world"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
assert.True(t, cfg.Section != nil)
assert.Equal(t, []string{"hello"}, cfg.Section.Keys)
assert.Equal(t, "world", cfg.Section.Values["hello"])
}
// Test table with sub-tables defined separately
func TestIssue873_TableWithSubTables(t *testing.T) {
type Config struct {
Parent customTable873 `toml:"parent"`
}
doc := `
[parent]
name = "root"
[parent.child]
name = "nested"
`
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
// The parent should only get the keys directly under [parent],
// not the [parent.child] sub-table
assert.NoError(t, err)
assert.Equal(t, []string{"name"}, cfg.Parent.Keys)
assert.Equal(t, "root", cfg.Parent.Values["name"])
}
// Test for issue #994 follow-up - tables defined piece-wise
// This addresses the maintainer's comment: "It doesn't deal with tables defined piece-wise yet"
func TestIssue994_TablesPieceWise(t *testing.T) {
// Test with piece-wise table definition (using [table] syntax)
// The customTable873 type captures key-value pairs in order,
// which is useful for use cases like maintaining map ordering
doc := `
[section]
first = "1"
second = "2"
third = "3"
`
type Config struct {
Section customTable873 `toml:"section"`
}
var cfg Config
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&cfg)
assert.NoError(t, err)
// Verify ordering is preserved (keys are collected in document order)
assert.Equal(t, []string{"first", "second", "third"}, cfg.Section.Keys)
assert.Equal(t, "1", cfg.Section.Values["first"])
assert.Equal(t, "2", cfg.Section.Values["second"])
assert.Equal(t, "3", cfg.Section.Values["third"])
}
// Test root-level struct with tables - combines #994 fix with #873 enhancement
func TestIssue994_RootWithTables(t *testing.T) {
type rootDoc struct {
Tables []customTable873 `toml:"tables"`
}
doc := `
[[tables]]
name = "first"
value = "one"
[[tables]]
name = "second"
value = "two"
`
var d rootDoc
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
EnableUnmarshalerInterface().
Decode(&d)
assert.NoError(t, err)
assert.Equal(t, 2, len(d.Tables))
assert.Equal(t, "first", d.Tables[0].Values["name"])
assert.Equal(t, "one", d.Tables[0].Values["value"])
assert.Equal(t, "second", d.Tables[1].Values["name"])
assert.Equal(t, "two", d.Tables[1].Values["value"])
}
+46
View File
@@ -143,3 +143,49 @@ func (n *Node) Value() *Node {
func (n *Node) Children() Iterator {
return Iterator{nodes: n.nodes, idx: n.child}
}
// SetNodeSlice sets the backing nodes slice for a node.
// This is used when building synthetic nodes.
func SetNodeSlice(n *Node, nodes *[]Node) {
n.nodes = nodes
}
// SetNodeChild sets the child index for a node.
// This is used when building synthetic nodes.
func SetNodeChild(n *Node, child int32) {
n.child = child
}
// SetNodeNext sets the next sibling index for a node.
// This is used when building synthetic nodes.
func SetNodeNext(n *Node, next int32) {
n.next = next
}
// GetNodeChild returns the child index for a node.
// This is used when copying nodes.
func GetNodeChild(n *Node) int32 {
return n.child
}
// GetNodeNext returns the next sibling index for a node.
// This is used when copying nodes.
func GetNodeNext(n *Node) int32 {
return n.next
}
// GetNodeIndex returns the index of node n in the backing slice,
// using relativeTo's nodes slice as reference.
// Returns -1 if n is not in the slice.
func GetNodeIndex(n *Node, relativeTo *Node) int32 {
if relativeTo.nodes == nil || n == nil {
return -1
}
nodes := *relativeTo.nodes
for i := range nodes {
if &nodes[i] == n {
return int32(i)
}
}
return -1
}