Support Unmarshaler interface for tables and array tables (#873)
Extend the unstable.Unmarshaler interface support to work with tables and array tables, not just single values. When a type implementing unstable.Unmarshaler is the target of a table (e.g., [table] or [[array]]), the UnmarshalTOML method receives a synthetic InlineTable node containing all the key-value pairs belonging to that table. Key changes: - Add handleKeyValuesUnmarshaler to collect and process table content - Add copyExpressionNodes to deep-copy AST nodes for synthetic tables - Add helper functions in unstable/ast.go for node manipulation - Update documentation for EnableUnmarshalerInterface - Add comprehensive tests for table and array table unmarshaling
This commit is contained in:
+151
-2
@@ -61,8 +61,10 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
|
||||
// that don't have a straightforward TOML representation to provide their own
|
||||
// decoding logic.
|
||||
//
|
||||
// Currently, types can only decode from a single value. Tables and array tables
|
||||
// are not supported.
|
||||
// Types can decode from single values, inline tables, arrays, and standard
|
||||
// tables/array tables. When decoding from a table (e.g., [table] or [[array]]),
|
||||
// the UnmarshalTOML method receives a synthetic InlineTable node containing
|
||||
// all the key-value pairs belonging to that table.
|
||||
//
|
||||
// *Unstable:* This method does not follow the compatibility guarantees of
|
||||
// semver. It can be changed or removed without a new major version being
|
||||
@@ -624,6 +626,24 @@ func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.V
|
||||
// Handle root expressions until the end of the document or the next
|
||||
// non-key-value.
|
||||
func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
|
||||
// Check if target implements Unmarshaler before processing key-values.
|
||||
// This allows types to handle entire tables themselves.
|
||||
if d.unmarshalerInterface {
|
||||
vv := v
|
||||
for vv.Kind() == reflect.Ptr {
|
||||
if vv.IsNil() {
|
||||
vv.Set(reflect.New(vv.Type().Elem()))
|
||||
}
|
||||
vv = vv.Elem()
|
||||
}
|
||||
if vv.CanAddr() && vv.Addr().CanInterface() {
|
||||
if outi, ok := vv.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
// Collect all key-value expressions for this table
|
||||
return d.handleKeyValuesUnmarshaler(outi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var rv reflect.Value
|
||||
for d.nextExpr() {
|
||||
expr := d.expr()
|
||||
@@ -653,6 +673,135 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
|
||||
// and passes them to the Unmarshaler as a synthetic InlineTable node.
|
||||
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
|
||||
// We need to collect all key-value expressions and build a synthetic table.
|
||||
// The parser reuses its internal builder between expressions, so we need to
|
||||
// copy each expression's nodes immediately before parsing the next one.
|
||||
var allNodes []unstable.Node
|
||||
allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable})
|
||||
// Initialize root node's child and next to -1 (invalid reference)
|
||||
unstable.SetNodeChild(&allNodes[0], -1)
|
||||
unstable.SetNodeNext(&allNodes[0], -1)
|
||||
|
||||
var lastKVIdx int = -1
|
||||
|
||||
for d.nextExpr() {
|
||||
expr := d.expr()
|
||||
if expr.Kind != unstable.KeyValue {
|
||||
d.stashExpr()
|
||||
break
|
||||
}
|
||||
|
||||
_, err := d.seen.CheckExpression(expr)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Deep copy this expression's nodes into our slice before parsing the next
|
||||
kvIdx := d.copyExpressionNodes(&allNodes, expr)
|
||||
|
||||
// Link to previous sibling or set as first child of root
|
||||
if lastKVIdx == -1 {
|
||||
unstable.SetNodeChild(&allNodes[0], int32(kvIdx))
|
||||
} else {
|
||||
unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx))
|
||||
}
|
||||
lastKVIdx = kvIdx
|
||||
}
|
||||
|
||||
// Set up all nodes with the backing slice
|
||||
for i := range allNodes {
|
||||
unstable.SetNodeSlice(&allNodes[i], &allNodes)
|
||||
}
|
||||
|
||||
if err := u.UnmarshalTOML(&allNodes[0]); err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// copyExpressionNodes recursively copies all nodes from an expression into
|
||||
// the destination slice. Returns the index of the root node in the destination.
|
||||
// Note: The caller is responsible for setting the nodes slice on all copied nodes
|
||||
// after all expressions have been collected.
|
||||
func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int {
|
||||
// Recursively collect all nodes in this expression tree
|
||||
collected := collectNodes(node)
|
||||
|
||||
// Calculate the offset for this batch
|
||||
baseIdx := len(*dst)
|
||||
|
||||
// Copy all nodes with child/next initialized to -1 (invalid reference)
|
||||
for _, n := range collected {
|
||||
copied := unstable.Node{
|
||||
Kind: n.Kind,
|
||||
Raw: n.Raw,
|
||||
Data: n.Data,
|
||||
}
|
||||
*dst = append(*dst, copied)
|
||||
// Initialize child and next to invalid reference (-1)
|
||||
// Go's zero value is 0, which would incorrectly point to the first node
|
||||
unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1)
|
||||
unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1)
|
||||
}
|
||||
|
||||
// Now fix up the child and next indices
|
||||
for i, n := range collected {
|
||||
dstIdx := baseIdx + i
|
||||
if child := unstable.GetNodeChild(n); child >= 0 {
|
||||
// Find the position of the child in our collected slice
|
||||
childOffset := findNodeOffset(collected, child, n)
|
||||
if childOffset >= 0 {
|
||||
unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset))
|
||||
}
|
||||
}
|
||||
if next := unstable.GetNodeNext(n); next >= 0 {
|
||||
// Find the position of the next in our collected slice
|
||||
nextOffset := findNodeOffset(collected, next, n)
|
||||
if nextOffset >= 0 {
|
||||
unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return baseIdx
|
||||
}
|
||||
|
||||
// collectNodes collects a node and all its descendants into a slice
|
||||
func collectNodes(root *unstable.Node) []*unstable.Node {
|
||||
var result []*unstable.Node
|
||||
var visit func(n *unstable.Node)
|
||||
visit = func(n *unstable.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
result = append(result, n)
|
||||
// Visit children
|
||||
it := n.Children()
|
||||
for it.Next() {
|
||||
child := it.Node()
|
||||
visit(child)
|
||||
}
|
||||
}
|
||||
visit(root)
|
||||
return result
|
||||
}
|
||||
|
||||
// findNodeOffset finds the offset of a node with the given index in the collected slice
|
||||
func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int {
|
||||
// The idx is an index into the original backing slice.
|
||||
// We need to find which node in our collected slice corresponds to that index.
|
||||
for i, n := range collected {
|
||||
if unstable.GetNodeIndex(n, relativeTo) == idx {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type (
|
||||
handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
|
||||
valueMakerFn func() reflect.Value
|
||||
|
||||
@@ -4385,3 +4385,294 @@ func TestIssue1028(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// Tests for issue #873 - Bring back toml.Unmarshaler for tables and arrays
|
||||
|
||||
type customTable873 struct {
|
||||
Keys []string
|
||||
Values map[string]string
|
||||
}
|
||||
|
||||
func (c *customTable873) UnmarshalTOML(node *unstable.Node) error {
|
||||
c.Keys = []string{}
|
||||
c.Values = make(map[string]string)
|
||||
|
||||
it := node.Children()
|
||||
for it.Next() {
|
||||
kv := it.Node()
|
||||
if kv.Kind != unstable.KeyValue {
|
||||
continue
|
||||
}
|
||||
// Get the key
|
||||
keyIt := kv.Key()
|
||||
if keyIt.Next() {
|
||||
keyNode := keyIt.Node()
|
||||
key := string(keyNode.Data)
|
||||
c.Keys = append(c.Keys, key)
|
||||
|
||||
// Get the value
|
||||
valueNode := kv.Value()
|
||||
if valueNode != nil && valueNode.Kind == unstable.String {
|
||||
c.Values[key] = string(valueNode.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestIssue873_TableUnmarshaler(t *testing.T) {
|
||||
type Config struct {
|
||||
Section customTable873 `toml:"section"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[section]
|
||||
key1 = "value1"
|
||||
key2 = "value2"
|
||||
key3 = "value3"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key1", "key2", "key3"}, cfg.Section.Keys)
|
||||
assert.Equal(t, "value1", cfg.Section.Values["key1"])
|
||||
assert.Equal(t, "value2", cfg.Section.Values["key2"])
|
||||
assert.Equal(t, "value3", cfg.Section.Values["key3"])
|
||||
}
|
||||
|
||||
func TestIssue873_TableUnmarshaler_EmptyTable(t *testing.T) {
|
||||
type Config struct {
|
||||
Section customTable873 `toml:"section"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[section]
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{}, cfg.Section.Keys)
|
||||
}
|
||||
|
||||
func TestIssue873_ArrayTableUnmarshaler(t *testing.T) {
|
||||
type Config struct {
|
||||
Items []customTable873 `toml:"items"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[[items]]
|
||||
name = "first"
|
||||
id = "1"
|
||||
|
||||
[[items]]
|
||||
name = "second"
|
||||
id = "2"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(cfg.Items))
|
||||
assert.Equal(t, []string{"name", "id"}, cfg.Items[0].Keys)
|
||||
assert.Equal(t, "first", cfg.Items[0].Values["name"])
|
||||
assert.Equal(t, "1", cfg.Items[0].Values["id"])
|
||||
assert.Equal(t, []string{"name", "id"}, cfg.Items[1].Keys)
|
||||
assert.Equal(t, "second", cfg.Items[1].Values["name"])
|
||||
assert.Equal(t, "2", cfg.Items[1].Values["id"])
|
||||
}
|
||||
|
||||
func TestIssue873_NestedTableUnmarshaler(t *testing.T) {
|
||||
type Config struct {
|
||||
Outer struct {
|
||||
Inner customTable873 `toml:"inner"`
|
||||
} `toml:"outer"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[outer.inner]
|
||||
a = "A"
|
||||
b = "B"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b"}, cfg.Outer.Inner.Keys)
|
||||
assert.Equal(t, "A", cfg.Outer.Inner.Values["a"])
|
||||
assert.Equal(t, "B", cfg.Outer.Inner.Values["b"])
|
||||
}
|
||||
|
||||
func TestIssue873_TableUnmarshaler_MultipleTables(t *testing.T) {
|
||||
type Config struct {
|
||||
First customTable873 `toml:"first"`
|
||||
Second customTable873 `toml:"second"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[first]
|
||||
key1 = "value1"
|
||||
|
||||
[second]
|
||||
key2 = "value2"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"key1"}, cfg.First.Keys)
|
||||
assert.Equal(t, "value1", cfg.First.Values["key1"])
|
||||
assert.Equal(t, []string{"key2"}, cfg.Second.Keys)
|
||||
assert.Equal(t, "value2", cfg.Second.Values["key2"])
|
||||
}
|
||||
|
||||
// Test that regular struct fields still work alongside Unmarshaler tables
|
||||
func TestIssue873_MixedWithRegularFields(t *testing.T) {
|
||||
type Config struct {
|
||||
Name string `toml:"name"`
|
||||
Section customTable873 `toml:"section"`
|
||||
Count int `toml:"count"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
name = "test"
|
||||
count = 42
|
||||
|
||||
[section]
|
||||
foo = "bar"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", cfg.Name)
|
||||
assert.Equal(t, 42, cfg.Count)
|
||||
assert.Equal(t, []string{"foo"}, cfg.Section.Keys)
|
||||
assert.Equal(t, "bar", cfg.Section.Values["foo"])
|
||||
}
|
||||
|
||||
// Test that pointer to Unmarshaler type works
|
||||
func TestIssue873_PointerToUnmarshaler(t *testing.T) {
|
||||
type Config struct {
|
||||
Section *customTable873 `toml:"section"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[section]
|
||||
hello = "world"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, cfg.Section != nil)
|
||||
assert.Equal(t, []string{"hello"}, cfg.Section.Keys)
|
||||
assert.Equal(t, "world", cfg.Section.Values["hello"])
|
||||
}
|
||||
|
||||
// Test table with sub-tables defined separately
|
||||
func TestIssue873_TableWithSubTables(t *testing.T) {
|
||||
type Config struct {
|
||||
Parent customTable873 `toml:"parent"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[parent]
|
||||
name = "root"
|
||||
|
||||
[parent.child]
|
||||
name = "nested"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
// The parent should only get the keys directly under [parent],
|
||||
// not the [parent.child] sub-table
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"name"}, cfg.Parent.Keys)
|
||||
assert.Equal(t, "root", cfg.Parent.Values["name"])
|
||||
}
|
||||
|
||||
// Test for issue #994 follow-up - tables defined piece-wise
|
||||
// This addresses the maintainer's comment: "It doesn't deal with tables defined piece-wise yet"
|
||||
func TestIssue994_TablesPieceWise(t *testing.T) {
|
||||
// Test with piece-wise table definition (using [table] syntax)
|
||||
// The customTable873 type captures key-value pairs in order,
|
||||
// which is useful for use cases like maintaining map ordering
|
||||
doc := `
|
||||
[section]
|
||||
first = "1"
|
||||
second = "2"
|
||||
third = "3"
|
||||
`
|
||||
|
||||
type Config struct {
|
||||
Section customTable873 `toml:"section"`
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
// Verify ordering is preserved (keys are collected in document order)
|
||||
assert.Equal(t, []string{"first", "second", "third"}, cfg.Section.Keys)
|
||||
assert.Equal(t, "1", cfg.Section.Values["first"])
|
||||
assert.Equal(t, "2", cfg.Section.Values["second"])
|
||||
assert.Equal(t, "3", cfg.Section.Values["third"])
|
||||
}
|
||||
|
||||
// Test root-level struct with tables - combines #994 fix with #873 enhancement
|
||||
func TestIssue994_RootWithTables(t *testing.T) {
|
||||
type rootDoc struct {
|
||||
Tables []customTable873 `toml:"tables"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[[tables]]
|
||||
name = "first"
|
||||
value = "one"
|
||||
|
||||
[[tables]]
|
||||
name = "second"
|
||||
value = "two"
|
||||
`
|
||||
|
||||
var d rootDoc
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&d)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(d.Tables))
|
||||
assert.Equal(t, "first", d.Tables[0].Values["name"])
|
||||
assert.Equal(t, "one", d.Tables[0].Values["value"])
|
||||
assert.Equal(t, "second", d.Tables[1].Values["name"])
|
||||
assert.Equal(t, "two", d.Tables[1].Values["value"])
|
||||
}
|
||||
|
||||
@@ -143,3 +143,49 @@ func (n *Node) Value() *Node {
|
||||
func (n *Node) Children() Iterator {
|
||||
return Iterator{nodes: n.nodes, idx: n.child}
|
||||
}
|
||||
|
||||
// SetNodeSlice sets the backing nodes slice for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeSlice(n *Node, nodes *[]Node) {
|
||||
n.nodes = nodes
|
||||
}
|
||||
|
||||
// SetNodeChild sets the child index for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeChild(n *Node, child int32) {
|
||||
n.child = child
|
||||
}
|
||||
|
||||
// SetNodeNext sets the next sibling index for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeNext(n *Node, next int32) {
|
||||
n.next = next
|
||||
}
|
||||
|
||||
// GetNodeChild returns the child index for a node.
|
||||
// This is used when copying nodes.
|
||||
func GetNodeChild(n *Node) int32 {
|
||||
return n.child
|
||||
}
|
||||
|
||||
// GetNodeNext returns the next sibling index for a node.
|
||||
// This is used when copying nodes.
|
||||
func GetNodeNext(n *Node) int32 {
|
||||
return n.next
|
||||
}
|
||||
|
||||
// GetNodeIndex returns the index of node n in the backing slice,
|
||||
// using relativeTo's nodes slice as reference.
|
||||
// Returns -1 if n is not in the slice.
|
||||
func GetNodeIndex(n *Node, relativeTo *Node) int32 {
|
||||
if relativeTo.nodes == nil || n == nil {
|
||||
return -1
|
||||
}
|
||||
nodes := *relativeTo.nodes
|
||||
for i := range nodes {
|
||||
if &nodes[i] == n {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user