From 8683be35f6b45b4c53f56573f6eca9936f714df2 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Mon, 8 Nov 2021 21:53:02 -0500 Subject: [PATCH] seen: check inline tables (#660) Fixes #658 --- internal/tracker/seen.go | 33 ++++++++++++++++++++++++--------- unmarshaler_test.go | 6 ++++++ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index 8990dae..8adc4dc 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -65,7 +65,7 @@ type entry struct { explicit bool } -// Remove all descendent of node at position idx. +// Remove all descendants of node at position idx. func (s *SeenTracker) clear(idx int) { p := s.entries[idx].id rest := clear(p, s.entries[idx+1:]) @@ -102,19 +102,21 @@ func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit return idx } -// CheckExpression takes a top-level node and checks that it does not contain keys -// that have been seen in previous calls, and validates that types are consistent. +// CheckExpression takes a top-level node and checks that it does not contain +// keys that have been seen in previous calls, and validates that types are +// consistent. func (s *SeenTracker) CheckExpression(node *ast.Node) error { if s.entries == nil { - // Skip ID = 0 to remove the confusion between nodes whose parent has - // id 0 and root nodes (parent id is 0 because it's the zero value). + // Skip ID = 0 to remove the confusion between nodes whose + // parent has id 0 and root nodes (parent id is 0 because it's + // the zero value). s.nextID = 1 // Start unscoped, so idx is negative. s.currentIdx = -1 } switch node.Kind { case ast.KeyValue: - return s.checkKeyValue(node) + return s.checkKeyValue(s.currentIdx, node) case ast.Table: return s.checkTable(node) case ast.ArrayTable: @@ -206,11 +208,9 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error { return nil } -func (s *SeenTracker) checkKeyValue(node *ast.Node) error { +func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error { it := node.Key() - parentIdx := s.currentIdx - for it.Next() { k := it.Node().Data @@ -230,12 +230,27 @@ func (s *SeenTracker) checkKeyValue(node *ast.Node) error { } kind := valueKind + var err error if node.Value().Kind == ast.InlineTable { kind = tableKind + err = s.checkInlineTable(parentIdx, node.Value()) } + s.entries[parentIdx].kind = kind + return err +} + +func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error { + it := node.Children() + for it.Next() { + n := it.Node() + err := s.checkKeyValue(parentIdx, n) + if err != nil { + return err + } + } return nil } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index ee23808..26ea732 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -1910,6 +1910,12 @@ func TestIssue564(t *testing.T) { require.Equal(t, uuid{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, config.ID) } +func TestIssue658(t *testing.T) { + var v map[string]interface{} + err := toml.Unmarshal([]byte("e={b=1,b=4}"), &v) + require.Error(t, err) +} + //nolint:funlen func TestUnmarshalDecodeErrors(t *testing.T) { examples := []struct {