diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index b495140..1337895 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -10,16 +10,38 @@ import ( ) func BenchmarkUnmarshalSimple(b *testing.B) { - d := struct { - A string - }{} doc := []byte(`A = "hello"`) - for i := 0; i < b.N; i++ { - err := toml.Unmarshal(doc, &d) - if err != nil { - panic(err) + + b.Run("struct", func(b *testing.B) { + b.SetBytes(int64(len(doc))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + d := struct { + A string + }{} + + err := toml.Unmarshal(doc, &d) + if err != nil { + panic(err) + } } - } + }) + + b.Run("map", func(b *testing.B) { + b.SetBytes(int64(len(doc))) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + d := map[string]interface{}{} + err := toml.Unmarshal(doc, &d) + if err != nil { + panic(err) + } + } + }) } type benchmarkDoc struct { @@ -133,33 +155,32 @@ func BenchmarkReferenceFile(b *testing.B) { if err != nil { b.Fatal(err) } - b.SetBytes(int64(len(bytes))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - d := benchmarkDoc{} - err := toml.Unmarshal(bytes, &d) - if err != nil { - panic(err) - } - } -} -func BenchmarkReferenceFileMap(b *testing.B) { - bytes, err := ioutil.ReadFile("benchmark.toml") - if err != nil { - b.Fatal(err) - } - b.SetBytes(int64(len(bytes))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - m := map[string]interface{}{} - err := toml.Unmarshal(bytes, &m) - if err != nil { - panic(err) + b.Run("struct", func(b *testing.B) { + b.SetBytes(int64(len(bytes))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + d := benchmarkDoc{} + err := toml.Unmarshal(bytes, &d) + if err != nil { + panic(err) + } } - } + }) + + b.Run("map", func(b *testing.B) { + b.SetBytes(int64(len(bytes))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + d := map[string]interface{}{} + err := toml.Unmarshal(bytes, &d) + if err != nil { + panic(err) + } + } + }) } func TestReferenceFile(t *testing.T) { @@ -169,3 +190,38 @@ func TestReferenceFile(t *testing.T) { err = toml.Unmarshal(bytes, &d) require.NoError(t, err) } + +func BenchmarkHugoFrontMatter(b *testing.B) { + bytes := []byte(` +categories = ["Development", "VIM"] +date = "2012-04-06" +description = "spf13-vim is a cross platform distribution of vim plugins and resources for Vim." +slug = "spf13-vim-3-0-release-and-new-website" +tags = [".vimrc", "plugins", "spf13-vim", "vim"] +title = "spf13-vim 3.0 release and new website" +include_toc = true +show_comments = false + +[[cascade]] + background = "yosemite.jpg" + [cascade._target] + kind = "page" + lang = "en" + path = "/blog/**" + +[[cascade]] + background = "goldenbridge.jpg" + [cascade._target] + kind = "section" +`) + b.SetBytes(int64(len(bytes))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + d := map[string]interface{}{} + err := toml.Unmarshal(bytes, &d) + if err != nil { + panic(err) + } + } +} diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index e245aac..4b5d392 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -1,6 +1,7 @@ package tracker import ( + "bytes" "fmt" "github.com/pelletier/go-toml/v2/internal/ast" @@ -29,67 +30,92 @@ func (k keyKind) String() string { panic("missing keyKind string mapping") } -// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates -// and mismatches according to the spec. +// SeenTracker tracks which keys have been seen with which TOML type to flag +// duplicates and mismatches according to the spec. +// +// Each node in the visited tree is represented by an entry. Each entry has an +// identifier, which is provided by a counter. Entries are stored in the array +// entries. As new nodes are discovered (referenced for the first time in the +// TOML document), entries are created and appended to the array. An entry +// points to its parent using its id. +// +// To find whether a given key (sequence of []byte) has already been visited, +// the entries are linearly searched, looking for one with the right name and +// parent id. +// +// Given that all keys appear in the document after their parent, it is +// guaranteed that all descendants of a node are stored after the node, this +// speeds up the search process. +// +// When encountering [[array tables]], the descendants of that node are removed +// to allow that branch of the tree to be "rediscovered". To maintain the +// invariant above, the deletion process needs to keep the order of entries. +// This results in more copies in that case. type SeenTracker struct { - root *info - current *info + entries []entry + currentIdx int + nextID int } -type info struct { - parent *info +type entry struct { + id int + parent int + name []byte kind keyKind - children map[string]*info explicit bool } -func (i *info) Clear() { - i.children = nil +// Remove all descendent of node at position idx. +func (s *SeenTracker) clear(idx int) { + p := s.entries[idx].id + rest := clear(p, s.entries[idx+1:]) + s.entries = s.entries[:idx+1+len(rest)] } -func (i *info) Has(k string) (*info, bool) { - c, ok := i.children[k] - return c, ok -} - -func (i *info) SetKind(kind keyKind) { - i.kind = kind -} - -func (i *info) CreateTable(k string, explicit bool) *info { - return i.createChild(k, tableKind, explicit) -} - -func (i *info) CreateArrayTable(k string, explicit bool) *info { - return i.createChild(k, arrayTableKind, explicit) -} - -func (i *info) createChild(k string, kind keyKind, explicit bool) *info { - if i.children == nil { - i.children = make(map[string]*info, 1) +func clear(parentID int, entries []entry) []entry { + for i := 0; i < len(entries); { + if entries[i].parent == parentID { + id := entries[i].id + copy(entries[i:], entries[i+1:]) + entries = entries[:len(entries)-1] + rest := clear(id, entries[i:]) + entries = entries[:i+len(rest)] + } else { + i++ + } } + return entries +} - x := &info{ - parent: i, +func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int { + parentID := s.id(parentIdx) + + idx := len(s.entries) + s.entries = append(s.entries, entry{ + id: s.nextID, + parent: parentID, + name: name, kind: kind, explicit: explicit, - } - i.children[k] = x - return x + }) + s.nextID++ + return idx } // CheckExpression takes a top-level node and checks that it does not contain keys // that have been seen in previous calls, and validates that types are consistent. func (s *SeenTracker) CheckExpression(node ast.Node) error { - if s.root == nil { - s.root = &info{ - kind: tableKind, - } - s.current = s.root + if s.entries == nil { + //s.entries = make([]entry, 0, 8) + // Skip ID = 0 to remove the confusion between nodes whose parent has + // id 0 and root nodes (parent id is 0 because it's the zero value). + s.nextID = 1 + // Start unscoped, so idx is negative. + s.currentIdx = -1 } switch node.Kind { case ast.KeyValue: - return s.checkKeyValue(s.current, node) + return s.checkKeyValue(node) case ast.Table: return s.checkTable(node) case ast.ArrayTable: @@ -97,104 +123,135 @@ func (s *SeenTracker) CheckExpression(node ast.Node) error { default: panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) } - } -func (s *SeenTracker) checkTable(node ast.Node) error { - s.current = s.root +func (s *SeenTracker) checkTable(node ast.Node) error { it := node.Key() - // handle the first parts of the key, excluding the last one + + parentIdx := -1 + + // This code is duplicated in checkArrayTable. This is because factoring + // it in a function requires to copy the iterator, or allocate it to the + // heap, which is not cheap. for it.Next() { if !it.Node().Next().Valid() { break } - k := string(it.Node().Data) - child, found := s.current.Has(k) - if !found { - child = s.current.CreateTable(k, false) + k := it.Node().Data + + idx := s.find(parentIdx, k) + + if idx < 0 { + idx = s.create(parentIdx, k, tableKind, false) } - s.current = child + parentIdx = idx } - // handle the last part of the key - k := string(it.Node().Data) + k := it.Node().Data + idx := s.find(parentIdx, k) - i, found := s.current.Has(k) - if found { - if i.kind != tableKind { - return fmt.Errorf("toml: key %s should be a table, not a %s", k, i.kind) + if idx >= 0 { + kind := s.entries[idx].kind + if kind != tableKind { + return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind) } - if i.explicit { - return fmt.Errorf("toml: table %s already exists", k) + if s.entries[idx].explicit { + return fmt.Errorf("toml: table %s already exists", string(k)) } - i.explicit = true - s.current = i + s.entries[idx].explicit = true } else { - s.current = s.current.CreateTable(k, true) + idx = s.create(parentIdx, k, tableKind, true) } + s.currentIdx = idx + return nil } func (s *SeenTracker) checkArrayTable(node ast.Node) error { - s.current = s.root - it := node.Key() - // handle the first parts of the key, excluding the last one + parentIdx := -1 + for it.Next() { if !it.Node().Next().Valid() { break } - k := string(it.Node().Data) - child, found := s.current.Has(k) - if !found { - child = s.current.CreateTable(k, false) + k := it.Node().Data + + idx := s.find(parentIdx, k) + + if idx < 0 { + idx = s.create(parentIdx, k, tableKind, false) } - s.current = child + parentIdx = idx } - // handle the last part of the key - k := string(it.Node().Data) + k := it.Node().Data + idx := s.find(parentIdx, k) - info, found := s.current.Has(k) - if found { - if info.kind != arrayTableKind { - return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", info.kind, k) + if idx >= 0 { + kind := s.entries[idx].kind + if kind != arrayTableKind { + return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k)) } - info.Clear() + s.clear(idx) } else { - info = s.current.CreateArrayTable(k, true) + idx = s.create(parentIdx, k, arrayTableKind, true) } - s.current = info + s.currentIdx = idx + return nil } -func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error { +func (s *SeenTracker) checkKeyValue(node ast.Node) error { it := node.Key() - // handle the first parts of the key, excluding the last one + parentIdx := s.currentIdx + for it.Next() { - k := string(it.Node().Data) - child, found := context.Has(k) - if found { - if child.kind != tableKind { - return fmt.Errorf("toml: expected %s to be a table, not a %s", k, child.kind) + k := it.Node().Data + + idx := s.find(parentIdx, k) + + if idx >= 0 { + if s.entries[idx].kind != tableKind { + return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), s.entries[idx].kind) } } else { - child = context.CreateTable(k, false) + idx = s.create(parentIdx, k, tableKind, false) } - context = child + parentIdx = idx } + kind := valueKind + if node.Value().Kind == ast.InlineTable { - context.SetKind(tableKind) - } else { - context.SetKind(valueKind) + kind = tableKind } + s.entries[parentIdx].kind = kind return nil } + +func (s *SeenTracker) id(idx int) int { + if idx >= 0 { + return s.entries[idx].id + } + return 0 +} + +func (s *SeenTracker) find(parentIdx int, k []byte) int { + parentID := s.id(parentIdx) + + for i := parentIdx + 1; i < len(s.entries); i++ { + if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) { + return i + } + } + + return -1 +}