Implement bytes-based Unmarshaler interface for tables and arrays (#873)

This change brings back support for the unstable.Unmarshaler interface
for tables and array tables, addressing issue #873.

Key changes:
- Changed UnmarshalTOML signature from (*Node) to ([]byte) to provide
  raw TOML bytes instead of AST nodes
- Added RawMessage type (similar to json.RawMessage) for capturing raw
  TOML bytes for later processing
- Updated handleKeyValuesUnmarshaler to reconstruct key-value lines
  from the parsed keys and raw value bytes
- Added support for slice types implementing Unmarshaler (e.g., RawMessage)
- Removed unused AST helper functions from unstable/ast.go

The bytes-based interface allows users to:
- Get raw TOML bytes for custom parsing
- Delay TOML decoding using RawMessage
- Implement custom unmarshaling logic for complex types

Tests added for:
- Table unmarshaler with various scenarios
- Array table unmarshaler
- Split tables (same parent defined in multiple places)
- RawMessage usage
- Nested tables and mixed regular fields
This commit is contained in:
Claude
2026-01-15 12:13:14 +00:00
parent 5b6828661c
commit 2762e24a9c
4 changed files with 202 additions and 189 deletions
+74 -117
View File
@@ -56,15 +56,18 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
// EnableUnmarshalerInterface allows to enable unmarshaler interface.
//
// With this feature enabled, types implementing the unstable/Unmarshaler
// With this feature enabled, types implementing the unstable.Unmarshaler
// interface can be decoded from any structure of the document. It allows types
// that don't have a straightforward TOML representation to provide their own
// decoding logic.
//
// Types can decode from single values, inline tables, arrays, and standard
// tables/array tables. When decoding from a table (e.g., [table] or [[array]]),
// the UnmarshalTOML method receives a synthetic InlineTable node containing
// all the key-value pairs belonging to that table.
// The UnmarshalTOML method receives raw TOML bytes:
// - For single values: the raw value bytes (e.g., `"hello"` for a string)
// - For tables: all key-value lines belonging to that table
// - For inline tables/arrays: the raw bytes of the inline structure
//
// The unstable.RawMessage type can be used to capture raw TOML bytes for
// later processing, similar to json.RawMessage.
//
// *Unstable:* This method does not follow the compatibility guarantees of
// semver. It can be changed or removed without a new major version being
@@ -601,18 +604,28 @@ func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (
// cannot handle it.
func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice {
if v.Len() == 0 {
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
// For non-empty slices, work with the last element
if v.Len() > 0 {
elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem)
if err != nil {
return reflect.Value{}, err
}
if x.IsValid() {
elem.Set(x)
}
return reflect.Value{}, nil
}
elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem)
if err != nil {
return reflect.Value{}, err
// Empty slice - check if it implements Unmarshaler (e.g., RawMessage)
// and we're at the end of the key path
if d.unmarshalerInterface && !key.Next() {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return d.handleKeyValuesUnmarshaler(outi)
}
}
}
if x.IsValid() {
elem.Set(x)
}
return reflect.Value{}, nil
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
}
if key.Next() {
// Still scoping the key
@@ -674,18 +687,12 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
}
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
// and passes them to the Unmarshaler as a synthetic InlineTable node.
// and passes them to the Unmarshaler as raw TOML bytes.
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
// We need to collect all key-value expressions and build a synthetic table.
// The parser reuses its internal builder between expressions, so we need to
// copy each expression's nodes immediately before parsing the next one.
var allNodes []unstable.Node
allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable})
// Initialize root node's child and next to -1 (invalid reference)
unstable.SetNodeChild(&allNodes[0], -1)
unstable.SetNodeNext(&allNodes[0], -1)
var lastKVIdx int = -1
// Collect raw bytes from all key-value expressions for this table.
// We build a valid TOML document by reconstructing each key-value line
// from the key names and the value's raw bytes.
var buf []byte
for d.nextExpr() {
expr := d.expr()
@@ -699,107 +706,55 @@ func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Va
return reflect.Value{}, err
}
// Deep copy this expression's nodes into our slice before parsing the next
kvIdx := d.copyExpressionNodes(&allNodes, expr)
// Link to previous sibling or set as first child of root
if lastKVIdx == -1 {
unstable.SetNodeChild(&allNodes[0], int32(kvIdx))
} else {
unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx))
// Reconstruct the key-value line from the key(s) and value
keyIt := expr.Key()
first := true
for keyIt.Next() {
if !first {
buf = append(buf, '.')
}
keyNode := keyIt.Node()
// Check if key needs quoting
if keyNeedsQuoting(keyNode.Data) {
buf = append(buf, '"')
buf = append(buf, keyNode.Data...)
buf = append(buf, '"')
} else {
buf = append(buf, keyNode.Data...)
}
first = false
}
lastKVIdx = kvIdx
buf = append(buf, " = "...)
// Get the raw value bytes
value := expr.Value()
if value != nil {
raw := d.p.Raw(value.Raw)
buf = append(buf, raw...)
}
buf = append(buf, '\n')
}
// Set up all nodes with the backing slice
for i := range allNodes {
unstable.SetNodeSlice(&allNodes[i], &allNodes)
}
if err := u.UnmarshalTOML(&allNodes[0]); err != nil {
if err := u.UnmarshalTOML(buf); err != nil {
return reflect.Value{}, err
}
return reflect.Value{}, nil
}
// copyExpressionNodes recursively copies all nodes from an expression into
// the destination slice. Returns the index of the root node in the destination.
// Note: The caller is responsible for setting the nodes slice on all copied nodes
// after all expressions have been collected.
func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int {
// Recursively collect all nodes in this expression tree
collected := collectNodes(node)
// Calculate the offset for this batch
baseIdx := len(*dst)
// Copy all nodes with child/next initialized to -1 (invalid reference)
for _, n := range collected {
copied := unstable.Node{
Kind: n.Kind,
Raw: n.Raw,
Data: n.Data,
}
*dst = append(*dst, copied)
// Initialize child and next to invalid reference (-1)
// Go's zero value is 0, which would incorrectly point to the first node
unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1)
unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1)
// keyNeedsQuoting returns true if the key needs to be quoted in TOML.
func keyNeedsQuoting(key []byte) bool {
if len(key) == 0 {
return true
}
// Now fix up the child and next indices
for i, n := range collected {
dstIdx := baseIdx + i
if child := unstable.GetNodeChild(n); child >= 0 {
// Find the position of the child in our collected slice
childOffset := findNodeOffset(collected, child, n)
if childOffset >= 0 {
unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset))
}
}
if next := unstable.GetNodeNext(n); next >= 0 {
// Find the position of the next in our collected slice
nextOffset := findNodeOffset(collected, next, n)
if nextOffset >= 0 {
unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset))
}
for _, b := range key {
// Bare keys can only contain A-Za-z0-9_-
if !((b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') ||
(b >= '0' && b <= '9') || b == '_' || b == '-') {
return true
}
}
return baseIdx
}
// collectNodes collects a node and all its descendants into a slice
func collectNodes(root *unstable.Node) []*unstable.Node {
var result []*unstable.Node
var visit func(n *unstable.Node)
visit = func(n *unstable.Node) {
if n == nil {
return
}
result = append(result, n)
// Visit children
it := n.Children()
for it.Next() {
child := it.Node()
visit(child)
}
}
visit(root)
return result
}
// findNodeOffset finds the offset of a node with the given index in the collected slice
func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int {
// The idx is an index into the original backing slice.
// We need to find which node in our collected slice corresponds to that index.
for i, n := range collected {
if unstable.GetNodeIndex(n, relativeTo) == idx {
return i
}
}
return -1
return false
}
type (
@@ -846,7 +801,8 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
if d.unmarshalerInterface {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return outi.UnmarshalTOML(value)
// Pass raw bytes from the original document
return outi.UnmarshalTOML(d.p.Raw(value.Raw))
}
}
}
@@ -1350,7 +1306,8 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
if d.unmarshalerInterface {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return reflect.Value{}, outi.UnmarshalTOML(value)
// Pass raw bytes from the original document
return reflect.Value{}, outi.UnmarshalTOML(d.p.Raw(value.Raw))
}
}
}