Implement bytes-based Unmarshaler interface for tables and arrays (#873)
This change brings back support for the unstable.Unmarshaler interface for tables and array tables, addressing issue #873. Key changes: - Changed UnmarshalTOML signature from (*Node) to ([]byte) to provide raw TOML bytes instead of AST nodes - Added RawMessage type (similar to json.RawMessage) for capturing raw TOML bytes for later processing - Updated handleKeyValuesUnmarshaler to reconstruct key-value lines from the parsed keys and raw value bytes - Added support for slice types implementing Unmarshaler (e.g., RawMessage) - Removed unused AST helper functions from unstable/ast.go The bytes-based interface allows users to: - Get raw TOML bytes for custom parsing - Delay TOML decoding using RawMessage - Implement custom unmarshaling logic for complex types Tests added for: - Table unmarshaler with various scenarios - Array table unmarshaler - Split tables (same parent defined in multiple places) - RawMessage usage - Nested tables and mixed regular fields
This commit is contained in:
+74
-117
@@ -56,15 +56,18 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
|
||||
|
||||
// EnableUnmarshalerInterface allows to enable unmarshaler interface.
|
||||
//
|
||||
// With this feature enabled, types implementing the unstable/Unmarshaler
|
||||
// With this feature enabled, types implementing the unstable.Unmarshaler
|
||||
// interface can be decoded from any structure of the document. It allows types
|
||||
// that don't have a straightforward TOML representation to provide their own
|
||||
// decoding logic.
|
||||
//
|
||||
// Types can decode from single values, inline tables, arrays, and standard
|
||||
// tables/array tables. When decoding from a table (e.g., [table] or [[array]]),
|
||||
// the UnmarshalTOML method receives a synthetic InlineTable node containing
|
||||
// all the key-value pairs belonging to that table.
|
||||
// The UnmarshalTOML method receives raw TOML bytes:
|
||||
// - For single values: the raw value bytes (e.g., `"hello"` for a string)
|
||||
// - For tables: all key-value lines belonging to that table
|
||||
// - For inline tables/arrays: the raw bytes of the inline structure
|
||||
//
|
||||
// The unstable.RawMessage type can be used to capture raw TOML bytes for
|
||||
// later processing, similar to json.RawMessage.
|
||||
//
|
||||
// *Unstable:* This method does not follow the compatibility guarantees of
|
||||
// semver. It can be changed or removed without a new major version being
|
||||
@@ -601,18 +604,28 @@ func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (
|
||||
// cannot handle it.
|
||||
func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
|
||||
if v.Kind() == reflect.Slice {
|
||||
if v.Len() == 0 {
|
||||
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
|
||||
// For non-empty slices, work with the last element
|
||||
if v.Len() > 0 {
|
||||
elem := v.Index(v.Len() - 1)
|
||||
x, err := d.handleTable(key, elem)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
if x.IsValid() {
|
||||
elem.Set(x)
|
||||
}
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
elem := v.Index(v.Len() - 1)
|
||||
x, err := d.handleTable(key, elem)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
// Empty slice - check if it implements Unmarshaler (e.g., RawMessage)
|
||||
// and we're at the end of the key path
|
||||
if d.unmarshalerInterface && !key.Next() {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return d.handleKeyValuesUnmarshaler(outi)
|
||||
}
|
||||
}
|
||||
}
|
||||
if x.IsValid() {
|
||||
elem.Set(x)
|
||||
}
|
||||
return reflect.Value{}, nil
|
||||
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
|
||||
}
|
||||
if key.Next() {
|
||||
// Still scoping the key
|
||||
@@ -674,18 +687,12 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
|
||||
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
|
||||
// and passes them to the Unmarshaler as a synthetic InlineTable node.
|
||||
// and passes them to the Unmarshaler as raw TOML bytes.
|
||||
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
|
||||
// We need to collect all key-value expressions and build a synthetic table.
|
||||
// The parser reuses its internal builder between expressions, so we need to
|
||||
// copy each expression's nodes immediately before parsing the next one.
|
||||
var allNodes []unstable.Node
|
||||
allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable})
|
||||
// Initialize root node's child and next to -1 (invalid reference)
|
||||
unstable.SetNodeChild(&allNodes[0], -1)
|
||||
unstable.SetNodeNext(&allNodes[0], -1)
|
||||
|
||||
var lastKVIdx int = -1
|
||||
// Collect raw bytes from all key-value expressions for this table.
|
||||
// We build a valid TOML document by reconstructing each key-value line
|
||||
// from the key names and the value's raw bytes.
|
||||
var buf []byte
|
||||
|
||||
for d.nextExpr() {
|
||||
expr := d.expr()
|
||||
@@ -699,107 +706,55 @@ func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Va
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Deep copy this expression's nodes into our slice before parsing the next
|
||||
kvIdx := d.copyExpressionNodes(&allNodes, expr)
|
||||
|
||||
// Link to previous sibling or set as first child of root
|
||||
if lastKVIdx == -1 {
|
||||
unstable.SetNodeChild(&allNodes[0], int32(kvIdx))
|
||||
} else {
|
||||
unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx))
|
||||
// Reconstruct the key-value line from the key(s) and value
|
||||
keyIt := expr.Key()
|
||||
first := true
|
||||
for keyIt.Next() {
|
||||
if !first {
|
||||
buf = append(buf, '.')
|
||||
}
|
||||
keyNode := keyIt.Node()
|
||||
// Check if key needs quoting
|
||||
if keyNeedsQuoting(keyNode.Data) {
|
||||
buf = append(buf, '"')
|
||||
buf = append(buf, keyNode.Data...)
|
||||
buf = append(buf, '"')
|
||||
} else {
|
||||
buf = append(buf, keyNode.Data...)
|
||||
}
|
||||
first = false
|
||||
}
|
||||
lastKVIdx = kvIdx
|
||||
buf = append(buf, " = "...)
|
||||
|
||||
// Get the raw value bytes
|
||||
value := expr.Value()
|
||||
if value != nil {
|
||||
raw := d.p.Raw(value.Raw)
|
||||
buf = append(buf, raw...)
|
||||
}
|
||||
buf = append(buf, '\n')
|
||||
}
|
||||
|
||||
// Set up all nodes with the backing slice
|
||||
for i := range allNodes {
|
||||
unstable.SetNodeSlice(&allNodes[i], &allNodes)
|
||||
}
|
||||
|
||||
if err := u.UnmarshalTOML(&allNodes[0]); err != nil {
|
||||
if err := u.UnmarshalTOML(buf); err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
return reflect.Value{}, nil
|
||||
}
|
||||
|
||||
// copyExpressionNodes recursively copies all nodes from an expression into
|
||||
// the destination slice. Returns the index of the root node in the destination.
|
||||
// Note: The caller is responsible for setting the nodes slice on all copied nodes
|
||||
// after all expressions have been collected.
|
||||
func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int {
|
||||
// Recursively collect all nodes in this expression tree
|
||||
collected := collectNodes(node)
|
||||
|
||||
// Calculate the offset for this batch
|
||||
baseIdx := len(*dst)
|
||||
|
||||
// Copy all nodes with child/next initialized to -1 (invalid reference)
|
||||
for _, n := range collected {
|
||||
copied := unstable.Node{
|
||||
Kind: n.Kind,
|
||||
Raw: n.Raw,
|
||||
Data: n.Data,
|
||||
}
|
||||
*dst = append(*dst, copied)
|
||||
// Initialize child and next to invalid reference (-1)
|
||||
// Go's zero value is 0, which would incorrectly point to the first node
|
||||
unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1)
|
||||
unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1)
|
||||
// keyNeedsQuoting returns true if the key needs to be quoted in TOML.
|
||||
func keyNeedsQuoting(key []byte) bool {
|
||||
if len(key) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Now fix up the child and next indices
|
||||
for i, n := range collected {
|
||||
dstIdx := baseIdx + i
|
||||
if child := unstable.GetNodeChild(n); child >= 0 {
|
||||
// Find the position of the child in our collected slice
|
||||
childOffset := findNodeOffset(collected, child, n)
|
||||
if childOffset >= 0 {
|
||||
unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset))
|
||||
}
|
||||
}
|
||||
if next := unstable.GetNodeNext(n); next >= 0 {
|
||||
// Find the position of the next in our collected slice
|
||||
nextOffset := findNodeOffset(collected, next, n)
|
||||
if nextOffset >= 0 {
|
||||
unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset))
|
||||
}
|
||||
for _, b := range key {
|
||||
// Bare keys can only contain A-Za-z0-9_-
|
||||
if !((b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') || b == '_' || b == '-') {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return baseIdx
|
||||
}
|
||||
|
||||
// collectNodes collects a node and all its descendants into a slice
|
||||
func collectNodes(root *unstable.Node) []*unstable.Node {
|
||||
var result []*unstable.Node
|
||||
var visit func(n *unstable.Node)
|
||||
visit = func(n *unstable.Node) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
result = append(result, n)
|
||||
// Visit children
|
||||
it := n.Children()
|
||||
for it.Next() {
|
||||
child := it.Node()
|
||||
visit(child)
|
||||
}
|
||||
}
|
||||
visit(root)
|
||||
return result
|
||||
}
|
||||
|
||||
// findNodeOffset finds the offset of a node with the given index in the collected slice
|
||||
func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int {
|
||||
// The idx is an index into the original backing slice.
|
||||
// We need to find which node in our collected slice corresponds to that index.
|
||||
for i, n := range collected {
|
||||
if unstable.GetNodeIndex(n, relativeTo) == idx {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
return false
|
||||
}
|
||||
|
||||
type (
|
||||
@@ -846,7 +801,8 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
|
||||
if d.unmarshalerInterface {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return outi.UnmarshalTOML(value)
|
||||
// Pass raw bytes from the original document
|
||||
return outi.UnmarshalTOML(d.p.Raw(value.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1350,7 +1306,8 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
|
||||
if d.unmarshalerInterface {
|
||||
if v.CanAddr() && v.Addr().CanInterface() {
|
||||
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
|
||||
return reflect.Value{}, outi.UnmarshalTOML(value)
|
||||
// Pass raw bytes from the original document
|
||||
return reflect.Value{}, outi.UnmarshalTOML(d.p.Raw(value.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+100
-23
@@ -3900,8 +3900,8 @@ type CustomUnmarshalerKey struct {
|
||||
A int64
|
||||
}
|
||||
|
||||
func (k *CustomUnmarshalerKey) UnmarshalTOML(value *unstable.Node) error {
|
||||
item, err := strconv.ParseInt(string(value.Data), 10, 64)
|
||||
func (k *CustomUnmarshalerKey) UnmarshalTOML(data []byte) error {
|
||||
item, err := strconv.ParseInt(string(data), 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting to int64, %w", err)
|
||||
}
|
||||
@@ -3989,7 +3989,7 @@ foo = "bar"`,
|
||||
|
||||
type doc994 struct{}
|
||||
|
||||
func (d *doc994) UnmarshalTOML(*unstable.Node) error {
|
||||
func (d *doc994) UnmarshalTOML([]byte) error {
|
||||
return errors.New("expected-error")
|
||||
}
|
||||
|
||||
@@ -4012,8 +4012,8 @@ type doc994ok struct {
|
||||
S string
|
||||
}
|
||||
|
||||
func (d *doc994ok) UnmarshalTOML(value *unstable.Node) error {
|
||||
d.S = string(value.Data) + " from unmarshaler"
|
||||
func (d *doc994ok) UnmarshalTOML(data []byte) error {
|
||||
d.S = string(data) + " from unmarshaler"
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4026,7 +4026,8 @@ func TestIssue994_OK(t *testing.T) {
|
||||
Decode(&d)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bar from unmarshaler", d.S)
|
||||
// With bytes-based interface, raw TOML bytes are passed including quotes
|
||||
assert.Equal(t, "\"bar\" from unmarshaler", d.S)
|
||||
}
|
||||
|
||||
func TestIssue995(t *testing.T) {
|
||||
@@ -4393,29 +4394,35 @@ type customTable873 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
|
||||
func (c *customTable873) UnmarshalTOML(node *unstable.Node) error {
|
||||
func (c *customTable873) UnmarshalTOML(data []byte) error {
|
||||
c.Keys = []string{}
|
||||
c.Values = make(map[string]string)
|
||||
|
||||
it := node.Children()
|
||||
for it.Next() {
|
||||
kv := it.Node()
|
||||
if kv.Kind != unstable.KeyValue {
|
||||
// Parse the raw TOML bytes into a map to extract keys in order
|
||||
// For this test, we use a simple line-by-line parser to preserve order
|
||||
lines := bytes.Split(data, []byte{'\n'})
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
// Get the key
|
||||
keyIt := kv.Key()
|
||||
if keyIt.Next() {
|
||||
keyNode := keyIt.Node()
|
||||
key := string(keyNode.Data)
|
||||
c.Keys = append(c.Keys, key)
|
||||
|
||||
// Get the value
|
||||
valueNode := kv.Value()
|
||||
if valueNode != nil && valueNode.Kind == unstable.String {
|
||||
c.Values[key] = string(valueNode.Data)
|
||||
}
|
||||
// Skip table headers
|
||||
if line[0] == '[' {
|
||||
continue
|
||||
}
|
||||
// Parse key = value
|
||||
eqIdx := bytes.Index(line, []byte{'='})
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
key := string(bytes.TrimSpace(line[:eqIdx]))
|
||||
valueBytes := bytes.TrimSpace(line[eqIdx+1:])
|
||||
// Remove quotes from string values
|
||||
if len(valueBytes) >= 2 && valueBytes[0] == '"' && valueBytes[len(valueBytes)-1] == '"' {
|
||||
valueBytes = valueBytes[1 : len(valueBytes)-1]
|
||||
}
|
||||
c.Keys = append(c.Keys, key)
|
||||
c.Values[key] = string(valueBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -4676,3 +4683,73 @@ value = "two"
|
||||
assert.Equal(t, "second", d.Tables[1].Values["name"])
|
||||
assert.Equal(t, "two", d.Tables[1].Values["value"])
|
||||
}
|
||||
|
||||
// Test for split tables - when the same parent table is defined in multiple places
|
||||
// This is a key requirement for issue #873: if type A implements Unmarshaler,
|
||||
// and [a.b] and [a.d] are defined with another table [x] in between,
|
||||
// A should receive content for both b and d, but not x.
|
||||
func TestIssue873_SplitTables(t *testing.T) {
|
||||
// splitTableUnmarshaler collects sub-table names it sees
|
||||
type splitTableUnmarshaler struct {
|
||||
SubTables map[string]map[string]string
|
||||
}
|
||||
|
||||
// For this test, we expect each sub-table to be handled separately
|
||||
// The parent doesn't receive the sub-tables directly - each sub-table
|
||||
// (b and d) gets its own call to handleKeyValues
|
||||
type Config struct {
|
||||
A struct {
|
||||
B customTable873 `toml:"b"`
|
||||
D customTable873 `toml:"d"`
|
||||
} `toml:"a"`
|
||||
X customTable873 `toml:"x"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[a.b]
|
||||
C = "1"
|
||||
|
||||
[x]
|
||||
Y = "100"
|
||||
|
||||
[a.d]
|
||||
E = "2"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
// Each sub-table should have received its own key-values
|
||||
assert.Equal(t, []string{"C"}, cfg.A.B.Keys)
|
||||
assert.Equal(t, "1", cfg.A.B.Values["C"])
|
||||
assert.Equal(t, []string{"E"}, cfg.A.D.Keys)
|
||||
assert.Equal(t, "2", cfg.A.D.Values["E"])
|
||||
assert.Equal(t, []string{"Y"}, cfg.X.Keys)
|
||||
assert.Equal(t, "100", cfg.X.Values["Y"])
|
||||
}
|
||||
|
||||
// Test using RawMessage to capture raw TOML bytes
|
||||
func TestIssue873_RawMessage(t *testing.T) {
|
||||
type Config struct {
|
||||
Plugin unstable.RawMessage `toml:"plugin"`
|
||||
}
|
||||
|
||||
doc := `
|
||||
[plugin]
|
||||
name = "example"
|
||||
version = "1.0"
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
err := toml.NewDecoder(bytes.NewReader([]byte(doc))).
|
||||
EnableUnmarshalerInterface().
|
||||
Decode(&cfg)
|
||||
|
||||
assert.NoError(t, err)
|
||||
// RawMessage should contain the raw key-value bytes
|
||||
expected := "name = \"example\"\nversion = \"1.0\"\n"
|
||||
assert.Equal(t, expected, string(cfg.Plugin))
|
||||
}
|
||||
|
||||
@@ -143,49 +143,3 @@ func (n *Node) Value() *Node {
|
||||
func (n *Node) Children() Iterator {
|
||||
return Iterator{nodes: n.nodes, idx: n.child}
|
||||
}
|
||||
|
||||
// SetNodeSlice sets the backing nodes slice for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeSlice(n *Node, nodes *[]Node) {
|
||||
n.nodes = nodes
|
||||
}
|
||||
|
||||
// SetNodeChild sets the child index for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeChild(n *Node, child int32) {
|
||||
n.child = child
|
||||
}
|
||||
|
||||
// SetNodeNext sets the next sibling index for a node.
|
||||
// This is used when building synthetic nodes.
|
||||
func SetNodeNext(n *Node, next int32) {
|
||||
n.next = next
|
||||
}
|
||||
|
||||
// GetNodeChild returns the child index for a node.
|
||||
// This is used when copying nodes.
|
||||
func GetNodeChild(n *Node) int32 {
|
||||
return n.child
|
||||
}
|
||||
|
||||
// GetNodeNext returns the next sibling index for a node.
|
||||
// This is used when copying nodes.
|
||||
func GetNodeNext(n *Node) int32 {
|
||||
return n.next
|
||||
}
|
||||
|
||||
// GetNodeIndex returns the index of node n in the backing slice,
|
||||
// using relativeTo's nodes slice as reference.
|
||||
// Returns -1 if n is not in the slice.
|
||||
func GetNodeIndex(n *Node, relativeTo *Node) int32 {
|
||||
if relativeTo.nodes == nil || n == nil {
|
||||
return -1
|
||||
}
|
||||
nodes := *relativeTo.nodes
|
||||
for i := range nodes {
|
||||
if &nodes[i] == n {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
+28
-3
@@ -1,7 +1,32 @@
|
||||
package unstable
|
||||
|
||||
// The Unmarshaler interface may be implemented by types to customize their
|
||||
// behavior when being unmarshaled from a TOML document.
|
||||
// Unmarshaler is implemented by types that can unmarshal a TOML
|
||||
// description of themselves. The input is a valid TOML document
|
||||
// containing the relevant portion of the parsed document.
|
||||
//
|
||||
// For tables (including split tables defined in multiple places),
|
||||
// the data contains the raw key-value bytes from the original document
|
||||
// with adjusted table headers to be relative to the unmarshaling target.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalTOML(value *Node) error
|
||||
UnmarshalTOML(data []byte) error
|
||||
}
|
||||
|
||||
// RawMessage is a raw encoded TOML value. It implements Unmarshaler
|
||||
// and can be used to delay TOML decoding or capture raw content.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// type Config struct {
|
||||
// Plugin RawMessage `toml:"plugin"`
|
||||
// }
|
||||
//
|
||||
// var cfg Config
|
||||
// toml.NewDecoder(r).EnableUnmarshalerInterface().Decode(&cfg)
|
||||
// // cfg.Plugin now contains the raw TOML bytes for [plugin]
|
||||
type RawMessage []byte
|
||||
|
||||
// UnmarshalTOML implements Unmarshaler.
|
||||
func (m *RawMessage) UnmarshalTOML(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user