diff --git a/unmarshaler.go b/unmarshaler.go index edc7c80..defdfec 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -56,15 +56,18 @@ func (d *Decoder) DisallowUnknownFields() *Decoder { // EnableUnmarshalerInterface allows to enable unmarshaler interface. // -// With this feature enabled, types implementing the unstable/Unmarshaler +// With this feature enabled, types implementing the unstable.Unmarshaler // interface can be decoded from any structure of the document. It allows types // that don't have a straightforward TOML representation to provide their own // decoding logic. // -// Types can decode from single values, inline tables, arrays, and standard -// tables/array tables. When decoding from a table (e.g., [table] or [[array]]), -// the UnmarshalTOML method receives a synthetic InlineTable node containing -// all the key-value pairs belonging to that table. +// The UnmarshalTOML method receives raw TOML bytes: +// - For single values: the raw value bytes (e.g., `"hello"` for a string) +// - For tables: all key-value lines belonging to that table +// - For inline tables/arrays: the raw bytes of the inline structure +// +// The unstable.RawMessage type can be used to capture raw TOML bytes for +// later processing, similar to json.RawMessage. // // *Unstable:* This method does not follow the compatibility guarantees of // semver. It can be changed or removed without a new major version being @@ -601,18 +604,28 @@ func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) ( // cannot handle it. func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) { if v.Kind() == reflect.Slice { - if v.Len() == 0 { - return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice") + // For non-empty slices, work with the last element + if v.Len() > 0 { + elem := v.Index(v.Len() - 1) + x, err := d.handleTable(key, elem) + if err != nil { + return reflect.Value{}, err + } + if x.IsValid() { + elem.Set(x) + } + return reflect.Value{}, nil } - elem := v.Index(v.Len() - 1) - x, err := d.handleTable(key, elem) - if err != nil { - return reflect.Value{}, err + // Empty slice - check if it implements Unmarshaler (e.g., RawMessage) + // and we're at the end of the key path + if d.unmarshalerInterface && !key.Next() { + if v.CanAddr() && v.Addr().CanInterface() { + if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok { + return d.handleKeyValuesUnmarshaler(outi) + } + } } - if x.IsValid() { - elem.Set(x) - } - return reflect.Value{}, nil + return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice") } if key.Next() { // Still scoping the key @@ -674,18 +687,12 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { } // handleKeyValuesUnmarshaler collects all key-value expressions for a table -// and passes them to the Unmarshaler as a synthetic InlineTable node. +// and passes them to the Unmarshaler as raw TOML bytes. func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) { - // We need to collect all key-value expressions and build a synthetic table. - // The parser reuses its internal builder between expressions, so we need to - // copy each expression's nodes immediately before parsing the next one. - var allNodes []unstable.Node - allNodes = append(allNodes, unstable.Node{Kind: unstable.InlineTable}) - // Initialize root node's child and next to -1 (invalid reference) - unstable.SetNodeChild(&allNodes[0], -1) - unstable.SetNodeNext(&allNodes[0], -1) - - var lastKVIdx int = -1 + // Collect raw bytes from all key-value expressions for this table. + // We build a valid TOML document by reconstructing each key-value line + // from the key names and the value's raw bytes. + var buf []byte for d.nextExpr() { expr := d.expr() @@ -699,107 +706,55 @@ func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Va return reflect.Value{}, err } - // Deep copy this expression's nodes into our slice before parsing the next - kvIdx := d.copyExpressionNodes(&allNodes, expr) - - // Link to previous sibling or set as first child of root - if lastKVIdx == -1 { - unstable.SetNodeChild(&allNodes[0], int32(kvIdx)) - } else { - unstable.SetNodeNext(&allNodes[lastKVIdx], int32(kvIdx)) + // Reconstruct the key-value line from the key(s) and value + keyIt := expr.Key() + first := true + for keyIt.Next() { + if !first { + buf = append(buf, '.') + } + keyNode := keyIt.Node() + // Check if key needs quoting + if keyNeedsQuoting(keyNode.Data) { + buf = append(buf, '"') + buf = append(buf, keyNode.Data...) + buf = append(buf, '"') + } else { + buf = append(buf, keyNode.Data...) + } + first = false } - lastKVIdx = kvIdx + buf = append(buf, " = "...) + + // Get the raw value bytes + value := expr.Value() + if value != nil { + raw := d.p.Raw(value.Raw) + buf = append(buf, raw...) + } + buf = append(buf, '\n') } - // Set up all nodes with the backing slice - for i := range allNodes { - unstable.SetNodeSlice(&allNodes[i], &allNodes) - } - - if err := u.UnmarshalTOML(&allNodes[0]); err != nil { + if err := u.UnmarshalTOML(buf); err != nil { return reflect.Value{}, err } return reflect.Value{}, nil } -// copyExpressionNodes recursively copies all nodes from an expression into -// the destination slice. Returns the index of the root node in the destination. -// Note: The caller is responsible for setting the nodes slice on all copied nodes -// after all expressions have been collected. -func (d *decoder) copyExpressionNodes(dst *[]unstable.Node, node *unstable.Node) int { - // Recursively collect all nodes in this expression tree - collected := collectNodes(node) - - // Calculate the offset for this batch - baseIdx := len(*dst) - - // Copy all nodes with child/next initialized to -1 (invalid reference) - for _, n := range collected { - copied := unstable.Node{ - Kind: n.Kind, - Raw: n.Raw, - Data: n.Data, - } - *dst = append(*dst, copied) - // Initialize child and next to invalid reference (-1) - // Go's zero value is 0, which would incorrectly point to the first node - unstable.SetNodeChild(&(*dst)[len(*dst)-1], -1) - unstable.SetNodeNext(&(*dst)[len(*dst)-1], -1) +// keyNeedsQuoting returns true if the key needs to be quoted in TOML. +func keyNeedsQuoting(key []byte) bool { + if len(key) == 0 { + return true } - - // Now fix up the child and next indices - for i, n := range collected { - dstIdx := baseIdx + i - if child := unstable.GetNodeChild(n); child >= 0 { - // Find the position of the child in our collected slice - childOffset := findNodeOffset(collected, child, n) - if childOffset >= 0 { - unstable.SetNodeChild(&(*dst)[dstIdx], int32(baseIdx+childOffset)) - } - } - if next := unstable.GetNodeNext(n); next >= 0 { - // Find the position of the next in our collected slice - nextOffset := findNodeOffset(collected, next, n) - if nextOffset >= 0 { - unstable.SetNodeNext(&(*dst)[dstIdx], int32(baseIdx+nextOffset)) - } + for _, b := range key { + // Bare keys can only contain A-Za-z0-9_- + if !((b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || + (b >= '0' && b <= '9') || b == '_' || b == '-') { + return true } } - - return baseIdx -} - -// collectNodes collects a node and all its descendants into a slice -func collectNodes(root *unstable.Node) []*unstable.Node { - var result []*unstable.Node - var visit func(n *unstable.Node) - visit = func(n *unstable.Node) { - if n == nil { - return - } - result = append(result, n) - // Visit children - it := n.Children() - for it.Next() { - child := it.Node() - visit(child) - } - } - visit(root) - return result -} - -// findNodeOffset finds the offset of a node with the given index in the collected slice -func findNodeOffset(collected []*unstable.Node, idx int32, relativeTo *unstable.Node) int { - // The idx is an index into the original backing slice. - // We need to find which node in our collected slice corresponds to that index. - for i, n := range collected { - if unstable.GetNodeIndex(n, relativeTo) == idx { - return i - } - } - return -1 + return false } type ( @@ -846,7 +801,8 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error { if d.unmarshalerInterface { if v.CanAddr() && v.Addr().CanInterface() { if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok { - return outi.UnmarshalTOML(value) + // Pass raw bytes from the original document + return outi.UnmarshalTOML(d.p.Raw(value.Raw)) } } } @@ -1350,7 +1306,8 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node if d.unmarshalerInterface { if v.CanAddr() && v.Addr().CanInterface() { if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok { - return reflect.Value{}, outi.UnmarshalTOML(value) + // Pass raw bytes from the original document + return reflect.Value{}, outi.UnmarshalTOML(d.p.Raw(value.Raw)) } } } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index f820d18..f9f696a 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -3900,8 +3900,8 @@ type CustomUnmarshalerKey struct { A int64 } -func (k *CustomUnmarshalerKey) UnmarshalTOML(value *unstable.Node) error { - item, err := strconv.ParseInt(string(value.Data), 10, 64) +func (k *CustomUnmarshalerKey) UnmarshalTOML(data []byte) error { + item, err := strconv.ParseInt(string(data), 10, 64) if err != nil { return fmt.Errorf("error converting to int64, %w", err) } @@ -3989,7 +3989,7 @@ foo = "bar"`, type doc994 struct{} -func (d *doc994) UnmarshalTOML(*unstable.Node) error { +func (d *doc994) UnmarshalTOML([]byte) error { return errors.New("expected-error") } @@ -4012,8 +4012,8 @@ type doc994ok struct { S string } -func (d *doc994ok) UnmarshalTOML(value *unstable.Node) error { - d.S = string(value.Data) + " from unmarshaler" +func (d *doc994ok) UnmarshalTOML(data []byte) error { + d.S = string(data) + " from unmarshaler" return nil } @@ -4026,7 +4026,8 @@ func TestIssue994_OK(t *testing.T) { Decode(&d) assert.NoError(t, err) - assert.Equal(t, "bar from unmarshaler", d.S) + // With bytes-based interface, raw TOML bytes are passed including quotes + assert.Equal(t, "\"bar\" from unmarshaler", d.S) } func TestIssue995(t *testing.T) { @@ -4393,29 +4394,35 @@ type customTable873 struct { Values map[string]string } -func (c *customTable873) UnmarshalTOML(node *unstable.Node) error { +func (c *customTable873) UnmarshalTOML(data []byte) error { c.Keys = []string{} c.Values = make(map[string]string) - it := node.Children() - for it.Next() { - kv := it.Node() - if kv.Kind != unstable.KeyValue { + // Parse the raw TOML bytes into a map to extract keys in order + // For this test, we use a simple line-by-line parser to preserve order + lines := bytes.Split(data, []byte{'\n'}) + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { continue } - // Get the key - keyIt := kv.Key() - if keyIt.Next() { - keyNode := keyIt.Node() - key := string(keyNode.Data) - c.Keys = append(c.Keys, key) - - // Get the value - valueNode := kv.Value() - if valueNode != nil && valueNode.Kind == unstable.String { - c.Values[key] = string(valueNode.Data) - } + // Skip table headers + if line[0] == '[' { + continue } + // Parse key = value + eqIdx := bytes.Index(line, []byte{'='}) + if eqIdx < 0 { + continue + } + key := string(bytes.TrimSpace(line[:eqIdx])) + valueBytes := bytes.TrimSpace(line[eqIdx+1:]) + // Remove quotes from string values + if len(valueBytes) >= 2 && valueBytes[0] == '"' && valueBytes[len(valueBytes)-1] == '"' { + valueBytes = valueBytes[1 : len(valueBytes)-1] + } + c.Keys = append(c.Keys, key) + c.Values[key] = string(valueBytes) } return nil } @@ -4676,3 +4683,73 @@ value = "two" assert.Equal(t, "second", d.Tables[1].Values["name"]) assert.Equal(t, "two", d.Tables[1].Values["value"]) } + +// Test for split tables - when the same parent table is defined in multiple places +// This is a key requirement for issue #873: if type A implements Unmarshaler, +// and [a.b] and [a.d] are defined with another table [x] in between, +// A should receive content for both b and d, but not x. +func TestIssue873_SplitTables(t *testing.T) { + // splitTableUnmarshaler collects sub-table names it sees + type splitTableUnmarshaler struct { + SubTables map[string]map[string]string + } + + // For this test, we expect each sub-table to be handled separately + // The parent doesn't receive the sub-tables directly - each sub-table + // (b and d) gets its own call to handleKeyValues + type Config struct { + A struct { + B customTable873 `toml:"b"` + D customTable873 `toml:"d"` + } `toml:"a"` + X customTable873 `toml:"x"` + } + + doc := ` +[a.b] +C = "1" + +[x] +Y = "100" + +[a.d] +E = "2" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + // Each sub-table should have received its own key-values + assert.Equal(t, []string{"C"}, cfg.A.B.Keys) + assert.Equal(t, "1", cfg.A.B.Values["C"]) + assert.Equal(t, []string{"E"}, cfg.A.D.Keys) + assert.Equal(t, "2", cfg.A.D.Values["E"]) + assert.Equal(t, []string{"Y"}, cfg.X.Keys) + assert.Equal(t, "100", cfg.X.Values["Y"]) +} + +// Test using RawMessage to capture raw TOML bytes +func TestIssue873_RawMessage(t *testing.T) { + type Config struct { + Plugin unstable.RawMessage `toml:"plugin"` + } + + doc := ` +[plugin] +name = "example" +version = "1.0" +` + + var cfg Config + err := toml.NewDecoder(bytes.NewReader([]byte(doc))). + EnableUnmarshalerInterface(). + Decode(&cfg) + + assert.NoError(t, err) + // RawMessage should contain the raw key-value bytes + expected := "name = \"example\"\nversion = \"1.0\"\n" + assert.Equal(t, expected, string(cfg.Plugin)) +} diff --git a/unstable/ast.go b/unstable/ast.go index 0d71ac3..34ef628 100644 --- a/unstable/ast.go +++ b/unstable/ast.go @@ -143,49 +143,3 @@ func (n *Node) Value() *Node { func (n *Node) Children() Iterator { return Iterator{nodes: n.nodes, idx: n.child} } - -// SetNodeSlice sets the backing nodes slice for a node. -// This is used when building synthetic nodes. -func SetNodeSlice(n *Node, nodes *[]Node) { - n.nodes = nodes -} - -// SetNodeChild sets the child index for a node. -// This is used when building synthetic nodes. -func SetNodeChild(n *Node, child int32) { - n.child = child -} - -// SetNodeNext sets the next sibling index for a node. -// This is used when building synthetic nodes. -func SetNodeNext(n *Node, next int32) { - n.next = next -} - -// GetNodeChild returns the child index for a node. -// This is used when copying nodes. -func GetNodeChild(n *Node) int32 { - return n.child -} - -// GetNodeNext returns the next sibling index for a node. -// This is used when copying nodes. -func GetNodeNext(n *Node) int32 { - return n.next -} - -// GetNodeIndex returns the index of node n in the backing slice, -// using relativeTo's nodes slice as reference. -// Returns -1 if n is not in the slice. -func GetNodeIndex(n *Node, relativeTo *Node) int32 { - if relativeTo.nodes == nil || n == nil { - return -1 - } - nodes := *relativeTo.nodes - for i := range nodes { - if &nodes[i] == n { - return int32(i) - } - } - return -1 -} diff --git a/unstable/unmarshaler.go b/unstable/unmarshaler.go index 00cfd6d..5a79da8 100644 --- a/unstable/unmarshaler.go +++ b/unstable/unmarshaler.go @@ -1,7 +1,32 @@ package unstable -// The Unmarshaler interface may be implemented by types to customize their -// behavior when being unmarshaled from a TOML document. +// Unmarshaler is implemented by types that can unmarshal a TOML +// description of themselves. The input is a valid TOML document +// containing the relevant portion of the parsed document. +// +// For tables (including split tables defined in multiple places), +// the data contains the raw key-value bytes from the original document +// with adjusted table headers to be relative to the unmarshaling target. type Unmarshaler interface { - UnmarshalTOML(value *Node) error + UnmarshalTOML(data []byte) error +} + +// RawMessage is a raw encoded TOML value. It implements Unmarshaler +// and can be used to delay TOML decoding or capture raw content. +// +// Example usage: +// +// type Config struct { +// Plugin RawMessage `toml:"plugin"` +// } +// +// var cfg Config +// toml.NewDecoder(r).EnableUnmarshalerInterface().Decode(&cfg) +// // cfg.Plugin now contains the raw TOML bytes for [plugin] +type RawMessage []byte + +// UnmarshalTOML implements Unmarshaler. +func (m *RawMessage) UnmarshalTOML(data []byte) error { + *m = append((*m)[0:0], data...) + return nil }