From 250e073408fa9921f651b466e05d6e45f2ad7b80 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Mon, 31 May 2021 12:14:13 -0400 Subject: [PATCH] Stack-based unmarshaler (#546) * Benchmark script * Rewrite unmarshaler using the stack Instead of tracking the build chain using `target`s, use the stack instead. Working and most benchmarks look good, but regression on structs unmarshalling. ~60% slower on ReferenceFile/struct. * Shortcut to check if last node of iterator * Remove unecessary pointer allocation * Skip over unused keys without marking them as seen * Add some tests * Fix mktemp on macos --- .golangci.toml | 2 +- benchmark/bench_datasets_test.go | 33 +- benchmark/benchmark_test.go | 12 + ci.sh | 37 +- errors_test.go | 4 +- fast_test.go | 100 ++ internal/ast/ast.go | 6 + .../imported_tests/unmarshal_imported_test.go | 78 +- internal/tracker/seen.go | 6 +- localtime_test.go | 18 - marshaler.go | 2 - marshaler_test.go | 22 - parser_test.go | 6 +- strict.go | 19 + targets.go | 536 -------- targets_test.go | 207 --- toml_testgen_test.go | 74 -- types.go | 13 + unmarshaler.go | 1165 ++++++++++++----- unmarshaler_test.go | 296 ++++- 20 files changed, 1340 insertions(+), 1296 deletions(-) create mode 100644 fast_test.go delete mode 100644 targets.go delete mode 100644 targets_test.go create mode 100644 types.go diff --git a/.golangci.toml b/.golangci.toml index fdf167b..067db55 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -60,7 +60,7 @@ enable = [ # "nlreturn", "noctx", "nolintlint", - "paralleltest", + #"paralleltest", "prealloc", "predeclared", "revive", diff --git a/benchmark/bench_datasets_test.go b/benchmark/bench_datasets_test.go index 1d668d9..ca974fd 100644 --- a/benchmark/bench_datasets_test.go +++ b/benchmark/bench_datasets_test.go @@ -31,13 +31,14 @@ var bench_inputs = []struct { func TestUnmarshalDatasetCode(t *testing.T) { for _, tc := range bench_inputs { - buf := fixture(t, tc.name) t.Run(tc.name, func(t *testing.T) { + buf := fixture(t, tc.name) + var v interface{} - check(t, toml.Unmarshal(buf, &v)) + require.NoError(t, toml.Unmarshal(buf, &v)) b, err := json.Marshal(v) - check(t, err) + require.NoError(t, err) require.Equal(t, len(b), tc.jsonLen) }) } @@ -45,14 +46,14 @@ func TestUnmarshalDatasetCode(t *testing.T) { func BenchmarkUnmarshalDataset(b *testing.B) { for _, tc := range bench_inputs { - buf := fixture(b, tc.name) b.Run(tc.name, func(b *testing.B) { + buf := fixture(b, tc.name) b.SetBytes(int64(len(buf))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { var v interface{} - check(b, toml.Unmarshal(buf, &v)) + require.NoError(b, toml.Unmarshal(buf, &v)) } }) } @@ -60,22 +61,20 @@ func BenchmarkUnmarshalDataset(b *testing.B) { // fixture returns the uncompressed contents of path. func fixture(tb testing.TB, path string) []byte { - f, err := os.Open(filepath.Join("testdata", path+".toml.gz")) - check(tb, err) + tb.Helper() + + file := path + ".toml.gz" + f, err := os.Open(filepath.Join("testdata", file)) + if os.IsNotExist(err) { + tb.Skip("benchmark fixture not found:", file) + } + require.NoError(tb, err) defer f.Close() gz, err := gzip.NewReader(f) - check(tb, err) + require.NoError(tb, err) buf, err := ioutil.ReadAll(gz) - check(tb, err) - + require.NoError(tb, err) return buf } - -func check(tb testing.TB, err error) { - if err != nil { - tb.Helper() - tb.Fatal(err) - } -} diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index da56f06..6d90571 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -9,6 +9,18 @@ import ( "github.com/stretchr/testify/require" ) +func TestUnmarshalSimple(t *testing.T) { + doc := []byte(`A = "hello"`) + d := struct { + A string + }{} + + err := toml.Unmarshal(doc, &d) + if err != nil { + panic(err) + } +} + func BenchmarkUnmarshalSimple(b *testing.B) { doc := []byte(`A = "hello"`) diff --git a/ci.sh b/ci.sh index 75d7008..d7e10cd 100755 --- a/ci.sh +++ b/ci.sh @@ -39,6 +39,9 @@ benchmark [OPTIONS...] [BRANCH] -d Compare benchmarks of HEAD with BRANCH using benchstats. In this form the BRANCH argument is required. + -a Compare benchmarks of HEAD against go-toml v1 and + BurntSushi/toml. + coverage [OPTIONS...] [BRANCH] Generates code coverage. @@ -118,6 +121,7 @@ coverage() { bench() { branch="${1}" out="${2}" + replace="${3}" dir="$(mktemp -d)" stderr "Executing benchmark for ${branch} at ${dir}" @@ -129,6 +133,15 @@ bench() { fi pushd "$dir" + + if [ "${replace}" != "" ]; then + find ./benchmark/ -iname '*.go' -exec sed -i -E "s|github.com/pelletier/go-toml/v2|${replace}|g" {} \; + go get "${replace}" + # hack: remove canada.toml.gz because it is not supported by + # burntsushi, and replace is only used for benchmark -a + rm -f benchmark/testdata/canada.toml.gz + fi + go test -bench=. -count=10 ./... | tee "${out}" popd @@ -142,14 +155,34 @@ benchmark() { -d) shift target="${1?Need to provide a target branch argument}" - old=`mktemp` + + old=`mktemp --suffix=-${target}` bench "${target}" "${old}" - new=`mktemp` + new=`mktemp --suffix=-HEAD` bench HEAD "${new}" + benchstat "${old}" "${new}" return 0 ;; + -a) + shift + + v2stats=`mktemp -t go-toml-v2` + bench HEAD "${v2stats}" "github.com/pelletier/go-toml/v2" + v1stats=`mktemp -t go-toml-v1` + bench HEAD "${v1stats}" "github.com/pelletier/go-toml" + bsstats=`mktemp -t bs-toml` + bench HEAD "${bsstats}" "github.com/BurntSushi/toml" + + cp "${v2stats}" go-toml-v2.txt + cp "${v1stats}" go-toml-v1.txt + cp "${bsstats}" bs-toml.txt + + benchstat -geomean go-toml-v2.txt go-toml-v1.txt bs-toml.txt + + rm -f go-toml-v2.txt go-toml-v1.txt bs-toml.txt + return $? esac bench "${1-HEAD}" `mktemp` diff --git a/errors_test.go b/errors_test.go index d6af314..d098647 100644 --- a/errors_test.go +++ b/errors_test.go @@ -12,7 +12,6 @@ import ( //nolint:funlen func TestDecodeError(t *testing.T) { - t.Parallel() examples := []struct { desc string @@ -154,7 +153,7 @@ line 5`, for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() + b := bytes.Buffer{} b.Write([]byte(e.doc[0])) start := b.Len() @@ -182,7 +181,6 @@ line 5`, } func TestDecodeError_Accessors(t *testing.T) { - t.Parallel() e := DecodeError{ message: "foo", diff --git a/fast_test.go b/fast_test.go new file mode 100644 index 0000000..02910bb --- /dev/null +++ b/fast_test.go @@ -0,0 +1,100 @@ +package toml_test + +import ( + "testing" + + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" +) + +func TestFastSimple(t *testing.T) { + m := map[string]int64{} + err := toml.Unmarshal([]byte(`a = 42`), &m) + require.NoError(t, err) + require.Equal(t, map[string]int64{"a": 42}, m) +} + +func TestFastSimpleString(t *testing.T) { + m := map[string]string{} + err := toml.Unmarshal([]byte(`a = "hello"`), &m) + require.NoError(t, err) + require.Equal(t, map[string]string{"a": "hello"}, m) +} + +func TestFastSimpleInterface(t *testing.T) { + m := map[string]interface{}{} + err := toml.Unmarshal([]byte(` + a = "hello" + b = 42`), &m) + require.NoError(t, err) + require.Equal(t, map[string]interface{}{ + "a": "hello", + "b": int64(42), + }, m) +} + +func TestFastMultipartKeyInterface(t *testing.T) { + m := map[string]interface{}{} + err := toml.Unmarshal([]byte(` + a.interim = "test" + a.b.c = "hello" + b = 42`), &m) + require.NoError(t, err) + require.Equal(t, map[string]interface{}{ + "a": map[string]interface{}{ + "interim": "test", + "b": map[string]interface{}{ + "c": "hello", + }, + }, + "b": int64(42), + }, m) +} + +func TestFastExistingMap(t *testing.T) { + m := map[string]interface{}{ + "ints": map[string]int{}, + } + err := toml.Unmarshal([]byte(` + ints.one = 1 + ints.two = 2 + strings.yo = "hello"`), &m) + require.NoError(t, err) + require.Equal(t, map[string]interface{}{ + "ints": map[string]interface{}{ + "one": int64(1), + "two": int64(2), + }, + "strings": map[string]interface{}{ + "yo": "hello", + }, + }, m) +} + +func TestFastArrayTable(t *testing.T) { + b := []byte(` + [root] + [[root.nested]] + name = 'Bob' + [[root.nested]] + name = 'Alice' + `) + + m := map[string]interface{}{} + + err := toml.Unmarshal(b, &m) + require.NoError(t, err) + + require.Equal(t, map[string]interface{}{ + "root": map[string]interface{}{ + "nested": []interface{}{ + map[string]interface{}{ + "name": "Bob", + }, + map[string]interface{}{ + "name": "Alice", + }, + }, + }, + }, m) +} diff --git a/internal/ast/ast.go b/internal/ast/ast.go index ba2729e..f9059d8 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -28,6 +28,12 @@ func (c *Iterator) Next() bool { return c.node.Valid() } +// IsLast returns true if the current node of the iterator is the last one. +// Subsequent call to Next() will return false. +func (c *Iterator) IsLast() bool { + return c.node.next <= 0 +} + // Node returns a copy of the node pointed at by the iterator. func (c *Iterator) Node() Node { return c.node diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index d3f54d8..2345445 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -223,11 +223,13 @@ type testSubDoc struct { unexported int `toml:"shouldntBeHere"` } -var biteMe = "Bite me" -var float1 float32 = 12.3 -var float2 float32 = 45.6 -var float3 float32 = 78.9 -var subdoc = testSubDoc{"Second", 0} +var ( + biteMe = "Bite me" + float1 float32 = 12.3 + float2 float32 = 45.6 + float3 float32 = 78.9 + subdoc = testSubDoc{"Second", 0} +) var docData = testDoc{ Title: "TOML Marshal Testing", @@ -382,7 +384,7 @@ var intErrTomls = []string{ } func TestErrUnmarshal(t *testing.T) { - var errTomls = []string{ + errTomls := []string{ "bool = truly\ndate = 1979-05-27T07:32:00Z\nfloat = 123.4\nint = 5000\nstring = \"Bite me\"", "bool = true\ndate = 1979-05-27T07:3200Z\nfloat = 123.4\nint = 5000\nstring = \"Bite me\"", "bool = true\ndate = 1979-05-27T07:32:00Z\nfloat = 123a4\nint = 5000\nstring = \"Bite me\"", @@ -468,7 +470,7 @@ func TestEmptyUnmarshalOmit(t *testing.T) { Map map[string]string `toml:"map,omitempty"` } - var emptyTestData2 = emptyMarshalTestStruct2{ + emptyTestData2 := emptyMarshalTestStruct2{ Title: "Placeholder", Bool: false, Int: 0, @@ -496,21 +498,23 @@ type pointerMarshalTestStruct struct { DblPtr *[]*[]*string } -var pointerStr = "Hello" -var pointerList = []string{"Hello back"} -var pointerListPtr = []*string{&pointerStr} -var pointerMap = map[string]string{"response": "Goodbye"} -var pointerMapPtr = map[string]*string{"alternate": &pointerStr} -var pointerTestData = pointerMarshalTestStruct{ - Str: &pointerStr, - List: &pointerList, - ListPtr: &pointerListPtr, - Map: &pointerMap, - MapPtr: &pointerMapPtr, - EmptyStr: nil, - EmptyList: nil, - EmptyMap: nil, -} +var ( + pointerStr = "Hello" + pointerList = []string{"Hello back"} + pointerListPtr = []*string{&pointerStr} + pointerMap = map[string]string{"response": "Goodbye"} + pointerMapPtr = map[string]*string{"alternate": &pointerStr} + pointerTestData = pointerMarshalTestStruct{ + Str: &pointerStr, + List: &pointerList, + ListPtr: &pointerListPtr, + Map: &pointerMap, + MapPtr: &pointerMapPtr, + EmptyStr: nil, + EmptyList: nil, + EmptyMap: nil, + } +) var pointerTestToml = []byte(`List = ["Hello back"] ListPtr = ["Hello"] @@ -538,15 +542,17 @@ func TestUnmarshalTypeMismatch(t *testing.T) { type nestedMarshalTestStruct struct { String [][]string - //Struct [][]basicMarshalTestSubStruct + // Struct [][]basicMarshalTestSubStruct StringPtr *[]*[]*string // StructPtr *[]*[]*basicMarshalTestSubStruct } -var str1 = "Three" -var str2 = "Four" -var strPtr = []*string{&str1, &str2} -var strPtr2 = []*[]*string{&strPtr} +var ( + str1 = "Three" + str2 = "Four" + strPtr = []*string{&str1, &str2} + strPtr2 = []*[]*string{&strPtr} +) var nestedTestData = nestedMarshalTestStruct{ String: [][]string{{"Five", "Six"}, {"One", "Two"}}, @@ -597,6 +603,7 @@ var nestedCustomMarshalerData = customMarshalerParent{ var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] me = "Maiku Suteda" `) + var nestedCustomMarshalerTomlForUnmarshal = []byte(`[friends] FirstName = "Sally" LastName = "Fields"`) @@ -613,11 +620,11 @@ func (x *IntOrString) MarshalTOML() ([]byte, error) { } func TestUnmarshalTextMarshaler(t *testing.T) { - var nested = struct { + nested := struct { Friends textMarshaler `toml:"friends"` }{} - var expected = struct { + expected := struct { Friends textMarshaler `toml:"friends"` }{ Friends: textMarshaler{FirstName: "Sally", LastName: "Fields"}, @@ -1360,7 +1367,6 @@ func TestUnmarshalPreservesUnexportedFields(t *testing.T) { t.Run("unexported field should not be set from toml", func(t *testing.T) { var actual unexportedFieldPreservationTest err := toml.Unmarshal([]byte(doc), &actual) - if err != nil { t.Fatal("did not expect an error") } @@ -1394,7 +1400,6 @@ func TestUnmarshalPreservesUnexportedFields(t *testing.T) { Nested3: &unexportedFieldPreservationTestNested{"baz", "bax"}, } err := toml.Unmarshal([]byte(doc), &actual) - if err != nil { t.Fatal("did not expect an error") } @@ -1431,7 +1436,6 @@ func TestUnmarshalLocalDate(t *testing.T) { var obj dateStruct err := toml.Unmarshal([]byte(doc), &obj) - if err != nil { t.Fatal(err) } @@ -1457,7 +1461,6 @@ func TestUnmarshalLocalDate(t *testing.T) { var obj dateStruct err := toml.Unmarshal([]byte(doc), &obj) - if err != nil { t.Fatal(err) } @@ -1495,7 +1498,8 @@ func TestUnmarshalLocalDateTime(t *testing.T) { Second: 0, Nanosecond: 0, }, - }}, + }, + }, { name: "with nanoseconds", in: "1979-05-27T00:32:00.999999", @@ -1526,7 +1530,6 @@ func TestUnmarshalLocalDateTime(t *testing.T) { var obj dateStruct err := toml.Unmarshal([]byte(doc), &obj) - if err != nil { t.Fatal(err) } @@ -1544,7 +1547,6 @@ func TestUnmarshalLocalDateTime(t *testing.T) { var obj dateStruct err := toml.Unmarshal([]byte(doc), &obj) - if err != nil { t.Fatal(err) } @@ -1613,7 +1615,6 @@ func TestUnmarshalLocalTime(t *testing.T) { var obj dateStruct err := toml.Unmarshal([]byte(doc), &obj) - if err != nil { t.Fatal(err) } @@ -2283,8 +2284,7 @@ func (d *durationString) UnmarshalTOML(v interface{}) error { return nil } -type config437Error struct { -} +type config437Error struct{} func (e *config437Error) UnmarshalTOML(v interface{}) error { return errors.New("expected") diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index 4b5d392..0f6bd01 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -106,7 +106,7 @@ func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit // that have been seen in previous calls, and validates that types are consistent. func (s *SeenTracker) CheckExpression(node ast.Node) error { if s.entries == nil { - //s.entries = make([]entry, 0, 8) + // s.entries = make([]entry, 0, 8) // Skip ID = 0 to remove the confusion between nodes whose parent has // id 0 and root nodes (parent id is 0 because it's the zero value). s.nextID = 1 @@ -134,7 +134,7 @@ func (s *SeenTracker) checkTable(node ast.Node) error { // it in a function requires to copy the iterator, or allocate it to the // heap, which is not cheap. for it.Next() { - if !it.Node().Next().Valid() { + if it.IsLast() { break } @@ -175,7 +175,7 @@ func (s *SeenTracker) checkArrayTable(node ast.Node) error { parentIdx := -1 for it.Next() { - if !it.Node().Next().Valid() { + if it.IsLast() { break } diff --git a/localtime_test.go b/localtime_test.go index 6741504..646a3db 100644 --- a/localtime_test.go +++ b/localtime_test.go @@ -26,7 +26,6 @@ func cmpEqual(x, y interface{}) bool { } func TestDates(t *testing.T) { - t.Parallel() for _, test := range []struct { date LocalDate @@ -64,7 +63,6 @@ func TestDates(t *testing.T) { } func TestDateIsValid(t *testing.T) { - t.Parallel() for _, test := range []struct { date LocalDate @@ -91,7 +89,6 @@ func TestDateIsValid(t *testing.T) { } func TestParseDate(t *testing.T) { - t.Parallel() var emptyDate LocalDate @@ -118,7 +115,6 @@ func TestParseDate(t *testing.T) { } func TestDateArithmetic(t *testing.T) { - t.Parallel() for _, test := range []struct { desc string @@ -180,7 +176,6 @@ func TestDateArithmetic(t *testing.T) { } func TestDateBefore(t *testing.T) { - t.Parallel() for _, test := range []struct { d1, d2 LocalDate @@ -198,7 +193,6 @@ func TestDateBefore(t *testing.T) { } func TestDateAfter(t *testing.T) { - t.Parallel() for _, test := range []struct { d1, d2 LocalDate @@ -215,7 +209,6 @@ func TestDateAfter(t *testing.T) { } func TestTimeToString(t *testing.T) { - t.Parallel() for _, test := range []struct { str string @@ -249,7 +242,6 @@ func TestTimeToString(t *testing.T) { } func TestTimeOf(t *testing.T) { - t.Parallel() for _, test := range []struct { time time.Time @@ -265,7 +257,6 @@ func TestTimeOf(t *testing.T) { } func TestTimeIsValid(t *testing.T) { - t.Parallel() for _, test := range []struct { time LocalTime @@ -291,7 +282,6 @@ func TestTimeIsValid(t *testing.T) { } func TestDateTimeToString(t *testing.T) { - t.Parallel() for _, test := range []struct { str string @@ -323,7 +313,6 @@ func TestDateTimeToString(t *testing.T) { } func TestParseDateTimeErrors(t *testing.T) { - t.Parallel() for _, str := range []string{ "", @@ -339,7 +328,6 @@ func TestParseDateTimeErrors(t *testing.T) { } func TestDateTimeOf(t *testing.T) { - t.Parallel() for _, test := range []struct { time time.Time @@ -361,7 +349,6 @@ func TestDateTimeOf(t *testing.T) { } func TestDateTimeIsValid(t *testing.T) { - t.Parallel() // No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid. for _, test := range []struct { @@ -380,7 +367,6 @@ func TestDateTimeIsValid(t *testing.T) { } func TestDateTimeIn(t *testing.T) { - t.Parallel() dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}} @@ -391,7 +377,6 @@ func TestDateTimeIn(t *testing.T) { } func TestDateTimeBefore(t *testing.T) { - t.Parallel() d1 := LocalDate{2016, 12, 31} d2 := LocalDate{2017, 1, 1} @@ -414,7 +399,6 @@ func TestDateTimeBefore(t *testing.T) { } func TestDateTimeAfter(t *testing.T) { - t.Parallel() d1 := LocalDate{2016, 12, 31} d2 := LocalDate{2017, 1, 1} @@ -437,7 +421,6 @@ func TestDateTimeAfter(t *testing.T) { } func TestMarshalJSON(t *testing.T) { - t.Parallel() for _, test := range []struct { value interface{} @@ -459,7 +442,6 @@ func TestMarshalJSON(t *testing.T) { } func TestUnmarshalJSON(t *testing.T) { - t.Parallel() var ( d LocalDate diff --git a/marshaler.go b/marshaler.go index ce9972a..0d8ed9c 100644 --- a/marshaler.go +++ b/marshaler.go @@ -640,8 +640,6 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte return b, nil } -var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() - func willConvertToTable(ctx encoderCtx, v reflect.Value) bool { if v.Type() == timeType || v.Type().Implements(textMarshalerType) { return false diff --git a/marshaler_test.go b/marshaler_test.go index 5c946c3..333c93c 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -14,8 +14,6 @@ import ( //nolint:funlen func TestMarshal(t *testing.T) { - t.Parallel() - someInt := 42 type structInline struct { @@ -516,8 +514,6 @@ K = 42`, for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - b, err := toml.Marshal(e.v) if e.err { require.Error(t, err) @@ -609,8 +605,6 @@ func equalStringsIgnoreNewlines(t *testing.T, expected string, actual string) { //nolint:funlen func TestMarshalIndentTables(t *testing.T) { - t.Parallel() - examples := []struct { desc string v interface{} @@ -661,8 +655,6 @@ root = 'value0' for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - var buf strings.Builder enc := toml.NewEncoder(&buf) enc.SetIndentTables(true) @@ -685,24 +677,18 @@ func (c *customTextMarshaler) MarshalText() ([]byte, error) { } func TestMarshalTextMarshaler_NoRoot(t *testing.T) { - t.Parallel() - c := customTextMarshaler{} _, err := toml.Marshal(&c) require.Error(t, err) } func TestMarshalTextMarshaler_Error(t *testing.T) { - t.Parallel() - m := map[string]interface{}{"a": &customTextMarshaler{value: 1}} _, err := toml.Marshal(m) require.Error(t, err) } func TestMarshalTextMarshaler_ErrorInline(t *testing.T) { - t.Parallel() - type s struct { A map[string]interface{} `inline:"true"` } @@ -716,8 +702,6 @@ func TestMarshalTextMarshaler_ErrorInline(t *testing.T) { } func TestMarshalTextMarshaler(t *testing.T) { - t.Parallel() - m := map[string]interface{}{"a": &customTextMarshaler{value: 2}} r, err := toml.Marshal(m) require.NoError(t, err) @@ -731,7 +715,6 @@ func (b *brokenWriter) Write([]byte) (int, error) { } func TestEncodeToBrokenWriter(t *testing.T) { - t.Parallel() w := brokenWriter{} enc := toml.NewEncoder(&w) err := enc.Encode(map[string]string{"hello": "world"}) @@ -739,7 +722,6 @@ func TestEncodeToBrokenWriter(t *testing.T) { } func TestEncoderSetIndentSymbol(t *testing.T) { - t.Parallel() var w strings.Builder enc := toml.NewEncoder(&w) enc.SetIndentTables(true) @@ -753,8 +735,6 @@ func TestEncoderSetIndentSymbol(t *testing.T) { } func TestIssue436(t *testing.T) { - t.Parallel() - data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) var v interface{} @@ -774,8 +754,6 @@ c = 'd' } func TestIssue424(t *testing.T) { - t.Parallel() - type Message1 struct { Text string } diff --git a/parser_test.go b/parser_test.go index bb3d3bd..fdb4f27 100644 --- a/parser_test.go +++ b/parser_test.go @@ -9,7 +9,6 @@ import ( //nolint:funlen func TestParser_AST_Numbers(t *testing.T) { - t.Parallel() examples := []struct { desc string @@ -137,7 +136,7 @@ func TestParser_AST_Numbers(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() + p := parser{} p.Reset([]byte(`A = ` + e.input)) p.NextExpression() @@ -200,7 +199,6 @@ func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) { //nolint:funlen func TestParser_AST(t *testing.T) { - t.Parallel() examples := []struct { desc string @@ -340,7 +338,7 @@ func TestParser_AST(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() + p := parser{} p.Reset([]byte(e.input)) p.NextExpression() diff --git a/strict.go b/strict.go index 2b2e7d6..ca482c4 100644 --- a/strict.go +++ b/strict.go @@ -3,6 +3,7 @@ package toml import ( "github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/tracker" + "github.com/pelletier/go-toml/v2/internal/unsafe" ) type strict struct { @@ -86,3 +87,21 @@ func (s *strict) Error(doc []byte) error { return err } + +func keyLocation(node ast.Node) []byte { + k := node.Key() + + hasOne := k.Next() + if !hasOne { + panic("should not be called with empty key") + } + + start := k.Node().Data + end := k.Node().Data + + for k.Next() { + end = k.Node().Data + } + + return unsafe.BytesRange(start, end) +} diff --git a/targets.go b/targets.go deleted file mode 100644 index 370892f..0000000 --- a/targets.go +++ /dev/null @@ -1,536 +0,0 @@ -package toml - -import ( - "fmt" - "math" - "reflect" - "strings" - "sync" -) - -type target interface { - // Dereferences the target. - get() reflect.Value - - // Store a string at the target. - setString(v string) - - // Store a boolean at the target - setBool(v bool) - - // Store an int64 at the target - setInt64(v int64) - - // Store a float64 at the target - setFloat64(v float64) - - // Stores any value at the target - set(v reflect.Value) -} - -// valueTarget just contains a reflect.Value that can be set. -// It is used for struct fields. -type valueTarget reflect.Value - -func (t valueTarget) get() reflect.Value { - return reflect.Value(t) -} - -func (t valueTarget) set(v reflect.Value) { - reflect.Value(t).Set(v) -} - -func (t valueTarget) setString(v string) { - t.get().SetString(v) -} - -func (t valueTarget) setBool(v bool) { - t.get().SetBool(v) -} - -func (t valueTarget) setInt64(v int64) { - t.get().SetInt(v) -} - -func (t valueTarget) setFloat64(v float64) { - t.get().SetFloat(v) -} - -// interfaceTarget wraps an other target to dereference on get. -type interfaceTarget struct { - x target -} - -func (t interfaceTarget) get() reflect.Value { - return t.x.get().Elem() -} - -func (t interfaceTarget) set(v reflect.Value) { - t.x.set(v) -} - -func (t interfaceTarget) setString(v string) { - panic("interface targets should always go through set") -} - -func (t interfaceTarget) setBool(v bool) { - panic("interface targets should always go through set") -} - -func (t interfaceTarget) setInt64(v int64) { - panic("interface targets should always go through set") -} - -func (t interfaceTarget) setFloat64(v float64) { - panic("interface targets should always go through set") -} - -// mapTarget targets a specific key of a map. -type mapTarget struct { - v reflect.Value - k reflect.Value -} - -func (t mapTarget) get() reflect.Value { - return t.v.MapIndex(t.k) -} - -func (t mapTarget) set(v reflect.Value) { - t.v.SetMapIndex(t.k, v) -} - -func (t mapTarget) setString(v string) { - t.set(reflect.ValueOf(v)) -} - -func (t mapTarget) setBool(v bool) { - t.set(reflect.ValueOf(v)) -} - -func (t mapTarget) setInt64(v int64) { - t.set(reflect.ValueOf(v)) -} - -func (t mapTarget) setFloat64(v float64) { - t.set(reflect.ValueOf(v)) -} - -// makes sure that the value pointed at by t is indexable (Slice, Array), or -// dereferences to an indexable (Ptr, Interface). -func ensureValueIndexable(t target) error { - f := t.get() - - switch f.Type().Kind() { - case reflect.Slice: - if f.IsNil() { - t.set(reflect.MakeSlice(f.Type(), 0, 0)) - return nil - } - case reflect.Interface: - if f.IsNil() || f.Elem().Type() != sliceInterfaceType { - t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) - return nil - } - case reflect.Ptr: - panic("pointer should have already been dereferenced") - case reflect.Array: - // arrays are always initialized. - default: - return fmt.Errorf("toml: cannot store array in a %s", f.Kind()) - } - - return nil -} - -var ( - sliceInterfaceType = reflect.TypeOf([]interface{}{}) - mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) -) - -func ensureMapIfInterface(x target) { - v := x.get() - - if v.Kind() == reflect.Interface && v.IsNil() { - newElement := reflect.MakeMap(mapStringInterfaceType) - - x.set(newElement) - } -} - -func setString(t target, v string) error { - f := t.get() - - switch f.Kind() { - case reflect.String: - t.setString(v) - case reflect.Interface: - t.set(reflect.ValueOf(v)) - default: - return fmt.Errorf("toml: cannot assign string to a %s", f.Kind()) - } - - return nil -} - -func setBool(t target, v bool) error { - f := t.get() - - switch f.Kind() { - case reflect.Bool: - t.setBool(v) - case reflect.Interface: - t.set(reflect.ValueOf(v)) - default: - return fmt.Errorf("toml: cannot assign boolean to a %s", f.Kind()) - } - - return nil -} - -const ( - maxInt = int64(^uint(0) >> 1) - minInt = -maxInt - 1 -) - -//nolint:funlen,gocognit,cyclop -func setInt64(t target, v int64) error { - f := t.get() - - switch f.Kind() { - case reflect.Int64: - t.setInt64(v) - case reflect.Int32: - if v < math.MinInt32 || v > math.MaxInt32 { - return fmt.Errorf("toml: number %d does not fit in an int32", v) - } - - t.set(reflect.ValueOf(int32(v))) - return nil - case reflect.Int16: - if v < math.MinInt16 || v > math.MaxInt16 { - return fmt.Errorf("toml: number %d does not fit in an int16", v) - } - - t.set(reflect.ValueOf(int16(v))) - case reflect.Int8: - if v < math.MinInt8 || v > math.MaxInt8 { - return fmt.Errorf("toml: number %d does not fit in an int8", v) - } - - t.set(reflect.ValueOf(int8(v))) - case reflect.Int: - if v < minInt || v > maxInt { - return fmt.Errorf("toml: number %d does not fit in an int", v) - } - - t.set(reflect.ValueOf(int(v))) - case reflect.Uint64: - if v < 0 { - return fmt.Errorf("toml: negative number %d does not fit in an uint64", v) - } - - t.set(reflect.ValueOf(uint64(v))) - case reflect.Uint32: - if v < 0 || v > math.MaxUint32 { - return fmt.Errorf("toml: negative number %d does not fit in an uint32", v) - } - - t.set(reflect.ValueOf(uint32(v))) - case reflect.Uint16: - if v < 0 || v > math.MaxUint16 { - return fmt.Errorf("toml: negative number %d does not fit in an uint16", v) - } - - t.set(reflect.ValueOf(uint16(v))) - case reflect.Uint8: - if v < 0 || v > math.MaxUint8 { - return fmt.Errorf("toml: negative number %d does not fit in an uint8", v) - } - - t.set(reflect.ValueOf(uint8(v))) - case reflect.Uint: - if v < 0 { - return fmt.Errorf("toml: negative number %d does not fit in an uint", v) - } - - t.set(reflect.ValueOf(uint(v))) - case reflect.Interface: - t.set(reflect.ValueOf(v)) - default: - return fmt.Errorf("toml: integer cannot be assigned to %s", f.Kind()) - } - - return nil -} - -func setFloat64(t target, v float64) error { - f := t.get() - - switch f.Kind() { - case reflect.Float64: - t.setFloat64(v) - case reflect.Float32: - if v > math.MaxFloat32 { - return fmt.Errorf("toml: number %f does not fit in a float32", v) - } - - t.set(reflect.ValueOf(float32(v))) - case reflect.Interface: - t.set(reflect.ValueOf(v)) - default: - return fmt.Errorf("toml: float cannot be assigned to %s", f.Kind()) - } - - return nil -} - -// Returns the element at idx of the value pointed at by target, or an error if -// t does not point to an indexable. -// If the target points to an Array and idx is out of bounds, it returns -// (nil, nil) as this is not a fatal error (the unmarshaler will skip). -func elementAt(t target, idx int) target { - f := t.get() - - switch f.Kind() { - case reflect.Slice: - //nolint:godox - // TODO: use the idx function argument and avoid alloc if possible. - idx := f.Len() - - t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) - - return valueTarget(t.get().Index(idx)) - case reflect.Array: - if idx >= f.Len() { - return nil - } - - return valueTarget(f.Index(idx)) - case reflect.Interface: - // This function is called after ensureValueIndexable, so it's - // guaranteed that f contains an initialized slice. - ifaceElem := f.Elem() - idx := ifaceElem.Len() - newElem := reflect.New(ifaceElem.Type().Elem()).Elem() - newSlice := reflect.Append(ifaceElem, newElem) - - t.set(newSlice) - - return valueTarget(t.get().Elem().Index(idx)) - default: - // Why ensureValueIndexable let it go through? - panic(fmt.Errorf("elementAt received unhandled value type: %s", f.Kind())) - } -} - -func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (target, bool, error) { - x := t.get() - - switch x.Kind() { - // Kinds that need to recurse - case reflect.Interface: - t := scopeInterface(shouldAppend, t) - return d.scopeTableTarget(shouldAppend, t, name) - case reflect.Ptr: - t := scopePtr(t) - return d.scopeTableTarget(shouldAppend, t, name) - case reflect.Slice: - t := scopeSlice(shouldAppend, t) - shouldAppend = false - return d.scopeTableTarget(shouldAppend, t, name) - case reflect.Array: - t, err := d.scopeArray(shouldAppend, t) - if err != nil { - return t, false, err - } - shouldAppend = false - - return d.scopeTableTarget(shouldAppend, t, name) - - // Terminal kinds - case reflect.Struct: - return scopeStruct(x, name) - case reflect.Map: - if x.IsNil() { - t.set(reflect.MakeMap(x.Type())) - x = t.get() - } - - return scopeMap(x, name) - default: - panic(fmt.Sprintf("can't scope on a %s", x.Kind())) - } -} - -func scopeInterface(shouldAppend bool, t target) target { - initInterface(shouldAppend, t) - return interfaceTarget{t} -} - -func scopePtr(t target) target { - initPtr(t) - return valueTarget(t.get().Elem()) -} - -func initPtr(t target) { - x := t.get() - if !x.IsNil() { - return - } - - t.set(reflect.New(x.Type().Elem())) -} - -// initInterface makes sure that the interface pointed at by the target is not -// nil. -// Returns the target to the initialized value of the target. -func initInterface(shouldAppend bool, t target) { - x := t.get() - - if x.Kind() != reflect.Interface { - panic("this should only be called on interfaces") - } - - if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) { - return - } - - var newElement reflect.Value - if shouldAppend { - newElement = reflect.MakeSlice(sliceInterfaceType, 0, 0) - } else { - newElement = reflect.MakeMap(mapStringInterfaceType) - } - - t.set(newElement) -} - -func scopeSlice(shouldAppend bool, t target) target { - v := t.get() - - if shouldAppend { - newElem := reflect.New(v.Type().Elem()) - newSlice := reflect.Append(v, newElem.Elem()) - - t.set(newSlice) - - v = t.get() - } - - return valueTarget(v.Index(v.Len() - 1)) -} - -func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) { - v := t.get() - - idx := d.arrayIndex(shouldAppend, v) - - if idx >= v.Len() { - return nil, fmt.Errorf("toml: impossible to insert element beyond array's size: %d", v.Len()) - } - - return valueTarget(v.Index(idx)), nil -} - -func scopeMap(v reflect.Value, name string) (target, bool, error) { - k := reflect.ValueOf(name) - - keyType := v.Type().Key() - if !k.Type().AssignableTo(keyType) { - if !k.Type().ConvertibleTo(keyType) { - return nil, false, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", k.Type(), keyType) - } - - k = k.Convert(keyType) - } - - if !v.MapIndex(k).IsValid() { - newElem := reflect.New(v.Type().Elem()) - v.SetMapIndex(k, newElem.Elem()) - } - - return mapTarget{ - v: v, - k: k, - }, true, nil -} - -type fieldPathsMap = map[string][]int - -type fieldPathsCache struct { - m map[reflect.Type]fieldPathsMap - l sync.RWMutex -} - -func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) { - c.l.RLock() - paths, ok := c.m[t] - c.l.RUnlock() - - return paths, ok -} - -func (c *fieldPathsCache) set(t reflect.Type, m fieldPathsMap) { - c.l.Lock() - c.m[t] = m - c.l.Unlock() -} - -var globalFieldPathsCache = fieldPathsCache{ - m: map[reflect.Type]fieldPathsMap{}, - l: sync.RWMutex{}, -} - -func scopeStruct(v reflect.Value, name string) (target, bool, error) { - //nolint:godox - // TODO: cache this, and reduce allocations - fieldPaths, ok := globalFieldPathsCache.get(v.Type()) - if !ok { - fieldPaths = map[string][]int{} - - path := make([]int, 0, 16) - - var walk func(reflect.Value) - walk = func(v reflect.Value) { - t := v.Type() - for i := 0; i < t.NumField(); i++ { - l := len(path) - path = append(path, i) - f := t.Field(i) - - if f.Anonymous { - walk(v.Field(i)) - } else if f.PkgPath == "" { - // only consider exported fields - fieldName, ok := f.Tag.Lookup("toml") - if !ok { - fieldName = f.Name - } - - pathCopy := make([]int, len(path)) - copy(pathCopy, path) - - fieldPaths[fieldName] = pathCopy - // extra copy for the case-insensitive match - fieldPaths[strings.ToLower(fieldName)] = pathCopy - } - path = path[:l] - } - } - - walk(v) - - globalFieldPathsCache.set(v.Type(), fieldPaths) - } - - path, ok := fieldPaths[name] - if !ok { - path, ok = fieldPaths[strings.ToLower(name)] - } - - if !ok { - return nil, false, nil - } - - return valueTarget(v.FieldByIndex(path)), true, nil -} diff --git a/targets_test.go b/targets_test.go deleted file mode 100644 index c895ad5..0000000 --- a/targets_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package toml - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStructTarget_Ensure(t *testing.T) { - t.Parallel() - - examples := []struct { - desc string - input reflect.Value - name string - test func(v reflect.Value, err error) - }{ - { - desc: "handle a nil slice of string", - input: reflect.ValueOf(&struct{ A []string }{}).Elem(), - name: "A", - test: func(v reflect.Value, err error) { - assert.NoError(t, err) - assert.False(t, v.IsNil()) - }, - }, - { - desc: "handle an existing slice of string", - input: reflect.ValueOf(&struct{ A []string }{A: []string{"foo"}}).Elem(), - name: "A", - test: func(v reflect.Value, err error) { - assert.NoError(t, err) - require.False(t, v.IsNil()) - - s, ok := v.Interface().([]string) - if !ok { - t.Errorf("interface %v should be castable into []string", s) - return - } - - assert.Equal(t, []string{"foo"}, s) - }, - }, - } - - for _, e := range examples { - e := e - t.Run(e.desc, func(t *testing.T) { - t.Parallel() - - d := decoder{} - target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) - require.NoError(t, err) - err = ensureValueIndexable(target) - v := target.get() - e.test(v, err) - }) - } -} - -func TestStructTarget_SetString(t *testing.T) { - t.Parallel() - - str := "value" - - examples := []struct { - desc string - input reflect.Value - name string - test func(v reflect.Value, err error) - }{ - { - desc: "sets a string", - input: reflect.ValueOf(&struct{ A string }{}).Elem(), - name: "A", - test: func(v reflect.Value, err error) { - assert.NoError(t, err) - assert.Equal(t, str, v.String()) - }, - }, - { - desc: "fails on a float", - input: reflect.ValueOf(&struct{ A float64 }{}).Elem(), - name: "A", - test: func(v reflect.Value, err error) { - assert.Error(t, err) - }, - }, - { - desc: "fails on a slice", - input: reflect.ValueOf(&struct{ A []string }{}).Elem(), - name: "A", - test: func(v reflect.Value, err error) { - assert.Error(t, err) - }, - }, - } - - for _, e := range examples { - e := e - t.Run(e.desc, func(t *testing.T) { - t.Parallel() - - d := decoder{} - target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) - require.NoError(t, err) - err = setString(target, str) - v := target.get() - e.test(v, err) - }) - } -} - -func TestPushNew(t *testing.T) { - t.Parallel() - - t.Run("slice of strings", func(t *testing.T) { - t.Parallel() - - type Doc struct { - A []string - } - d := Doc{} - - dec := decoder{} - x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") - require.NoError(t, err) - - n := elementAt(x, 0) - n.setString("hello") - require.Equal(t, []string{"hello"}, d.A) - - n = elementAt(x, 1) - n.setString("world") - require.Equal(t, []string{"hello", "world"}, d.A) - }) - - t.Run("slice of interfaces", func(t *testing.T) { - t.Parallel() - - type Doc struct { - A []interface{} - } - d := Doc{} - - dec := decoder{} - x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") - require.NoError(t, err) - - n := elementAt(x, 0) - require.NoError(t, setString(n, "hello")) - require.Equal(t, []interface{}{"hello"}, d.A) - - n = elementAt(x, 1) - require.NoError(t, setString(n, "world")) - require.Equal(t, []interface{}{"hello", "world"}, d.A) - }) -} - -func TestScope_Struct(t *testing.T) { - t.Parallel() - - examples := []struct { - desc string - input reflect.Value - name string - err bool - found bool - idx []int - }{ - { - desc: "simple field", - input: reflect.ValueOf(&struct{ A string }{}).Elem(), - name: "A", - idx: []int{0}, - found: true, - }, - { - desc: "fails not-exported field", - input: reflect.ValueOf(&struct{ a string }{}).Elem(), - name: "a", - err: false, - found: false, - }, - } - - for _, e := range examples { - e := e - t.Run(e.desc, func(t *testing.T) { - t.Parallel() - - dec := decoder{} - x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name) - assert.Equal(t, e.found, found) - if e.err { - assert.Error(t, err) - } - if found { - x2, ok := x.(valueTarget) - require.True(t, ok) - x2.get() - } - }) - } -} diff --git a/toml_testgen_test.go b/toml_testgen_test.go index b0d82cd..1be0d14 100644 --- a/toml_testgen_test.go +++ b/toml_testgen_test.go @@ -6,35 +6,30 @@ import ( ) func TestInvalidDatetimeMalformedNoLeads(t *testing.T) { - t.Parallel() input := `no-leads = 1987-7-05T17:45:00Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedNoSecs(t *testing.T) { - t.Parallel() input := `no-secs = 1987-07-05T17:45Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedNoT(t *testing.T) { - t.Parallel() input := `no-t = 1987-07-0517:45:00Z` testgenInvalid(t, input) } func TestInvalidDatetimeMalformedWithMilli(t *testing.T) { - t.Parallel() input := `with-milli = 1987-07-5T17:45:00.12Z` testgenInvalid(t, input) } func TestInvalidDuplicateKeyTable(t *testing.T) { - t.Parallel() input := `[fruit] type = "apple" @@ -45,7 +40,6 @@ apple = "yes"` } func TestInvalidDuplicateKeys(t *testing.T) { - t.Parallel() input := `dupe = false dupe = true` @@ -53,7 +47,6 @@ dupe = true` } func TestInvalidDuplicateTables(t *testing.T) { - t.Parallel() input := `[a] [a]` @@ -61,21 +54,18 @@ func TestInvalidDuplicateTables(t *testing.T) { } func TestInvalidEmptyImplicitTable(t *testing.T) { - t.Parallel() input := `[naughty..naughty]` testgenInvalid(t, input) } func TestInvalidEmptyTable(t *testing.T) { - t.Parallel() input := `[]` testgenInvalid(t, input) } func TestInvalidFloatNoLeadingZero(t *testing.T) { - t.Parallel() input := `answer = .12345 neganswer = -.12345` @@ -83,7 +73,6 @@ neganswer = -.12345` } func TestInvalidFloatNoTrailingDigits(t *testing.T) { - t.Parallel() input := `answer = 1. neganswer = -1.` @@ -91,21 +80,18 @@ neganswer = -1.` } func TestInvalidKeyEmpty(t *testing.T) { - t.Parallel() input := ` = 1` testgenInvalid(t, input) } func TestInvalidKeyHash(t *testing.T) { - t.Parallel() input := `a# = 1` testgenInvalid(t, input) } func TestInvalidKeyNewline(t *testing.T) { - t.Parallel() input := `a = 1` @@ -113,28 +99,24 @@ func TestInvalidKeyNewline(t *testing.T) { } func TestInvalidKeyOpenBracket(t *testing.T) { - t.Parallel() input := `[abc = 1` testgenInvalid(t, input) } func TestInvalidKeySingleOpenBracket(t *testing.T) { - t.Parallel() input := `[` testgenInvalid(t, input) } func TestInvalidKeySpace(t *testing.T) { - t.Parallel() input := `a b = 1` testgenInvalid(t, input) } func TestInvalidKeyStartBracket(t *testing.T) { - t.Parallel() input := `[a] [xyz = 5 @@ -143,42 +125,36 @@ func TestInvalidKeyStartBracket(t *testing.T) { } func TestInvalidKeyTwoEquals(t *testing.T) { - t.Parallel() input := `key= = 1` testgenInvalid(t, input) } func TestInvalidStringBadByteEscape(t *testing.T) { - t.Parallel() input := `naughty = "\xAg"` testgenInvalid(t, input) } func TestInvalidStringBadEscape(t *testing.T) { - t.Parallel() input := `invalid-escape = "This string has a bad \a escape character."` testgenInvalid(t, input) } func TestInvalidStringByteEscapes(t *testing.T) { - t.Parallel() input := `answer = "\x33"` testgenInvalid(t, input) } func TestInvalidStringNoClose(t *testing.T) { - t.Parallel() input := `no-ending-quote = "One time, at band camp` testgenInvalid(t, input) } func TestInvalidTableArrayImplicit(t *testing.T) { - t.Parallel() input := "# This test is a bit tricky. It should fail because the first use of\n" + "# `[[albums.songs]]` without first declaring `albums` implies that `albums`\n" + @@ -198,7 +174,6 @@ func TestInvalidTableArrayImplicit(t *testing.T) { } func TestInvalidTableArrayMalformedBracket(t *testing.T) { - t.Parallel() input := `[[albums] name = "Born to Run"` @@ -206,7 +181,6 @@ name = "Born to Run"` } func TestInvalidTableArrayMalformedEmpty(t *testing.T) { - t.Parallel() input := `[[]] name = "Born to Run"` @@ -214,14 +188,12 @@ name = "Born to Run"` } func TestInvalidTableEmpty(t *testing.T) { - t.Parallel() input := `[]` testgenInvalid(t, input) } func TestInvalidTableNestedBracketsClose(t *testing.T) { - t.Parallel() input := `[a]b] zyx = 42` @@ -229,7 +201,6 @@ zyx = 42` } func TestInvalidTableNestedBracketsOpen(t *testing.T) { - t.Parallel() input := `[a[b] zyx = 42` @@ -237,14 +208,12 @@ zyx = 42` } func TestInvalidTableWhitespace(t *testing.T) { - t.Parallel() input := `[invalid key]` testgenInvalid(t, input) } func TestInvalidTableWithPound(t *testing.T) { - t.Parallel() input := `[key#group] answer = 42` @@ -252,7 +221,6 @@ answer = 42` } func TestInvalidTextAfterArrayEntries(t *testing.T) { - t.Parallel() input := `array = [ "Is there life after an array separator?", No @@ -262,28 +230,24 @@ func TestInvalidTextAfterArrayEntries(t *testing.T) { } func TestInvalidTextAfterInteger(t *testing.T) { - t.Parallel() input := `answer = 42 the ultimate answer?` testgenInvalid(t, input) } func TestInvalidTextAfterString(t *testing.T) { - t.Parallel() input := `string = "Is there life after strings?" No.` testgenInvalid(t, input) } func TestInvalidTextAfterTable(t *testing.T) { - t.Parallel() input := `[error] this shouldn't be here` testgenInvalid(t, input) } func TestInvalidTextBeforeArraySeparator(t *testing.T) { - t.Parallel() input := `array = [ "Is there life before an array separator?" No, @@ -293,7 +257,6 @@ func TestInvalidTextBeforeArraySeparator(t *testing.T) { } func TestInvalidTextInArray(t *testing.T) { - t.Parallel() input := `array = [ "Entry 1", @@ -304,7 +267,6 @@ func TestInvalidTextInArray(t *testing.T) { } func TestValidArrayEmpty(t *testing.T) { - t.Parallel() input := `thevoid = [[[[[]]]]]` jsonRef := `{ @@ -322,7 +284,6 @@ func TestValidArrayEmpty(t *testing.T) { } func TestValidArrayNospaces(t *testing.T) { - t.Parallel() input := `ints = [1,2,3]` jsonRef := `{ @@ -339,7 +300,6 @@ func TestValidArrayNospaces(t *testing.T) { } func TestValidArraysHetergeneous(t *testing.T) { - t.Parallel() input := `mixed = [[1, 2], ["a", "b"], [1.1, 2.1]]` jsonRef := `{ @@ -365,7 +325,6 @@ func TestValidArraysHetergeneous(t *testing.T) { } func TestValidArraysNested(t *testing.T) { - t.Parallel() input := `nest = [["a"], ["b"]]` jsonRef := `{ @@ -385,7 +344,6 @@ func TestValidArraysNested(t *testing.T) { } func TestValidArrays(t *testing.T) { - t.Parallel() input := `ints = [1, 2, 3] floats = [1.1, 2.1, 3.1] @@ -433,7 +391,6 @@ dates = [ } func TestValidBool(t *testing.T) { - t.Parallel() input := `t = true f = false` @@ -445,7 +402,6 @@ f = false` } func TestValidCommentsEverywhere(t *testing.T) { - t.Parallel() input := `# Top comment. # Top comment. @@ -487,7 +443,6 @@ more = [ # Comment } func TestValidDatetime(t *testing.T) { - t.Parallel() input := `bestdayever = 1987-07-05T17:45:00Z` jsonRef := `{ @@ -497,7 +452,6 @@ func TestValidDatetime(t *testing.T) { } func TestValidEmpty(t *testing.T) { - t.Parallel() input := `` jsonRef := `{}` @@ -505,7 +459,6 @@ func TestValidEmpty(t *testing.T) { } func TestValidExample(t *testing.T) { - t.Parallel() input := `best-day-ever = 1987-07-05T17:45:00Z @@ -530,7 +483,6 @@ perfection = [6, 28, 496]` } func TestValidFloat(t *testing.T) { - t.Parallel() input := `pi = 3.14 negpi = -3.14` @@ -542,7 +494,6 @@ negpi = -3.14` } func TestValidImplicitAndExplicitAfter(t *testing.T) { - t.Parallel() input := `[a.b.c] answer = 42 @@ -563,7 +514,6 @@ better = 43` } func TestValidImplicitAndExplicitBefore(t *testing.T) { - t.Parallel() input := `[a] better = 43 @@ -584,7 +534,6 @@ answer = 42` } func TestValidImplicitGroups(t *testing.T) { - t.Parallel() input := `[a.b.c] answer = 42` @@ -601,7 +550,6 @@ answer = 42` } func TestValidInteger(t *testing.T) { - t.Parallel() input := `answer = 42 neganswer = -42` @@ -613,7 +561,6 @@ neganswer = -42` } func TestValidKeyEqualsNospace(t *testing.T) { - t.Parallel() input := `answer=42` jsonRef := `{ @@ -623,7 +570,6 @@ func TestValidKeyEqualsNospace(t *testing.T) { } func TestValidKeySpace(t *testing.T) { - t.Parallel() input := `"a b" = 1` jsonRef := `{ @@ -633,7 +579,6 @@ func TestValidKeySpace(t *testing.T) { } func TestValidKeySpecialChars(t *testing.T) { - t.Parallel() input := "\"~!@$^&*()_+-`1234567890[]|/?><.,;:'\" = 1\n" jsonRef := "{\n" + @@ -645,7 +590,6 @@ func TestValidKeySpecialChars(t *testing.T) { } func TestValidLongFloat(t *testing.T) { - t.Parallel() input := `longpi = 3.141592653589793 neglongpi = -3.141592653589793` @@ -657,7 +601,6 @@ neglongpi = -3.141592653589793` } func TestValidLongInteger(t *testing.T) { - t.Parallel() input := `answer = 9223372036854775807 neganswer = -9223372036854775808` @@ -669,7 +612,6 @@ neganswer = -9223372036854775808` } func TestValidMultilineString(t *testing.T) { - t.Parallel() input := `multiline_empty_one = """""" multiline_empty_two = """ @@ -728,7 +670,6 @@ equivalent_three = """\ } func TestValidRawMultilineString(t *testing.T) { - t.Parallel() input := `oneline = '''This string has a ' quote character.''' firstnl = ''' @@ -757,7 +698,6 @@ in it.'''` } func TestValidRawString(t *testing.T) { - t.Parallel() input := `backspace = 'This string has a \b backspace character.' tab = 'This string has a \t tab character.' @@ -800,7 +740,6 @@ backslash = 'This string has a \\ backslash character.'` } func TestValidStringEmpty(t *testing.T) { - t.Parallel() input := `answer = ""` jsonRef := `{ @@ -813,7 +752,6 @@ func TestValidStringEmpty(t *testing.T) { } func TestValidStringEscapes(t *testing.T) { - t.Parallel() input := `backspace = "This string has a \b backspace character." tab = "This string has a \t tab character." @@ -876,7 +814,6 @@ notunicode4 = "This string does not have a unicode \\\u0075 escape."` } func TestValidStringSimple(t *testing.T) { - t.Parallel() input := `answer = "You are not drinking enough whisky."` jsonRef := `{ @@ -889,7 +826,6 @@ func TestValidStringSimple(t *testing.T) { } func TestValidStringWithPound(t *testing.T) { - t.Parallel() input := `pound = "We see no # comments here." poundcomment = "But there are # some comments here." # Did I # mess you up?` @@ -904,7 +840,6 @@ poundcomment = "But there are # some comments here." # Did I # mess you up?` } func TestValidTableArrayImplicit(t *testing.T) { - t.Parallel() input := `[[albums.songs]] name = "Glory Days"` @@ -919,7 +854,6 @@ name = "Glory Days"` } func TestValidTableArrayMany(t *testing.T) { - t.Parallel() input := `[[people]] first_name = "Bruce" @@ -952,7 +886,6 @@ last_name = "Seger"` } func TestValidTableArrayNest(t *testing.T) { - t.Parallel() input := `[[albums]] name = "Born to Run" @@ -993,7 +926,6 @@ name = "Born in the USA" } func TestValidTableArrayOne(t *testing.T) { - t.Parallel() input := `[[people]] first_name = "Bruce" @@ -1010,7 +942,6 @@ last_name = "Springsteen"` } func TestValidTableEmpty(t *testing.T) { - t.Parallel() input := `[a]` jsonRef := `{ @@ -1020,7 +951,6 @@ func TestValidTableEmpty(t *testing.T) { } func TestValidTableSubEmpty(t *testing.T) { - t.Parallel() input := `[a] [a.b]` @@ -1031,7 +961,6 @@ func TestValidTableSubEmpty(t *testing.T) { } func TestValidTableWhitespace(t *testing.T) { - t.Parallel() input := `["valid key"]` jsonRef := `{ @@ -1041,7 +970,6 @@ func TestValidTableWhitespace(t *testing.T) { } func TestValidTableWithPound(t *testing.T) { - t.Parallel() input := `["key#group"] answer = 42` @@ -1054,7 +982,6 @@ answer = 42` } func TestValidUnicodeEscape(t *testing.T) { - t.Parallel() input := `answer4 = "\u03B4" answer8 = "\U000003B4"` @@ -1066,7 +993,6 @@ answer8 = "\U000003B4"` } func TestValidUnicodeLiteral(t *testing.T) { - t.Parallel() input := `answer = "δ"` jsonRef := `{ diff --git a/types.go b/types.go new file mode 100644 index 0000000..b5ec165 --- /dev/null +++ b/types.go @@ -0,0 +1,13 @@ +package toml + +import ( + "encoding" + "reflect" + "time" +) + +var timeType = reflect.TypeOf(time.Time{}) +var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() +var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) +var sliceInterfaceType = reflect.TypeOf([]interface{}{}) diff --git a/unmarshaler.go b/unmarshaler.go index fda9438..6dcb5c5 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -6,12 +6,14 @@ import ( "fmt" "io" "io/ioutil" + "math" "reflect" + "strings" + "sync" "time" "github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/tracker" - "github.com/pelletier/go-toml/v2/internal/unsafe" ) // Unmarshal deserializes a TOML document into a Go value. @@ -20,9 +22,9 @@ import ( func Unmarshal(data []byte, v interface{}) error { p := parser{} p.Reset(data) - d := decoder{} + d := decoder{p: &p} - return d.FromParser(&p, v) + return d.FromParser(v) } // Decoder reads and decode a TOML document from an input stream. @@ -93,16 +95,34 @@ func (d *Decoder) Decode(v interface{}) error { p := parser{} p.Reset(b) dec := decoder{ + p: &p, strict: strict{ Enabled: d.strict, }, } - return dec.FromParser(&p, v) + return dec.FromParser(v) } type decoder struct { + // Which parser instance in use for this decoding session. + // TODO: Think about removing later. + p *parser + + // Flag indicating that the current expression is stashed. + // If set to true, calling nextExpr will not actually pull a new expression + // but turn off the flag instead. + stashedExpr bool + + // Skip expressions until a table is found. This is set to true when a + // table could not be create (missing field in map), so all KV expressions + // need to be skipped. + skipUntilTable bool + // Tracks position in Go arrays. + // This is used when decoding [[array tables]] into Go arrays. Given array + // tables are separate TOML expression, we need to keep track of where we + // are at in the Go array, as we can't just introspect its size. arrayIndexes map[reflect.Value]int // Tracks keys that have been seen, with which type. @@ -112,6 +132,22 @@ type decoder struct { strict strict } +func (d *decoder) expr() ast.Node { + return d.p.Expression() +} + +func (d *decoder) nextExpr() bool { + if d.stashedExpr { + d.stashedExpr = false + return true + } + return d.p.NextExpression() +} + +func (d *decoder) stashExpr() { + d.stashedExpr = true +} + func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int { if d.arrayIndexes == nil { d.arrayIndexes = make(map[reflect.Value]int, 1) @@ -129,40 +165,7 @@ func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int { return idx } -func (d *decoder) FromParser(p *parser, v interface{}) error { - err := d.fromParser(p, v) - if err == nil { - return d.strict.Error(p.data) - } - - var e *decodeError - if errors.As(err, &e) { - return wrapDecodeError(p.data, e) - } - - return err -} - -func keyLocation(node ast.Node) []byte { - k := node.Key() - - hasOne := k.Next() - if !hasOne { - panic("should not be called with empty key") - } - - start := k.Node().Data - end := k.Node().Data - - for k.Next() { - end = k.Node().Data - } - - return unsafe.BytesRange(start, end) -} - -//nolint:funlen,cyclop -func (d *decoder) fromParser(p *parser, v interface{}) error { +func (d *decoder) FromParser(v interface{}) error { r := reflect.ValueOf(v) if r.Kind() != reflect.Ptr { return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind()) @@ -172,172 +175,361 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { return fmt.Errorf("toml: decoding pointer target cannot be nil") } - var ( - skipUntilTable bool - root target = valueTarget(r.Elem()) - ) + err := d.fromParser(r.Elem()) + if err == nil { + return d.strict.Error(d.p.data) + } - current := root + var e *decodeError + if errors.As(err, &e) { + return wrapDecodeError(d.p.data, e) + } - for p.NextExpression() { - node := p.Expression() + return err +} - if node.Kind == ast.KeyValue && skipUntilTable { - continue - } - - err := d.seen.CheckExpression(node) +func (d *decoder) fromParser(root reflect.Value) error { + for d.nextExpr() { + err := d.handleRootExpression(d.expr(), root) if err != nil { return err } + } - var found bool + return d.p.Error() +} - switch node.Kind { - case ast.KeyValue: - err = d.unmarshalKeyValue(current, node) - found = true - case ast.Table: - skipUntilTable = false - d.strict.EnterTable(node) +/* +Rules for the unmarshal code: - current, found, err = d.scopeWithKey(root, node.Key()) - if err == nil && found { - // In case this table points to an interface, - // make sure it at least holds something that - // looks like a table. Otherwise the information - // of a table is lost, and marshal cannot do the - // round trip. - ensureMapIfInterface(current) +- The stack is used to keep track of which values need to be set where. +- handle* functions <=> switch on a given ast.Kind. +- unmarshalX* functions need to unmarshal a node of kind X. +- An "object" is either a struct or a map. +*/ + +func (d *decoder) handleRootExpression(expr ast.Node, v reflect.Value) error { + var x reflect.Value + var err error + + if !(d.skipUntilTable && expr.Kind == ast.KeyValue) { + err = d.seen.CheckExpression(expr) + if err != nil { + return err + } + } + + switch expr.Kind { + case ast.KeyValue: + if d.skipUntilTable { + return nil + } + x, err = d.handleKeyValue(expr, v) + case ast.Table: + d.skipUntilTable = false + d.strict.EnterTable(expr) + x, err = d.handleTable(expr.Key(), v) + case ast.ArrayTable: + d.skipUntilTable = false + d.strict.EnterArrayTable(expr) + x, err = d.handleArrayTable(expr.Key(), v) + default: + panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind)) + } + + if d.skipUntilTable { + if expr.Kind == ast.Table || expr.Kind == ast.ArrayTable { + d.strict.MissingTable(expr) + } + } else if err == nil && x.IsValid() { + v.Set(x) + } + + return err +} + +func (d *decoder) handleArrayTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + if key.Next() { + return d.handleArrayTablePart(key, v) + } + return d.handleKeyValues(v) +} + +func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + switch v.Kind() { + case reflect.Interface: + elem := v.Elem() + if !elem.IsValid() { + elem = reflect.New(sliceInterfaceType).Elem() + elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16)) + } else if elem.Kind() == reflect.Slice { + if elem.Type() != sliceInterfaceType { + elem = reflect.New(sliceInterfaceType).Elem() + elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16)) + } else if !elem.CanSet() { + nelem := reflect.New(sliceInterfaceType).Elem() + nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap())) + reflect.Copy(nelem, elem) + elem = nelem } - case ast.ArrayTable: - skipUntilTable = false - d.strict.EnterArrayTable(node) - current, found, err = d.scopeWithArrayTable(root, node.Key()) - default: - panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind)) + } + return d.handleArrayTableCollectionLast(key, elem) + case reflect.Ptr: + elem := v.Elem() + if !elem.IsValid() { + ptr := reflect.New(v.Type().Elem()) + v.Set(ptr) + elem = ptr.Elem() } + elem, err := d.handleArrayTableCollectionLast(key, elem) if err != nil { - return err + return reflect.Value{}, err } + v.Elem().Set(elem) - if !found { - skipUntilTable = true - - d.strict.MissingTable(node) + return v, nil + case reflect.Slice: + elem := reflect.New(v.Type().Elem()).Elem() + elem2, err := d.handleArrayTable(key, elem) + if err != nil { + return reflect.Value{}, err } + if elem2.IsValid() { + elem = elem2 + } + return reflect.Append(v, elem), nil + case reflect.Array: + idx := d.arrayIndex(true, v) + if idx >= v.Len() { + return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) + } + elem := v.Index(idx) + _, err := d.handleArrayTable(key, elem) + return v, err } - return p.Error() + return d.handleArrayTable(key, v) } -// scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node. -// -// The goal is to hop from target to target recursively using the names in key. -// Parts of the key should be used to resolve field names for structs, and as -// keys when targeting maps. -// -// When encountering slices, it should always use its last element, and error -// if the slice does not have any. -func (d *decoder) scopeWithKey(x target, key ast.Iterator) (target, bool, error) { - var ( - err error - found bool - ) - - for key.Next() { - n := key.Node() - - x, found, err = d.scopeTableTarget(false, x, string(n.Data)) - if err != nil || !found { - return nil, found, err - } - } - - return x, true, nil -} - -//nolint:cyclop -// scopeWithArrayTable performs target scoping when unmarshaling an -// ast.ArrayTable node. -// -// It is the same as scopeWithKey, but when scoping the last part of the key -// it creates a new element in the array instead of using the last one. -func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool, error) { - var ( - err error - found bool - ) - - for key.Next() { - n := key.Node() - if !n.Next().Valid() { // want to stop at one before last - break - } - - x, found, err = d.scopeTableTarget(false, x, string(n.Data)) - if err != nil || !found { - return nil, found, err - } - } - - n := key.Node() - - x, found, err = d.scopeTableTarget(false, x, string(n.Data)) - if err != nil || !found { - return x, found, err - } - - v := x.get() - - if v.Kind() == reflect.Ptr { - x = scopePtr(x) - v = x.get() - } - - if v.Kind() == reflect.Interface { - x = scopeInterface(true, x) - v = x.get() +// When parsing an array table expression, each part of the key needs to be +// evaluated like a normal key, but if it returns a collection, it also needs to +// point to the last element of the collection. Unless it is the last part of +// the key, then it needs to create a new element at the end. +func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + if key.IsLast() { + return d.handleArrayTableCollectionLast(key, v) } switch v.Kind() { + case reflect.Ptr: + elem := v.Elem() + if !elem.IsValid() { + ptr := reflect.New(v.Type().Elem()) + v.Set(ptr) + elem = ptr.Elem() + } + + elem, err := d.handleArrayTableCollection(key, elem) + if err != nil { + return reflect.Value{}, err + } + v.Elem().Set(elem) + + return v, nil case reflect.Slice: - x = scopeSlice(true, x) + elem := v.Index(v.Len() - 1) + x, err := d.handleArrayTable(key, elem) + if err != nil || d.skipUntilTable { + return reflect.Value{}, err + } + if x.IsValid() { + elem.Set(x) + } + + return v, err case reflect.Array: - x, err = d.scopeArray(true, x) + idx := d.arrayIndex(false, v) + if idx >= v.Len() { + return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) + } + elem := v.Index(idx) + _, err := d.handleArrayTable(key, elem) + return v, err + } + + return d.handleArrayTable(key, v) +} + +func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) { + var rv reflect.Value + + // First, dispatch over v to make sure it is a valid object. + // There is no guarantee over what it could be. + switch v.Kind() { + case reflect.Map: + // Create the key for the map element. For now assume it's a string. + mk := reflect.ValueOf(string(key.Node().Data)) + + // If the map does not exist, create it. + if v.IsNil() { + v = reflect.MakeMap(v.Type()) + rv = v + } + + mv := v.MapIndex(mk) + set := false + if !mv.IsValid() { + // If there is no value in the map, create a new one according to + // the map type. If the element type is interface, create either a + // map[string]interface{} or a []interface{} depending on whether + // this is the last part of the array table key. + + t := v.Type().Elem() + if t.Kind() == reflect.Interface { + mv = makeFn() + } else { + mv = reflect.New(t).Elem() + } + set = true + } else if mv.Kind() == reflect.Interface { + mv = mv.Elem() + if !mv.IsValid() { + mv = makeFn() + } + set = true + } + + x, err := nextFn(key, mv) + if err != nil { + return reflect.Value{}, err + } + + if x.IsValid() { + mv = x + set = true + } + + if set { + v.SetMapIndex(mk, mv) + } + case reflect.Struct: + f, found := structField(v, string(key.Node().Data)) + if !found { + d.skipUntilTable = true + return reflect.Value{}, nil + } + + x, err := nextFn(key, f) + if err != nil || d.skipUntilTable { + return reflect.Value{}, err + } + if x.IsValid() { + f.Set(x) + } + case reflect.Interface: + if v.Elem().IsValid() { + v = v.Elem() + } else { + v = reflect.MakeMap(mapStringInterfaceType) + } + + x, err := d.handleKeyPart(key, v, nextFn, makeFn) + if err != nil { + return reflect.Value{}, err + } + if x.IsValid() { + v = x + } + rv = v default: + panic(fmt.Errorf("unhandled part: %s", v.Kind())) } - return x, found, err + return rv, nil } -func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error { - assertNode(ast.KeyValue, node) - - d.strict.EnterKeyValue(node) - defer d.strict.ExitKeyValue(node) - - x, found, err := d.scopeWithKey(x, node.Key()) - if err != nil { - return err +// HandleArrayTablePart navigates the Go structure v using the key v. It is +// only used for the prefix (non-last) parts of an array-table. When +// encountering a collection, it should go to the last element. +func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + var makeFn valueMakerFn + if key.IsLast() { + makeFn = makeSliceInterface + } else { + makeFn = makeMapStringInterface } - - // A struct in the path was not found. Skip this value. - if !found { - d.strict.MissingField(node) - - return nil - } - - return d.unmarshalValue(x, node.Value()) + return d.handleKeyPart(key, v, d.handleArrayTableCollection, makeFn) } -var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() +// HandleTable returns a reference when it has checked the next expression but +// cannot handle it. +func (d *decoder) handleTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + if v.Kind() == reflect.Slice { + elem := v.Index(v.Len() - 1) + x, err := d.handleTable(key, elem) + if err != nil { + return reflect.Value{}, err + } + if x.IsValid() { + elem.Set(x) + } + return reflect.Value{}, nil + } + if key.Next() { + // Still scoping the key + return d.handleTablePart(key, v) + } + // Done scoping the key. + // Now handle all the key-value expressions in this table. + return d.handleKeyValues(v) +} -func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { - v := x.get() +// Handle root expressions until the end of the document or the next +// non-key-value. +func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { + var rv reflect.Value + for d.nextExpr() { + expr := d.expr() + if expr.Kind != ast.KeyValue { + // Stash the expression so that fromParser can just loop and use + // the right handler. + // We could just recurse ourselves here, but at least this gives a + // chance to pop the stack a bit. + d.stashExpr() + break + } + x, err := d.handleKeyValue(expr, v) + if err != nil { + return reflect.Value{}, err + } + if x.IsValid() { + v = x + rv = x + } + } + return rv, nil +} + +type ( + handlerFn func(key ast.Iterator, v reflect.Value) (reflect.Value, error) + valueMakerFn func() reflect.Value +) + +func makeMapStringInterface() reflect.Value { + return reflect.MakeMap(mapStringInterfaceType) +} + +func makeSliceInterface() reflect.Value { + return reflect.MakeSlice(sliceInterfaceType, 0, 16) +} + +func (d *decoder) handleTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { + return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface) +} + +func tryTextUnmarshaler(node ast.Node, v reflect.Value) (bool, error) { if v.Kind() != reflect.Struct { return false, nil } @@ -348,19 +540,12 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { return false, nil } - if v.Type().Implements(textUnmarshalerType) { - err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) - if err != nil { - return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) - } - - return true, nil - } - if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) if err != nil { - return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) + return false, fmt.Errorf("toml: error calling UnmarshalText: %w", err) + // TODO: same as above + // return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) } return true, nil @@ -369,65 +554,177 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { return false, nil } -//nolint:cyclop -func (d *decoder) unmarshalValue(x target, node ast.Node) error { - v := x.get() - - if v.Kind() == reflect.Ptr { - if !v.Elem().IsValid() { - x.set(reflect.New(v.Type().Elem())) - v = x.get() - } - - return d.unmarshalValue(valueTarget(v.Elem()), node) +func (d *decoder) handleValue(value ast.Node, v reflect.Value) error { + for v.Kind() == reflect.Ptr { + v = initAndDereferencePointer(v) } - ok, err := tryTextUnmarshaler(x, node) - if ok { + ok, err := tryTextUnmarshaler(value, v) + if ok || err != nil { return err } - switch node.Kind { + switch value.Kind { case ast.String: - return unmarshalString(x, node) - case ast.Bool: - return unmarshalBool(x, node) + return d.unmarshalString(value, v) case ast.Integer: - return unmarshalInteger(x, node) + return d.unmarshalInteger(value, v) case ast.Float: - return unmarshalFloat(x, node) - case ast.Array: - return d.unmarshalArray(x, node) - case ast.InlineTable: - return d.unmarshalInlineTable(x, node) - case ast.LocalDateTime: - return unmarshalLocalDateTime(x, node) + return d.unmarshalFloat(value, v) + case ast.Bool: + return d.unmarshalBool(value, v) case ast.DateTime: - return unmarshalDateTime(x, node) + return d.unmarshalDateTime(value, v) case ast.LocalDate: - return unmarshalLocalDate(x, node) + return d.unmarshalLocalDate(value, v) + case ast.LocalDateTime: + return d.unmarshalLocalDateTime(value, v) + case ast.InlineTable: + return d.unmarshalInlineTable(value, v) + case ast.Array: + return d.unmarshalArray(value, v) default: - panic(fmt.Sprintf("unhandled node kind %s", node.Kind)) + panic(fmt.Errorf("handleValue not implemented for %s", value.Kind)) } } -func unmarshalLocalDate(x target, node ast.Node) error { - assertNode(ast.LocalDate, node) - - v, err := parseLocalDate(node.Data) - if err != nil { - return err +func (d *decoder) unmarshalArray(array ast.Node, v reflect.Value) error { + switch v.Kind() { + case reflect.Slice: + if v.IsNil() { + v.Set(reflect.MakeSlice(v.Type(), 0, 16)) + } else { + v.SetLen(0) + } + case reflect.Array: + // arrays are always initialized + case reflect.Interface: + elem := v.Elem() + if !elem.IsValid() { + elem = reflect.New(sliceInterfaceType).Elem() + elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16)) + } else if elem.Kind() == reflect.Slice { + if elem.Type() != sliceInterfaceType { + elem = reflect.New(sliceInterfaceType).Elem() + elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16)) + } else if !elem.CanSet() { + nelem := reflect.New(sliceInterfaceType).Elem() + nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap())) + reflect.Copy(nelem, elem) + elem = nelem + } + } + err := d.unmarshalArray(array, elem) + if err != nil { + return err + } + v.Set(elem) + return nil + default: + // TODO: use newDecodeError, but first the parser needs to fill + // array.Data. + return fmt.Errorf("toml: cannot store array in Go type %s", v.Kind()) } - setDate(x, v) + elemType := v.Type().Elem() + + it := array.Children() + idx := 0 + for it.Next() { + n := it.Node() + + // TODO: optimize + if v.Kind() == reflect.Slice { + elem := reflect.New(elemType).Elem() + + err := d.handleValue(n, elem) + if err != nil { + return err + } + + v.Set(reflect.Append(v, elem)) + } else { // array + if idx >= v.Len() { + return nil + } + elem := v.Index(idx) + err := d.handleValue(n, elem) + if err != nil { + return err + } + idx++ + } + } return nil } -func unmarshalLocalDateTime(x target, node ast.Node) error { - assertNode(ast.LocalDateTime, node) +func (d *decoder) unmarshalInlineTable(itable ast.Node, v reflect.Value) error { + // Make sure v is an initialized object. + switch v.Kind() { + case reflect.Map: + if v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + case reflect.Struct: + // structs are always initialized. + case reflect.Interface: + elem := v.Elem() + if !elem.IsValid() { + elem = reflect.MakeMap(mapStringInterfaceType) + v.Set(elem) + } + return d.unmarshalInlineTable(itable, elem) + default: + return newDecodeError(itable.Data, "cannot store inline table in Go type %s", v.Kind()) + } - v, rest, err := parseLocalDateTime(node.Data) + it := itable.Children() + for it.Next() { + n := it.Node() + + x, err := d.handleKeyValue(n, v) + if err != nil { + return err + } + if x.IsValid() { + v = x + } + } + + return nil +} + +func (d *decoder) unmarshalDateTime(value ast.Node, v reflect.Value) error { + dt, err := parseDateTime(value.Data) + if err != nil { + return err + } + + v.Set(reflect.ValueOf(dt)) + return nil +} + +func (d *decoder) unmarshalLocalDate(value ast.Node, v reflect.Value) error { + ld, err := parseLocalDate(value.Data) + if err != nil { + return err + } + + if v.Type() == timeType { + cast := ld.In(time.Local) + + v.Set(reflect.ValueOf(cast)) + return nil + } + + v.Set(reflect.ValueOf(ld)) + + return nil +} + +func (d *decoder) unmarshalLocalDateTime(value ast.Node, v reflect.Value) error { + ldt, rest, err := parseLocalDateTime(value.Data) if err != nil { return err } @@ -436,160 +733,366 @@ func unmarshalLocalDateTime(x target, node ast.Node) error { return newDecodeError(rest, "extra characters at the end of a local date time") } - setLocalDateTime(x, v) + if v.Type() == timeType { + cast := ldt.In(time.Local) + + v.Set(reflect.ValueOf(cast)) + return nil + } + + v.Set(reflect.ValueOf(ldt)) return nil } -func unmarshalDateTime(x target, node ast.Node) error { - assertNode(ast.DateTime, node) +func (d *decoder) unmarshalBool(value ast.Node, v reflect.Value) error { + b := value.Data[0] == 't' - v, err := parseDateTime(node.Data) - if err != nil { - return err + switch v.Kind() { + case reflect.Bool: + v.SetBool(b) + case reflect.Interface: + v.Set(reflect.ValueOf(b)) + default: + return newDecodeError(value.Data, "cannot assign boolean to a %t", b) } - setDateTime(x, v) - return nil } -func setLocalDateTime(x target, v LocalDateTime) { - if x.get().Type() == timeType { - cast := v.In(time.Local) - - setDateTime(x, cast) - return - } - - x.set(reflect.ValueOf(v)) -} - -func setDateTime(x target, v time.Time) { - x.set(reflect.ValueOf(v)) -} - -var timeType = reflect.TypeOf(time.Time{}) - -func setDate(x target, v LocalDate) { - if x.get().Type() == timeType { - cast := v.In(time.Local) - - setDateTime(x, cast) - return - } - - x.set(reflect.ValueOf(v)) -} - -func unmarshalString(x target, node ast.Node) error { - assertNode(ast.String, node) - - return setString(x, string(node.Data)) -} - -func unmarshalBool(x target, node ast.Node) error { - assertNode(ast.Bool, node) - v := node.Data[0] == 't' - - return setBool(x, v) -} - -func unmarshalInteger(x target, node ast.Node) error { - assertNode(ast.Integer, node) - - v, err := parseInteger(node.Data) +func (d *decoder) unmarshalFloat(value ast.Node, v reflect.Value) error { + f, err := parseFloat(value.Data) if err != nil { return err } - return setInt64(x, v) -} - -func unmarshalFloat(x target, node ast.Node) error { - assertNode(ast.Float, node) - - v, err := parseFloat(node.Data) - if err != nil { - return err - } - - return setFloat64(x, v) -} - -func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error { - assertNode(ast.InlineTable, node) - - ensureMapIfInterface(x) - - it := node.Children() - for it.Next() { - n := it.Node() - - err := d.unmarshalKeyValue(x, n) - if err != nil { - return err + switch v.Kind() { + case reflect.Float64: + v.SetFloat(f) + case reflect.Float32: + if f > math.MaxFloat32 { + return newDecodeError(value.Data, "number %f does not fit in a float32", f) } + v.SetFloat(f) + case reflect.Interface: + v.Set(reflect.ValueOf(f)) + default: + return newDecodeError(value.Data, "float cannot be assigned to %s", v.Kind()) } return nil } -func (d *decoder) unmarshalArray(x target, node ast.Node) error { - assertNode(ast.Array, node) +func (d *decoder) unmarshalInteger(value ast.Node, v reflect.Value) error { + const ( + maxInt = int64(^uint(0) >> 1) + minInt = -maxInt - 1 + ) - err := ensureValueIndexable(x) + i, err := parseInteger(value.Data) if err != nil { return err } - // Special work around when unmarshaling into an array. - // If the array is not addressable, for example when stored as a value in a - // map, calling elementAt in the inner function would fail. - // Instead, we allocate a new array that will be filled then inserted into - // the container. - // This problem does not exist with slices because they are addressable. - // There may be a better way of doing this, but it is not obvious to me - // with the target system. - if x.get().Kind() == reflect.Array { - container := x - newArrayPtr := reflect.New(x.get().Type()) - x = valueTarget(newArrayPtr.Elem()) - defer func() { - container.set(newArrayPtr.Elem()) - }() + switch v.Kind() { + case reflect.Int64: + v.SetInt(i) + case reflect.Int32: + if i < math.MinInt32 || i > math.MaxInt32 { + return fmt.Errorf("toml: number %d does not fit in an int32", i) + } + + v.Set(reflect.ValueOf(int32(i))) + return nil + case reflect.Int16: + if i < math.MinInt16 || i > math.MaxInt16 { + return fmt.Errorf("toml: number %d does not fit in an int16", i) + } + + v.Set(reflect.ValueOf(int16(i))) + case reflect.Int8: + if i < math.MinInt8 || i > math.MaxInt8 { + return fmt.Errorf("toml: number %d does not fit in an int8", i) + } + + v.Set(reflect.ValueOf(int8(i))) + case reflect.Int: + if i < minInt || i > maxInt { + return fmt.Errorf("toml: number %d does not fit in an int", i) + } + + v.Set(reflect.ValueOf(int(i))) + case reflect.Uint64: + if i < 0 { + return fmt.Errorf("toml: negative number %d does not fit in an uint64", i) + } + + v.Set(reflect.ValueOf(uint64(i))) + case reflect.Uint32: + if i < 0 || i > math.MaxUint32 { + return fmt.Errorf("toml: negative number %d does not fit in an uint32", i) + } + + v.Set(reflect.ValueOf(uint32(i))) + case reflect.Uint16: + if i < 0 || i > math.MaxUint16 { + return fmt.Errorf("toml: negative number %d does not fit in an uint16", i) + } + + v.Set(reflect.ValueOf(uint16(i))) + case reflect.Uint8: + if i < 0 || i > math.MaxUint8 { + return fmt.Errorf("toml: negative number %d does not fit in an uint8", i) + } + + v.Set(reflect.ValueOf(uint8(i))) + case reflect.Uint: + if i < 0 { + return fmt.Errorf("toml: negative number %d does not fit in an uint", i) + } + + v.Set(reflect.ValueOf(uint(i))) + case reflect.Interface: + v.Set(reflect.ValueOf(i)) + default: + err = fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) } - return d.unmarshalArrayInner(x, node) + return err } -func (d *decoder) unmarshalArrayInner(x target, node ast.Node) error { - idx := 0 +func (d *decoder) unmarshalString(value ast.Node, v reflect.Value) error { + var err error - it := node.Children() - for it.Next() { - n := it.Node() + switch v.Kind() { + case reflect.String: + v.SetString(string(value.Data)) + case reflect.Interface: + v.Set(reflect.ValueOf(string(value.Data))) + default: + err = fmt.Errorf("toml: cannot store TOML string into a Go %s", v.Kind()) + } - v := elementAt(x, idx) + return err +} - if v == nil { - // when we go out of bound for an array just stop processing it to - // mimic encoding/json +func (d *decoder) handleKeyValue(expr ast.Node, v reflect.Value) (reflect.Value, error) { + d.strict.EnterKeyValue(expr) + + v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v) + if d.skipUntilTable { + d.strict.MissingField(expr) + d.skipUntilTable = false + } + + d.strict.ExitKeyValue(expr) + + return v, err +} + +func (d *decoder) handleKeyValueInner(key ast.Iterator, value ast.Node, v reflect.Value) (reflect.Value, error) { + if key.Next() { + // Still scoping the key + return d.handleKeyValuePart(key, value, v) + } + // Done scoping the key. + // v is whatever Go value we need to fill. + return reflect.Value{}, d.handleValue(value, v) +} + +func (d *decoder) handleKeyValuePart(key ast.Iterator, value ast.Node, v reflect.Value) (reflect.Value, error) { + // contains the replacement for v + var rv reflect.Value + + // First, dispatch over v to make sure it is a valid object. + // There is no guarantee over what it could be. + switch v.Kind() { + case reflect.Map: + mk := reflect.ValueOf(string(key.Node().Data)) + + keyType := v.Type().Key() + if !mk.Type().AssignableTo(keyType) { + if !mk.Type().ConvertibleTo(keyType) { + return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mk.Type(), keyType) + } + + mk = mk.Convert(keyType) + } + + // If the map does not exist, create it. + if v.IsNil() { + v = reflect.MakeMap(v.Type()) + rv = v + } + + mv := v.MapIndex(mk) + set := false + if !mv.IsValid() { + set = true + mv = reflect.New(v.Type().Elem()).Elem() + } else { + if key.IsLast() { + var x interface{} + mv = reflect.ValueOf(&x).Elem() + set = true + } + } + + nv, err := d.handleKeyValueInner(key, value, mv) + if err != nil { + return reflect.Value{}, err + } + if nv.IsValid() { + mv = nv + set = true + } + + if set { + v.SetMapIndex(mk, mv) + } + case reflect.Struct: + f, found := structField(v, string(key.Node().Data)) + if !found { + d.skipUntilTable = true break } - err := d.unmarshalValue(v, n) + x, err := d.handleKeyValueInner(key, value, f) if err != nil { - return err + return reflect.Value{}, err } - idx++ + if x.IsValid() { + f.Set(x) + } + case reflect.Interface: + v = v.Elem() + + // Following encoding/toml: decoding an object into an interface{}, it + // needs to always hold a map[string]interface{}. This is for the types + // to be consistent whether a previous value was set or not. + if !v.IsValid() || v.Type() != mapStringInterfaceType { + v = reflect.MakeMap(mapStringInterfaceType) + } + + x, err := d.handleKeyValuePart(key, value, v) + if err != nil { + return reflect.Value{}, err + } + if x.IsValid() { + v = x + } + rv = v + case reflect.Ptr: + elem := v.Elem() + if !elem.IsValid() { + ptr := reflect.New(v.Type().Elem()) + v.Set(ptr) + rv = v + elem = ptr.Elem() + } + + elem2, err := d.handleKeyValuePart(key, value, elem) + if err != nil { + return reflect.Value{}, err + } + if elem2.IsValid() { + elem = elem2 + } + v.Elem().Set(elem) + default: + return reflect.Value{}, fmt.Errorf("unhandled kv part: %s", v.Kind()) } - return nil + + return rv, nil } -func assertNode(expected ast.Kind, node ast.Node) { - if node.Kind != expected { - panic(fmt.Sprintf("expected node of kind %s, not %s", expected, node.Kind)) +func initAndDereferencePointer(v reflect.Value) reflect.Value { + var elem reflect.Value + if v.IsNil() { + ptr := reflect.New(v.Type().Elem()) + v.Set(ptr) } + elem = v.Elem() + return elem +} + +type fieldPathsMap = map[string][]int + +type fieldPathsCache struct { + m map[reflect.Type]fieldPathsMap + l sync.RWMutex +} + +func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) { + c.l.RLock() + paths, ok := c.m[t] + c.l.RUnlock() + + return paths, ok +} + +func (c *fieldPathsCache) set(t reflect.Type, m fieldPathsMap) { + c.l.Lock() + c.m[t] = m + c.l.Unlock() +} + +var globalFieldPathsCache = fieldPathsCache{ + m: map[reflect.Type]fieldPathsMap{}, + l: sync.RWMutex{}, +} + +func structField(v reflect.Value, name string) (reflect.Value, bool) { + //nolint:godox + // TODO: cache this, and reduce allocations + fieldPaths, ok := globalFieldPathsCache.get(v.Type()) + if !ok { + fieldPaths = map[string][]int{} + + path := make([]int, 0, 16) + + var walk func(reflect.Value) + walk = func(v reflect.Value) { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + l := len(path) + path = append(path, i) + f := t.Field(i) + + if f.Anonymous { + walk(v.Field(i)) + } else if f.PkgPath == "" { + // only consider exported fields + fieldName, ok := f.Tag.Lookup("toml") + if !ok { + fieldName = f.Name + } + + pathCopy := make([]int, len(path)) + copy(pathCopy, path) + + fieldPaths[fieldName] = pathCopy + // extra copy for the case-insensitive match + fieldPaths[strings.ToLower(fieldName)] = pathCopy + } + path = path[:l] + } + } + + walk(v) + + globalFieldPathsCache.set(v.Type(), fieldPaths) + } + + path, ok := fieldPaths[name] + if !ok { + path, ok = fieldPaths[strings.ToLower(name)] + } + + if !ok { + return reflect.Value{}, false + } + + return v.FieldByIndex(path), true } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index d9203a6..f3deb23 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -14,10 +14,23 @@ import ( "github.com/stretchr/testify/require" ) +type badReader struct{} + +func (r *badReader) Read([]byte) (int, error) { + return 0, fmt.Errorf("testing error") +} + +func TestDecodeReaderError(t *testing.T) { + r := &badReader{} + + dec := toml.NewDecoder(r) + m := map[string]interface{}{} + err := dec.Decode(&m) + require.Error(t, err) +} + // nolint:funlen func TestUnmarshal_Integers(t *testing.T) { - t.Parallel() - examples := []struct { desc string input string @@ -88,8 +101,6 @@ func TestUnmarshal_Integers(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - doc := doc{} err := toml.Unmarshal([]byte(`A = `+e.input), &doc) if e.err { @@ -104,8 +115,6 @@ func TestUnmarshal_Integers(t *testing.T) { //nolint:funlen func TestUnmarshal_Floats(t *testing.T) { - t.Parallel() - examples := []struct { desc string input string @@ -197,8 +206,6 @@ func TestUnmarshal_Floats(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - doc := doc{} err := toml.Unmarshal([]byte(`A = `+e.input), &doc) require.NoError(t, err) @@ -213,8 +220,6 @@ func TestUnmarshal_Floats(t *testing.T) { //nolint:funlen func TestUnmarshal(t *testing.T) { - t.Parallel() - type test struct { target interface{} expected interface{} @@ -814,6 +819,240 @@ B = "data"`, } }, }, + { + desc: "array table into interface in struct", + input: `[[foo]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + Foo: []interface{}{ + map[string]interface{}{ + "bar": "hello", + }, + }, + }, + } + }, + }, + { + desc: "array table into interface in struct already initialized with right type", + input: `[[foo]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo interface{} + } + return test{ + target: &doc{ + Foo: []interface{}{}, + }, + expected: &doc{ + Foo: []interface{}{ + map[string]interface{}{ + "bar": "hello", + }, + }, + }, + } + }, + }, + { + desc: "array table into interface in struct already initialized with wrong type", + input: `[[foo]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo interface{} + } + return test{ + target: &doc{ + Foo: []string{}, + }, + expected: &doc{ + Foo: []interface{}{ + map[string]interface{}{ + "bar": "hello", + }, + }, + }, + } + }, + }, + { + desc: "array table into nil ptr", + input: `[[foo]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo *[]interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + Foo: &[]interface{}{ + map[string]interface{}{ + "bar": "hello", + }, + }, + }, + } + }, + }, + { + desc: "array table into nil ptr of invalid type", + input: `[[foo]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo *string + } + return test{ + target: &doc{}, + err: true, + } + }, + }, + { + desc: "array table with intermediate ptr", + input: `[[foo.bar]] + bar = "hello"`, + gen: func() test { + type doc struct { + Foo *map[string]interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + Foo: &map[string]interface{}{ + "bar": []interface{}{ + map[string]interface{}{ + "bar": "hello", + }, + }, + }, + }, + } + }, + }, + { + desc: "unmarshal array into interface that contains a slice", + input: `a = [1,2,3]`, + gen: func() test { + type doc struct { + A interface{} + } + return test{ + target: &doc{ + A: []string{}, + }, + expected: &doc{ + A: []interface{}{ + int64(1), + int64(2), + int64(3), + }, + }, + } + }, + }, + { + desc: "unmarshal array into interface that contains a []interface{}", + input: `a = [1,2,3]`, + gen: func() test { + type doc struct { + A interface{} + } + return test{ + target: &doc{ + A: []interface{}{}, + }, + expected: &doc{ + A: []interface{}{ + int64(1), + int64(2), + int64(3), + }, + }, + } + }, + }, + { + desc: "unmarshal key into map with existing value", + input: `a = "new"`, + gen: func() test { + return test{ + target: &map[string]interface{}{"a": "old"}, + expected: &map[string]interface{}{"a": "new"}, + } + }, + }, + { + desc: "unmarshal key into map with existing value", + input: `a.b = "new"`, + gen: func() test { + type doc struct { + A interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + A: map[string]interface{}{ + "b": "new", + }, + }, + } + }, + }, + { + desc: "unmarshal array into struct field with existing array", + input: `a = [1,2]`, + gen: func() test { + type doc struct { + A []int + } + return test{ + target: &doc{}, + expected: &doc{ + A: []int{1, 2}, + }, + } + }, + }, + { + desc: "unmarshal inline table into map", + input: `a = {b="hello"}`, + gen: func() test { + type doc struct { + A map[string]interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + A: map[string]interface{}{ + "b": "hello", + }, + }, + } + }, + }, + { + desc: "unmarshal inline table into map of incorrect type", + input: `a = {b="hello"}`, + gen: func() test { + type doc struct { + A map[string]int + } + return test{ + target: &doc{}, + err: true, + } + }, + }, { desc: "slice pointer in slice pointer", input: `A = ["Hello"]`, @@ -1155,8 +1394,6 @@ B = "data"`, for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - if e.skip { t.Skip() } @@ -1241,6 +1478,16 @@ func TestUnmarshalOverflows(t *testing.T) { } } +func TestUnmarshalInvalidTarget(t *testing.T) { + x := "foo" + err := toml.Unmarshal([]byte{}, x) + require.Error(t, err) + + var m *map[string]interface{} + err = toml.Unmarshal([]byte{}, m) + require.Error(t, err) +} + func TestUnmarshalFloat32(t *testing.T) { t.Run("fits", func(t *testing.T) { doc := "A = 1.2" @@ -1277,8 +1524,6 @@ type Config484 struct { } func TestIssue484(t *testing.T) { - t.Parallel() - raw := []byte(`integers = ["1","2","3","100"]`) var cfg Config484 @@ -1299,8 +1544,6 @@ func (m Map458) A(s string) Slice458 { } func TestIssue458(t *testing.T) { - t.Parallel() - s := []byte(`[[package]] dependencies = ["regex"] name = "decode" @@ -1320,8 +1563,6 @@ version = "0.1.0"`) } func TestIssue252(t *testing.T) { - t.Parallel() - type config struct { Val1 string `toml:"val1"` Val2 string `toml:"val2"` @@ -1342,8 +1583,6 @@ val1 = "test1" } func TestIssue494(t *testing.T) { - t.Parallel() - data := ` foo = 2021-04-08 bar = 2021-04-08 @@ -1359,8 +1598,6 @@ bar = 2021-04-08 } func TestIssue507(t *testing.T) { - t.Parallel() - data := []byte{'0', '=', '\n', '0', 'a', 'm', 'e'} m := map[string]interface{}{} err := toml.Unmarshal(data, &m) @@ -1369,8 +1606,6 @@ func TestIssue507(t *testing.T) { //nolint:funlen func TestUnmarshalDecodeErrors(t *testing.T) { - t.Parallel() - examples := []struct { desc string data string @@ -1603,8 +1838,6 @@ world'`, for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - m := map[string]interface{}{} err := toml.Unmarshal([]byte(e.data), &m) require.Error(t, err) @@ -1624,8 +1857,6 @@ world'`, //nolint:funlen func TestLocalDateTime(t *testing.T) { - t.Parallel() - examples := []struct { desc string input string @@ -1675,7 +1906,6 @@ func TestLocalDateTime(t *testing.T) { for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() t.Log("input:", e.input) doc := `a = ` + e.input m := map[string]toml.LocalDateTime{} @@ -1691,8 +1921,6 @@ func TestLocalDateTime(t *testing.T) { } func TestIssue287(t *testing.T) { - t.Parallel() - b := `y=[[{}]]` v := map[string]interface{}{} err := toml.Unmarshal([]byte(b), &v) @@ -1709,8 +1937,6 @@ func TestIssue287(t *testing.T) { } func TestIssue508(t *testing.T) { - t.Parallel() - type head struct { Title string `toml:"title"` } @@ -1729,8 +1955,6 @@ func TestIssue508(t *testing.T) { //nolint:funlen func TestDecoderStrict(t *testing.T) { - t.Parallel() - examples := []struct { desc string input string @@ -1801,8 +2025,6 @@ bar = 42 for _, e := range examples { e := e t.Run(e.desc, func(t *testing.T) { - t.Parallel() - t.Run("strict", func(t *testing.T) { r := strings.NewReader(e.input) d := toml.NewDecoder(r)