Implement bytes-based Unmarshaler interface for tables and arrays (#873)
This change brings back support for the unstable.Unmarshaler interface for tables and array tables, addressing issue #873. Key changes: - Changed UnmarshalTOML signature from (*Node) to ([]byte) to provide raw TOML bytes instead of AST nodes - Added RawMessage type (similar to json.RawMessage) for capturing raw TOML bytes for later processing - Updated handleKeyValuesUnmarshaler to reconstruct key-value lines from the parsed keys and raw value bytes - Added support for slice types implementing Unmarshaler (e.g., RawMessage) - Removed unused AST helper functions from unstable/ast.go The bytes-based interface allows users to: - Get raw TOML bytes for custom parsing - Delay TOML decoding using RawMessage - Implement custom unmarshaling logic for complex types Tests added for: - Table unmarshaler with various scenarios - Array table unmarshaler - Split tables (same parent defined in multiple places) - RawMessage usage - Nested tables and mixed regular fields
This commit is contained in:
+74
-117
@@ -56,15 +56,18 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
|
||||
|
||||
// EnableUnmarshalerInterface allows to enable unmarshaler interface.
|
||||
//
|
||||
// With this feature enabled, types implementing the unstable/Unmarshaler
|
||||
// With this feature enabled, types implementing the unstable.Unmarshaler
|
||||
// interface can be decoded from any structure of the document. It allows types
|
||||
// that don't have a straightforward TOML representation to provide their own
|
||||
// decoding logic.
|
||||
//
|
||||
// Types can decode from single values, inline tables, arrays, and standard
|
||||
// tables/array tables. When decoding from a table (e.g., [table] or [[array]]),
|
||||
// the UnmarshalTOML method receives a synthetic InlineTable node containing
|
||||
// all the key-value pairs belonging to that table.
|
||||
// The UnmarshalTOML method receives raw TOML bytes:
|
||||
// - For single values: the raw value bytes (e.g., `"hello"` for a string)
|
||||
// - For tables: all key-value lines belonging to that table
|
||||
// - For inline tables/arrays: the raw bytes of the inline structure
|
||||
//
|
||||
// The unstable.RawMessage type can be used to capture raw TOML bytes for
|
||||
// later processing, similar to json.RawMessage.
|
||||
//
|
||||
// *Unstable:* This method does not follow the compatibility guarantees of
|
||||
// semver. It can be changed or removed without a new major version being
|
||||
@@ -601,18 +604,28 @@ func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (
|
||||
// cannot handle it.
|
||||
func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
|
||||
if v.Kind() == reflect.Slice {
|
||||
if v.Len() == 0 {
|
||||
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
|
||||
// For non-empty slices, work with the last element
|
||||
if v.Len() > 0 {
|
||||
elem := v.Index(v.Len() - 1)
|
||||
x, err := d.handleTable(key, elem)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
if x.IsValid() {
|
||||
elem.Set(x)
|
||||
}
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
elem := v.Index(v.Len() - 1)
|
||||
x, err := d.handleTable(key, elem)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
// Empty slice - check if it implements Unmarshaler (e.g., RawMessage)
|
||||
// and we're at the end of the key path
|
||||
if d.unmarshalerInterface && !key.Next() {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return d.handleKeyValuesUnmarshaler(outi)
|
||||
}
|
||||
}
|
||||
}
|
||||
if x.IsValid() {
|
||||
elem.Set(x)
|
||||
}
|
||||
return reflect.Value{}, nil
|
||||
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
|
||||
}
|
||||
if key.Next() {
|
||||
// Still scoping the key
|
||||
@@ -674,18 +687,12 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
|
||||
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
|
||||
// and passes them to the Unmarshaler as a synthetic InlineTable node.
|
||||
// and passes them to the Unmarshaler as raw TOML bytes.
|
||||
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
|
||||
// We need to collect all key-value expressions and build a synthetic table.
|
||||
// The parser reuses its internal builder between expressions, so we need to
|
||||
// copy each expression's nodes immediately before parsing the next one.
|
||||
var allNodes []unstable.Node
|
||||
allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable})
|
||||
// Initialize root node's child and next to -1 (invalid reference)
|
||||
unstable.SetNodeChild(&allNodes[0], -1)
|
||||
unstable.SetNodeNext(&allNodes[0], -1)
|
||||
|
||||
var lastKVIdx int = -1
|
||||
// Collect raw bytes from all key-value expressions for this table.
|
||||
// We build a valid TOML document by reconstructing each key-value line
|
||||
// from the key names and the value's raw bytes.
|
||||
var buf []byte
|
||||
|
||||
for d.nextExpr() {
|
||||
expr := d.expr()
|
||||
@@ -699,107 +706,55 @@ func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Va
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Deep copy this expression's nodes into our slice before parsing the next
|
||||
kvIdx := d.copyExpressionNodes(&allNodes, expr)
|
||||
|
||||
// Link to previous sibling or set as first child of root
|
||||
if lastKVIdx == -1 {
|
||||
unstable.SetNodeChild(&allNodes[0], int32(kvIdx))
|
||||
} else {
|
||||
unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx))
|
||||
// Reconstruct the key-value line from the key(s) and value
|
||||
keyIt := expr.Key()
|
||||
first := true
|
||||
for keyIt.Next() {
|
||||
if !first {
|
||||
buf = append(buf, '.')
|
||||
}
|
||||
keyNode := keyIt.Node()
|
||||
// Check if key needs quoting
|
||||
if keyNeedsQuoting(keyNode.Data) {
|
||||
buf = append(buf, '"')
|
||||
buf = append(buf, keyNode.Data...)
|
||||
buf = append(buf, '"')
|
||||
} else {
|
||||
buf = append(buf, keyNode.Data...)
|
||||
}
|
||||
first = false
|
||||
}
|
||||
lastKVIdx = kvIdx
|
||||
buf = append(buf, " = "...)
|
||||
|
||||
// Get the raw value bytes
|
||||
value := expr.Value()
|
||||
if value != nil {
|
||||
raw := d.p.Raw(value.Raw)
|
||||
buf = append(buf, raw...)
|
||||
}
|
||||
buf = append(buf, '\n')
|
||||
}
|
||||
|
||||
// Set up all nodes with the backing slice
|
||||
for i := range allNodes {
|
||||
unstable.SetNodeSlice(&allNodes[i], &allNodes)
|
||||
}
|
||||
|
||||
if err := u.UnmarshalTOML(&allNodes[0]); err != nil {
|
||||
if err := u.UnmarshalTOML(buf); err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// copyExpressionNodes recursively copies all nodes from an expression into
|
||||
// the destination slice. Returns the index of the root node in the destination.
|
||||
// Note: The caller is responsible for setting the nodes slice on all copied nodes
|
||||
// after all expressions have been collected.
|
||||
func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int {
|
||||
// Recursively collect all nodes in this expression tree
|
||||
collected := collectNodes(node)
|
||||
|
||||
// Calculate the offset for this batch
|
||||
baseIdx := len(*dst)
|
||||
|
||||
// Copy all nodes with child/next initialized to -1 (invalid reference)
|
||||
for _, n := range collected {
|
||||
copied := unstable.Node{
|
||||
Kind: n.Kind,
|
||||
Raw: n.Raw,
|
||||
Data: n.Data,
|
||||
}
|
||||
*dst = append(*dst, copied)
|
||||
// Initialize child and next to invalid reference (-1)
|
||||
// Go's zero value is 0, which would incorrectly point to the first node
|
||||
unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1)
|
||||
unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1)
|
||||
// keyNeedsQuoting returns true if the key needs to be quoted in TOML.
|
||||
func keyNeedsQuoting(key []byte) bool {
|
||||
if len(key) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Now fix up the child and next indices
|
||||
for i, n := range collected {
|
||||
dstIdx := baseIdx + i
|
||||
if child := unstable.GetNodeChild(n); child >= 0 {
|
||||
// Find the position of the child in our collected slice
|
||||
childOffset := findNodeOffset(collected, child, n)
|
||||
if childOffset >= 0 {
|
||||
unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset))
|
||||
}
|
||||
}
|
||||
if next := unstable.GetNodeNext(n); next >= 0 {
|
||||
// Find the position of the next in our collected slice
|
||||
nextOffset := findNodeOffset(collected, next, n)
|
||||
if nextOffset >= 0 {
|
||||
unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset))
|
||||
}
|
||||
for _, b := range key {
|
||||
// Bare keys can only contain A-Za-z0-9_-
|
||||
if !((b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') || b == '_' || b == '-') {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return baseIdx
|
||||
}
|
||||
|
||||
// collectNodes collects a node and all its descendants into a slice
|
||||
func collectNodes(root *unstable.Node) []*unstable.Node {
|
||||
var result []*unstable.Node
|
||||
var visit func(n *unstable.Node)
|
||||
visit = func(n *unstable.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
result = append(result, n)
|
||||
// Visit children
|
||||
it := n.Children()
|
||||
for it.Next() {
|
||||
child := it.Node()
|
||||
visit(child)
|
||||
}
|
||||
}
|
||||
visit(root)
|
||||
return result
|
||||
}
|
||||
|
||||
// findNodeOffset finds the offset of a node with the given index in the collected slice
|
||||
func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int {
|
||||
// The idx is an index into the original backing slice.
|
||||
// We need to find which node in our collected slice corresponds to that index.
|
||||
for i, n := range collected {
|
||||
if unstable.GetNodeIndex(n, relativeTo) == idx {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
return false
|
||||
}
|
||||
|
||||
type (
|
||||
@@ -846,7 +801,8 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
|
||||
if d.unmarshalerInterface {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return outi.UnmarshalTOML(value)
|
||||
// Pass raw bytes from the original document
|
||||
return outi.UnmarshalTOML(d.p.Raw(value.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1350,7 +1306,8 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
|
||||
if d.unmarshalerInterface {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return reflect.Value{}, outi.UnmarshalTOML(value)
|
||||
// Pass raw bytes from the original document
|
||||
return reflect.Value{}, outi.UnmarshalTOML(d.p.Raw(value.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user