From 618f0181ac76015c3efdaf8b8ab5a468293f6e82 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Thu, 3 Jun 2021 21:48:51 -0400 Subject: [PATCH] AST Tweaks (#551) * Use pointers instead of copying around ast.Node Node is a 56B struct that is constantly in the hot path. Passing nodes around by copy had a cost that started to add up. This change replaces them by pointers. Using unsafe pointer arithmetic and converting sibling/child indexes to relative offsets, it removes the need to carry around a pointer to the root of the tree. This saves 8B per Node. This space will be used to store an extra []byte slice to provide contextual error handling on all nodes, including the ones whose data is different than the raw input (for example: strings with escaped characters), while staying under the size of a cache line. * Remove conditional * Add Raw to track range in data for parsed values * Simplify reference tracking --- README.md | 32 +++--- errors.go | 4 +- internal/ast/ast.go | 66 ++++++----- internal/ast/builder.go | 31 ++--- .../{unsafe/unsafe.go => danger/danger.go} | 8 +- .../unsafe_test.go => danger/danger_test.go} | 28 +++-- internal/tracker/key.go | 8 +- internal/tracker/seen.go | 8 +- parser.go | 106 ++++++++++-------- parser_test.go | 6 +- strict.go | 18 +-- unmarshaler.go | 41 ++++--- unmarshaler_test.go | 48 ++++++++ 13 files changed, 239 insertions(+), 165 deletions(-) rename internal/{unsafe/unsafe.go => danger/danger.go} (84%) rename internal/{unsafe/unsafe_test.go => danger/danger_test.go} (82%) diff --git a/README.md b/README.md index 5a03665..698121b 100644 --- a/README.md +++ b/README.md @@ -156,12 +156,12 @@ Execution time speedup compared to other Go TOML libraries: Benchmarkgo-toml v1BurntSushi/toml - Marshal/HugoFrontMatter1.9x1.9x - Marshal/ReferenceFile/map1.7x1.9x - Marshal/ReferenceFile/struct2.7x2.9x - Unmarshal/HugoFrontMatter2.9x2.4x - Unmarshal/ReferenceFile/map3.1x3.0x - Unmarshal/ReferenceFile/struct5.5x5.8x + Marshal/HugoFrontMatter2.0x2.0x + Marshal/ReferenceFile/map1.8x2.0x + Marshal/ReferenceFile/struct2.7x2.7x + Unmarshal/HugoFrontMatter3.0x2.6x + Unmarshal/ReferenceFile/map3.0x3.1x + Unmarshal/ReferenceFile/struct5.9x6.6x
See more @@ -174,16 +174,16 @@ provided for completeness.

Benchmarkgo-toml v1BurntSushi/toml - Marshal/SimpleDocument/map1.8x2.4x - Marshal/SimpleDocument/struct2.7x3.5x - Unmarshal/SimpleDocument/map4.3x2.4x - Unmarshal/SimpleDocument/struct5.8x3.3x - UnmarshalDataset/example3.1x2.2x - UnmarshalDataset/code1.8x2.1x - UnmarshalDataset/twitter2.7x1.9x - UnmarshalDataset/citm_catalog1.8x1.2x - UnmarshalDataset/config3.4x2.8x - [Geo mean]2.8x2.5x + Marshal/SimpleDocument/map1.7x2.1x + Marshal/SimpleDocument/struct2.6x2.9x + Unmarshal/SimpleDocument/map4.1x2.9x + Unmarshal/SimpleDocument/struct6.3x4.1x + UnmarshalDataset/example3.5x2.4x + UnmarshalDataset/code2.2x2.8x + UnmarshalDataset/twitter2.8x2.1x + UnmarshalDataset/citm_catalog2.3x1.5x + UnmarshalDataset/config4.2x3.2x + [Geo mean]3.0x2.7x

This table can be generated with ./ci.sh benchmark -a -html.

diff --git a/errors.go b/errors.go index 712765b..a00924b 100644 --- a/errors.go +++ b/errors.go @@ -5,7 +5,7 @@ import ( "strconv" "strings" - "github.com/pelletier/go-toml/v2/internal/unsafe" + "github.com/pelletier/go-toml/v2/internal/danger" ) // DecodeError represents an error encountered during the parsing or decoding @@ -105,7 +105,7 @@ func (e *DecodeError) Key() Key { // highlight can be freely deallocated. //nolint:funlen func wrapDecodeError(document []byte, de *decodeError) *DecodeError { - offset := unsafe.SubsliceOffset(document, de.highlight) + offset := danger.SubsliceOffset(document, de.highlight) errMessage := de.Error() errLine, errColumn := positionAtEnd(document[:offset]) diff --git a/internal/ast/ast.go b/internal/ast/ast.go index f9059d8..82c1cb9 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -2,6 +2,9 @@ package ast import ( "fmt" + "unsafe" + + "github.com/pelletier/go-toml/v2/internal/danger" ) // Iterator starts uninitialized, you need to call Next() first. @@ -14,7 +17,7 @@ import ( // } type Iterator struct { started bool - node Node + node *Node } // Next moves the iterator forward and returns true if points to a node, false @@ -31,11 +34,11 @@ func (c *Iterator) Next() bool { // IsLast returns true if the current node of the iterator is the last one. // Subsequent call to Next() will return false. func (c *Iterator) IsLast() bool { - return c.node.next <= 0 + return c.node.next == 0 } // Node returns a copy of the node pointed at by the iterator. -func (c *Iterator) Node() Node { +func (c *Iterator) Node() *Node { return c.node } @@ -50,14 +53,13 @@ type Root struct { func (r *Root) Iterator() Iterator { it := Iterator{} if len(r.nodes) > 0 { - it.node = r.nodes[0] + it.node = &r.nodes[0] } return it } -func (r *Root) at(idx int) Node { - // TODO: unsafe to point to the node directly - return r.nodes[idx] +func (r *Root) at(idx Reference) *Node { + return &r.nodes[idx] } // Arrays have one child per element in the array. @@ -69,42 +71,48 @@ func (r *Root) at(idx int) Node { // children []Node type Node struct { Kind Kind - Data []byte // Raw bytes from the input + Raw Range // Raw bytes from the input. + Data []byte // Node value (could be either allocated or referencing the input). - // next idx (in the root array). 0 if last of the collection. - next int - // child idx (in the root array). 0 if no child. - child int - // pointer to the root array - root *Root + // References to other nodes, as offsets in the backing array from this + // node. References can go backward, so those can be negative. + next int // 0 if last element + child int // 0 if no child +} + +type Range struct { + Offset uint32 + Length uint32 } // Next returns a copy of the next node, or an invalid Node if there is no // next node. -func (n Node) Next() Node { - if n.next <= 0 { - return noNode +func (n *Node) Next() *Node { + if n.next == 0 { + return nil } - return n.root.at(n.next) + ptr := unsafe.Pointer(n) + size := unsafe.Sizeof(Node{}) + return (*Node)(danger.Stride(ptr, size, n.next)) } // Child returns a copy of the first child node of this node. Other children // can be accessed calling Next on the first child. // Returns an invalid Node if there is none. -func (n Node) Child() Node { - if n.child <= 0 { - return noNode +func (n *Node) Child() *Node { + if n.child == 0 { + return nil } - return n.root.at(n.child) + ptr := unsafe.Pointer(n) + size := unsafe.Sizeof(Node{}) + return (*Node)(danger.Stride(ptr, size, n.child)) } // Valid returns true if the node's kind is set (not to Invalid). -func (n Node) Valid() bool { - return n.Kind != Invalid +func (n *Node) Valid() bool { + return n != nil } -var noNode = Node{} - // Key returns the child nodes making the Key on a supported node. Panics // otherwise. // They are guaranteed to be all be of the Kind Key. A simple key would return @@ -127,13 +135,13 @@ func (n *Node) Key() Iterator { // Value returns a pointer to the value node of a KeyValue. // Guaranteed to be non-nil. // Panics if not called on a KeyValue node, or if the Children are malformed. -func (n Node) Value() Node { - assertKind(KeyValue, n) +func (n *Node) Value() *Node { + assertKind(KeyValue, *n) return n.Child() } // Children returns an iterator over a node's children. -func (n Node) Children() Iterator { +func (n *Node) Children() Iterator { return Iterator{node: n.Child()} } diff --git a/internal/ast/builder.go b/internal/ast/builder.go index 796b7f1..120f16e 100644 --- a/internal/ast/builder.go +++ b/internal/ast/builder.go @@ -1,12 +1,11 @@ package ast -type Reference struct { - idx int - set bool -} +type Reference int + +const InvalidReference Reference = -1 func (r Reference) Valid() bool { - return r.set + return r != InvalidReference } type Builder struct { @@ -18,8 +17,8 @@ func (b *Builder) Tree() *Root { return &b.tree } -func (b *Builder) NodeAt(ref Reference) Node { - return b.tree.at(ref.idx) +func (b *Builder) NodeAt(ref Reference) *Node { + return b.tree.at(ref) } func (b *Builder) Reset() { @@ -28,33 +27,25 @@ func (b *Builder) Reset() { } func (b *Builder) Push(n Node) Reference { - n.root = &b.tree b.lastIdx = len(b.tree.nodes) b.tree.nodes = append(b.tree.nodes, n) - return Reference{ - idx: b.lastIdx, - set: true, - } + return Reference(b.lastIdx) } func (b *Builder) PushAndChain(n Node) Reference { - n.root = &b.tree newIdx := len(b.tree.nodes) b.tree.nodes = append(b.tree.nodes, n) if b.lastIdx >= 0 { - b.tree.nodes[b.lastIdx].next = newIdx + b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx } b.lastIdx = newIdx - return Reference{ - idx: b.lastIdx, - set: true, - } + return Reference(b.lastIdx) } func (b *Builder) AttachChild(parent Reference, child Reference) { - b.tree.nodes[parent.idx].child = child.idx + b.tree.nodes[parent].child = int(child) - int(parent) } func (b *Builder) Chain(from Reference, to Reference) { - b.tree.nodes[from.idx].next = to.idx + b.tree.nodes[from].next = int(to) - int(from) } diff --git a/internal/unsafe/unsafe.go b/internal/danger/danger.go similarity index 84% rename from internal/unsafe/unsafe.go rename to internal/danger/danger.go index 742c6ab..e38e113 100644 --- a/internal/unsafe/unsafe.go +++ b/internal/danger/danger.go @@ -1,4 +1,4 @@ -package unsafe +package danger import ( "fmt" @@ -57,3 +57,9 @@ func BytesRange(start []byte, end []byte) []byte { return start[:l] } + +func Stride(ptr unsafe.Pointer, size uintptr, offset int) unsafe.Pointer { + // TODO: replace with unsafe.Add when Go 1.17 is released + // https://github.com/golang/go/issues/40481 + return unsafe.Pointer(uintptr(ptr) + uintptr(int(size)*offset)) +} diff --git a/internal/unsafe/unsafe_test.go b/internal/danger/danger_test.go similarity index 82% rename from internal/unsafe/unsafe_test.go rename to internal/danger/danger_test.go index 5462d08..cb975f8 100644 --- a/internal/unsafe/unsafe_test.go +++ b/internal/danger/danger_test.go @@ -1,15 +1,16 @@ -package unsafe_test +package danger_test import ( "testing" + "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/pelletier/go-toml/v2/internal/unsafe" + "github.com/pelletier/go-toml/v2/internal/danger" ) -func TestUnsafeSubsliceOffsetValid(t *testing.T) { +func TestSubsliceOffsetValid(t *testing.T) { examples := []struct { desc string test func() ([]byte, []byte) @@ -28,13 +29,13 @@ func TestUnsafeSubsliceOffsetValid(t *testing.T) { for _, e := range examples { t.Run(e.desc, func(t *testing.T) { d, s := e.test() - offset := unsafe.SubsliceOffset(d, s) + offset := danger.SubsliceOffset(d, s) assert.Equal(t, e.offset, offset) }) } } -func TestUnsafeSubsliceOffsetInvalid(t *testing.T) { +func TestSubsliceOffsetInvalid(t *testing.T) { examples := []struct { desc string test func() ([]byte, []byte) @@ -72,13 +73,22 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) { t.Run(e.desc, func(t *testing.T) { d, s := e.test() require.Panics(t, func() { - unsafe.SubsliceOffset(d, s) + danger.SubsliceOffset(d, s) }) }) } } -func TestUnsafeBytesRange(t *testing.T) { +func TestStride(t *testing.T) { + a := []byte{1, 2, 3, 4} + x := &a[1] + n := (*byte)(danger.Stride(unsafe.Pointer(x), unsafe.Sizeof(byte(0)), 1)) + require.Equal(t, &a[2], n) + n = (*byte)(danger.Stride(unsafe.Pointer(x), unsafe.Sizeof(byte(0)), -1)) + require.Equal(t, &a[0], n) +} + +func TestBytesRange(t *testing.T) { type fn = func() ([]byte, []byte) examples := []struct { desc string @@ -157,10 +167,10 @@ func TestUnsafeBytesRange(t *testing.T) { start, end := e.test() if e.expected == nil { require.Panics(t, func() { - unsafe.BytesRange(start, end) + danger.BytesRange(start, end) }) } else { - res := unsafe.BytesRange(start, end) + res := danger.BytesRange(start, end) require.Equal(t, e.expected, res) } }) diff --git a/internal/tracker/key.go b/internal/tracker/key.go index be99f72..7c148f4 100644 --- a/internal/tracker/key.go +++ b/internal/tracker/key.go @@ -11,19 +11,19 @@ type KeyTracker struct { } // UpdateTable sets the state of the tracker with the AST table node. -func (t *KeyTracker) UpdateTable(node ast.Node) { +func (t *KeyTracker) UpdateTable(node *ast.Node) { t.reset() t.Push(node) } // UpdateArrayTable sets the state of the tracker with the AST array table node. -func (t *KeyTracker) UpdateArrayTable(node ast.Node) { +func (t *KeyTracker) UpdateArrayTable(node *ast.Node) { t.reset() t.Push(node) } // Push the given key on the stack. -func (t *KeyTracker) Push(node ast.Node) { +func (t *KeyTracker) Push(node *ast.Node) { it := node.Key() for it.Next() { t.k = append(t.k, string(it.Node().Data)) @@ -31,7 +31,7 @@ func (t *KeyTracker) Push(node ast.Node) { } // Pop key from stack. -func (t *KeyTracker) Pop(node ast.Node) { +func (t *KeyTracker) Pop(node *ast.Node) { it := node.Key() for it.Next() { t.k = t.k[:len(t.k)-1] diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index 0f6bd01..af70241 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -104,7 +104,7 @@ func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit // CheckExpression takes a top-level node and checks that it does not contain keys // that have been seen in previous calls, and validates that types are consistent. -func (s *SeenTracker) CheckExpression(node ast.Node) error { +func (s *SeenTracker) CheckExpression(node *ast.Node) error { if s.entries == nil { // s.entries = make([]entry, 0, 8) // Skip ID = 0 to remove the confusion between nodes whose parent has @@ -125,7 +125,7 @@ func (s *SeenTracker) CheckExpression(node ast.Node) error { } } -func (s *SeenTracker) checkTable(node ast.Node) error { +func (s *SeenTracker) checkTable(node *ast.Node) error { it := node.Key() parentIdx := -1 @@ -169,7 +169,7 @@ func (s *SeenTracker) checkTable(node ast.Node) error { return nil } -func (s *SeenTracker) checkArrayTable(node ast.Node) error { +func (s *SeenTracker) checkArrayTable(node *ast.Node) error { it := node.Key() parentIdx := -1 @@ -207,7 +207,7 @@ func (s *SeenTracker) checkArrayTable(node ast.Node) error { return nil } -func (s *SeenTracker) checkKeyValue(node ast.Node) error { +func (s *SeenTracker) checkKeyValue(node *ast.Node) error { it := node.Key() parentIdx := s.currentIdx diff --git a/parser.go b/parser.go index aa97e2e..453b08d 100644 --- a/parser.go +++ b/parser.go @@ -5,6 +5,7 @@ import ( "strconv" "github.com/pelletier/go-toml/v2/internal/ast" + "github.com/pelletier/go-toml/v2/internal/danger" ) type parser struct { @@ -16,9 +17,20 @@ type parser struct { first bool } +func (p *parser) Range(b []byte) ast.Range { + return ast.Range{ + Offset: uint32(danger.SubsliceOffset(p.data, b)), + Length: uint32(len(b)), + } +} + +func (p *parser) Raw(raw ast.Range) []byte { + return p.data[raw.Offset : raw.Offset+raw.Length] +} + func (p *parser) Reset(b []byte) { p.builder.Reset() - p.ref = ast.Reference{} + p.ref = ast.InvalidReference p.data = b p.left = b p.err = nil @@ -32,7 +44,7 @@ func (p *parser) NextExpression() bool { } p.builder.Reset() - p.ref = ast.Reference{} + p.ref = ast.InvalidReference for { if len(p.left) == 0 || p.err != nil { @@ -61,7 +73,7 @@ func (p *parser) NextExpression() bool { } } -func (p *parser) Expression() ast.Node { +func (p *parser) Expression() *ast.Node { return p.builder.NodeAt(p.ref) } @@ -86,7 +98,7 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) { // expression = ws [ comment ] // expression =/ ws keyval ws [ comment ] // expression =/ ws table ws [ comment ] - var ref ast.Reference + ref := ast.InvalidReference b = p.parseWhitespace(b) @@ -197,7 +209,7 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { key, b, err := p.parseKey(b) if err != nil { - return ast.Reference{}, nil, err + return ast.InvalidReference, nil, err } // keyval-sep = ws %x3D ws ; = @@ -205,12 +217,12 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { b = p.parseWhitespace(b) if len(b) == 0 { - return ast.Reference{}, nil, newDecodeError(b, "expected = after a key, but the document ends there") + return ast.InvalidReference, nil, newDecodeError(b, "expected = after a key, but the document ends there") } b, err = expect('=', b) if err != nil { - return ast.Reference{}, nil, err + return ast.InvalidReference, nil, err } b = p.parseWhitespace(b) @@ -229,7 +241,7 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { //nolint:cyclop,funlen func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { // val = string / boolean / array / inline-table / date-time / float / integer - var ref ast.Reference + ref := ast.InvalidReference if len(b) == 0 { return ref, nil, newDecodeError(b, "expected value, not eof") @@ -240,32 +252,36 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { switch c { case '"': + var raw []byte var v []byte if scanFollowsMultilineBasicStringDelimiter(b) { - v, b, err = p.parseMultilineBasicString(b) + raw, v, b, err = p.parseMultilineBasicString(b) } else { - v, b, err = p.parseBasicString(b) + raw, v, b, err = p.parseBasicString(b) } if err == nil { ref = p.builder.Push(ast.Node{ Kind: ast.String, + Raw: p.Range(raw), Data: v, }) } return ref, b, err case '\'': + var raw []byte var v []byte if scanFollowsMultilineLiteralStringDelimiter(b) { - v, b, err = p.parseMultilineLiteralString(b) + raw, v, b, err = p.parseMultilineLiteralString(b) } else { - v, b, err = p.parseLiteralString(b) + raw, v, b, err = p.parseLiteralString(b) } if err == nil { ref = p.builder.Push(ast.Node{ Kind: ast.String, + Raw: p.Range(raw), Data: v, }) } @@ -310,13 +326,13 @@ func atmost(b []byte, n int) []byte { return b[:n] } -func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { +func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) { v, rest, err := scanLiteralString(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return v[1 : len(v)-1], rest, nil + return v, v[1 : len(v)-1], rest, nil } func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { @@ -476,10 +492,10 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) return b, nil } -func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) { +func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte, error) { token, rest, err := scanMultilineLiteralString(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } i := 3 @@ -491,11 +507,11 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) { i += 2 } - return token[i : len(token)-3], rest, err + return token, token[i : len(token)-3], rest, err } //nolint:funlen,gocognit,cyclop -func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { +func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, error) { // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string-delim // ml-basic-string-delim = 3quotation-mark @@ -508,7 +524,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { // mlb-escaped-nl = escape ws newline *( wschar / newline ) token, rest, err := scanMultilineBasicString(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } i := 3 @@ -529,7 +545,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { } } if i == endIdx { - return token[startIdx:endIdx], rest, nil + return token, token[startIdx:endIdx], rest, nil } var builder bytes.Buffer @@ -579,7 +595,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { case 'u': x, err := hexToString(atmost(token[i+1:], 4), 4) if err != nil { - return nil, nil, err + return nil, nil, nil, err } builder.WriteString(x) @@ -587,20 +603,20 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { case 'U': x, err := hexToString(atmost(token[i+1:], 8), 8) if err != nil { - return nil, nil, err + return nil, nil, nil, err } builder.WriteString(x) i += 8 default: - return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) + return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) } } else { builder.WriteByte(c) } } - return builder.Bytes(), rest, nil + return token, builder.Bytes(), rest, nil } func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { @@ -612,13 +628,14 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { // dotted-key = simple-key 1*( dot-sep simple-key ) // // dot-sep = ws %x2E ws ; . Period - key, b, err := p.parseSimpleKey(b) + raw, key, b, err := p.parseSimpleKey(b) if err != nil { - return ast.Reference{}, nil, err + return ast.InvalidReference, nil, err } ref := p.builder.Push(ast.Node{ Kind: ast.Key, + Raw: p.Range(raw), Data: key, }) @@ -627,13 +644,14 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { if len(b) > 0 && b[0] == '.' { b = p.parseWhitespace(b[1:]) - key, b, err = p.parseSimpleKey(b) + raw, key, b, err = p.parseSimpleKey(b) if err != nil { return ref, nil, err } p.builder.PushAndChain(ast.Node{ Kind: ast.Key, + Raw: p.Range(raw), Data: key, }) } else { @@ -644,12 +662,12 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { return ref, b, nil } -func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { +func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) { // simple-key = quoted-key / unquoted-key // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // quoted-key = basic-string / literal-string if len(b) == 0 { - return nil, nil, newDecodeError(b, "key is incomplete") + return nil, nil, nil, newDecodeError(b, "key is incomplete") } switch { @@ -659,14 +677,14 @@ func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { return p.parseBasicString(b) case isUnquotedKeyChar(b[0]): key, rest = scanUnquotedKey(b) - return key, rest, nil + return key, key, rest, nil default: - return nil, nil, newDecodeError(b[0:1], "invalid character at start of key: %c", b[0]) + return nil, nil, nil, newDecodeError(b[0:1], "invalid character at start of key: %c", b[0]) } } //nolint:funlen,cyclop -func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { +func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) { // basic-string = quotation-mark *basic-char quotation-mark // quotation-mark = %x22 ; " // basic-char = basic-unescaped / escaped @@ -683,7 +701,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { // escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX token, rest, err := scanBasicString(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // fast path @@ -696,7 +714,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { } } if i == endIdx { - return token[startIdx:endIdx], rest, nil + return token, token[startIdx:endIdx], rest, nil } var builder bytes.Buffer @@ -726,7 +744,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { case 'u': x, err := hexToString(token[i+1:len(token)-1], 4) if err != nil { - return nil, nil, err + return nil, nil, nil, err } builder.WriteString(x) @@ -734,20 +752,20 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { case 'U': x, err := hexToString(token[i+1:len(token)-1], 8) if err != nil { - return nil, nil, err + return nil, nil, nil, err } builder.WriteString(x) i += 8 default: - return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) + return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) } } else { builder.WriteByte(c) } } - return builder.Bytes(), rest, nil + return token, builder.Bytes(), rest, nil } func hexToString(b []byte, length int) (string, error) { @@ -780,7 +798,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err switch b[0] { case 'i': if !scanFollowsInf(b) { - return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'inf'") + return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'inf'") } return p.builder.Push(ast.Node{ @@ -789,7 +807,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err }), b[3:], nil case 'n': if !scanFollowsNan(b) { - return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'nan'") + return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'nan'") } return p.builder.Push(ast.Node{ @@ -945,7 +963,7 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { }), b[i+3:], nil } - return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number") + return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number") } if c == 'n' { @@ -956,14 +974,14 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { }), b[i+3:], nil } - return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number") + return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number") } break } if i == 0 { - return ast.Reference{}, b, newDecodeError(b, "incomplete number") + return ast.InvalidReference, b, newDecodeError(b, "incomplete number") } kind := ast.Integer diff --git a/parser_test.go b/parser_test.go index fdb4f27..9fda429 100644 --- a/parser_test.go +++ b/parser_test.go @@ -9,7 +9,6 @@ import ( //nolint:funlen func TestParser_AST_Numbers(t *testing.T) { - examples := []struct { desc string input string @@ -136,7 +135,6 @@ func TestParser_AST_Numbers(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - p := parser{} p.Reset([]byte(`A = ` + e.input)) p.NextExpression() @@ -167,7 +165,7 @@ type ( } ) -func compareNode(t *testing.T, e astNode, n ast.Node) { +func compareNode(t *testing.T, e astNode, n *ast.Node) { t.Helper() require.Equal(t, e.Kind, n.Kind) require.Equal(t, e.Data, n.Data) @@ -199,7 +197,6 @@ func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) { //nolint:funlen func TestParser_AST(t *testing.T) { - examples := []struct { desc string input string @@ -338,7 +335,6 @@ func TestParser_AST(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - p := parser{} p.Reset([]byte(e.input)) p.NextExpression() diff --git a/strict.go b/strict.go index ca482c4..b7830d1 100644 --- a/strict.go +++ b/strict.go @@ -2,8 +2,8 @@ package toml import ( "github.com/pelletier/go-toml/v2/internal/ast" + "github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/tracker" - "github.com/pelletier/go-toml/v2/internal/unsafe" ) type strict struct { @@ -15,7 +15,7 @@ type strict struct { missing []decodeError } -func (s *strict) EnterTable(node ast.Node) { +func (s *strict) EnterTable(node *ast.Node) { if !s.Enabled { return } @@ -23,7 +23,7 @@ func (s *strict) EnterTable(node ast.Node) { s.key.UpdateTable(node) } -func (s *strict) EnterArrayTable(node ast.Node) { +func (s *strict) EnterArrayTable(node *ast.Node) { if !s.Enabled { return } @@ -31,7 +31,7 @@ func (s *strict) EnterArrayTable(node ast.Node) { s.key.UpdateArrayTable(node) } -func (s *strict) EnterKeyValue(node ast.Node) { +func (s *strict) EnterKeyValue(node *ast.Node) { if !s.Enabled { return } @@ -39,7 +39,7 @@ func (s *strict) EnterKeyValue(node ast.Node) { s.key.Push(node) } -func (s *strict) ExitKeyValue(node ast.Node) { +func (s *strict) ExitKeyValue(node *ast.Node) { if !s.Enabled { return } @@ -47,7 +47,7 @@ func (s *strict) ExitKeyValue(node ast.Node) { s.key.Pop(node) } -func (s *strict) MissingTable(node ast.Node) { +func (s *strict) MissingTable(node *ast.Node) { if !s.Enabled { return } @@ -59,7 +59,7 @@ func (s *strict) MissingTable(node ast.Node) { }) } -func (s *strict) MissingField(node ast.Node) { +func (s *strict) MissingField(node *ast.Node) { if !s.Enabled { return } @@ -88,7 +88,7 @@ func (s *strict) Error(doc []byte) error { return err } -func keyLocation(node ast.Node) []byte { +func keyLocation(node *ast.Node) []byte { k := node.Key() hasOne := k.Next() @@ -103,5 +103,5 @@ func keyLocation(node ast.Node) []byte { end = k.Node().Data } - return unsafe.BytesRange(start, end) + return danger.BytesRange(start, end) } diff --git a/unmarshaler.go b/unmarshaler.go index a328045..b514413 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -106,7 +106,6 @@ func (d *Decoder) Decode(v interface{}) error { type decoder struct { // Which parser instance in use for this decoding session. - // TODO: Think about removing later. p *parser // Flag indicating that the current expression is stashed. @@ -132,7 +131,7 @@ type decoder struct { strict strict } -func (d *decoder) expr() ast.Node { +func (d *decoder) expr() *ast.Node { return d.p.Expression() } @@ -208,7 +207,7 @@ Rules for the unmarshal code: - An "object" is either a struct or a map. */ -func (d *decoder) handleRootExpression(expr ast.Node, v reflect.Value) error { +func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error { var x reflect.Value var err error @@ -533,7 +532,7 @@ func (d *decoder) handleTablePart(key ast.Iterator, v reflect.Value) (reflect.Va return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface) } -func tryTextUnmarshaler(node ast.Node, v reflect.Value) (bool, error) { +func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, error) { if v.Kind() != reflect.Struct { return false, nil } @@ -547,9 +546,7 @@ func tryTextUnmarshaler(node ast.Node, v reflect.Value) (bool, error) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) if err != nil { - return false, fmt.Errorf("toml: error calling UnmarshalText: %w", err) - // TODO: same as above - // return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) + return false, newDecodeError(d.p.Raw(node.Raw), "error calling UnmarshalText: %w", err) } return true, nil @@ -558,12 +555,12 @@ func tryTextUnmarshaler(node ast.Node, v reflect.Value) (bool, error) { return false, nil } -func (d *decoder) handleValue(value ast.Node, v reflect.Value) error { +func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error { for v.Kind() == reflect.Ptr { v = initAndDereferencePointer(v) } - ok, err := tryTextUnmarshaler(value, v) + ok, err := d.tryTextUnmarshaler(value, v) if ok || err != nil { return err } @@ -592,7 +589,7 @@ func (d *decoder) handleValue(value ast.Node, v reflect.Value) error { } } -func (d *decoder) unmarshalArray(array ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error { switch v.Kind() { case reflect.Slice: if v.IsNil() { @@ -663,7 +660,7 @@ func (d *decoder) unmarshalArray(array ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalInlineTable(itable ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error { // Make sure v is an initialized object. switch v.Kind() { case reflect.Map: @@ -699,7 +696,7 @@ func (d *decoder) unmarshalInlineTable(itable ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalDateTime(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalDateTime(value *ast.Node, v reflect.Value) error { dt, err := parseDateTime(value.Data) if err != nil { return err @@ -709,7 +706,7 @@ func (d *decoder) unmarshalDateTime(value ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalLocalDate(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error { ld, err := parseLocalDate(value.Data) if err != nil { return err @@ -727,7 +724,7 @@ func (d *decoder) unmarshalLocalDate(value ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalLocalDateTime(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error { ldt, rest, err := parseLocalDateTime(value.Data) if err != nil { return err @@ -749,7 +746,7 @@ func (d *decoder) unmarshalLocalDateTime(value ast.Node, v reflect.Value) error return nil } -func (d *decoder) unmarshalBool(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalBool(value *ast.Node, v reflect.Value) error { b := value.Data[0] == 't' switch v.Kind() { @@ -764,7 +761,7 @@ func (d *decoder) unmarshalBool(value ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalFloat(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error { f, err := parseFloat(value.Data) if err != nil { return err @@ -787,7 +784,7 @@ func (d *decoder) unmarshalFloat(value ast.Node, v reflect.Value) error { return nil } -func (d *decoder) unmarshalInteger(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error { const ( maxInt = int64(^uint(0) >> 1) minInt = -maxInt - 1 @@ -865,7 +862,7 @@ func (d *decoder) unmarshalInteger(value ast.Node, v reflect.Value) error { return err } -func (d *decoder) unmarshalString(value ast.Node, v reflect.Value) error { +func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error { var err error switch v.Kind() { @@ -874,13 +871,13 @@ func (d *decoder) unmarshalString(value ast.Node, v reflect.Value) error { case reflect.Interface: v.Set(reflect.ValueOf(string(value.Data))) default: - err = fmt.Errorf("toml: cannot store TOML string into a Go %s", v.Kind()) + err = newDecodeError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind()) } return err } -func (d *decoder) handleKeyValue(expr ast.Node, v reflect.Value) (reflect.Value, error) { +func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value, error) { d.strict.EnterKeyValue(expr) v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v) @@ -894,7 +891,7 @@ func (d *decoder) handleKeyValue(expr ast.Node, v reflect.Value) (reflect.Value, return v, err } -func (d *decoder) handleKeyValueInner(key ast.Iterator, value ast.Node, v reflect.Value) (reflect.Value, error) { +func (d *decoder) handleKeyValueInner(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { if key.Next() { // Still scoping the key return d.handleKeyValuePart(key, value, v) @@ -904,7 +901,7 @@ func (d *decoder) handleKeyValueInner(key ast.Iterator, value ast.Node, v reflec return reflect.Value{}, d.handleValue(value, v) } -func (d *decoder) handleKeyValuePart(key ast.Iterator, value ast.Node, v reflect.Value) (reflect.Value, error) { +func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { // contains the replacement for v var rv reflect.Value diff --git a/unmarshaler_test.go b/unmarshaler_test.go index e14c9be..35f4ad3 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -287,6 +287,54 @@ func TestUnmarshal(t *testing.T) { } }, }, + { + desc: "local datetime into time.Time", + input: `a = 1979-05-27T00:32:00`, + gen: func() test { + type doc struct { + A time.Time + } + + return test{ + target: &doc{}, + expected: &doc{ + A: time.Date(1979, 5, 27, 0, 32, 0, 0, time.Local), + }, + } + }, + }, + { + desc: "local datetime into interface", + input: `a = 1979-05-27T00:32:00`, + gen: func() test { + type doc struct { + A interface{} + } + + return test{ + target: &doc{}, + expected: &doc{ + A: toml.LocalDateTimeOf(time.Date(1979, 5, 27, 0, 32, 0, 0, time.Local)), + }, + } + }, + }, + { + desc: "local date into interface", + input: `a = 1979-05-27`, + gen: func() test { + type doc struct { + A interface{} + } + + return test{ + target: &doc{}, + expected: &doc{ + A: toml.LocalDateOf(time.Date(1979, 5, 27, 0, 32, 0, 0, time.Local)), + }, + } + }, + }, { desc: "issue 475 - space between dots in key", input: `fruit. color = "yellow"