Compare commits

...

64 Commits

Author SHA1 Message Date
Thomas Pelletier dc72d75f3e Keep separate fn for []interface{} unmarshal 2021-11-13 19:20:20 -05:00
Thomas Pelletier f77775b59e Use less reflection when making slices
```
name                               old time/op    new time/op    delta
UnmarshalDataset/config-2            24.9ms ± 0%    24.6ms ± 0%  -1.09%  (p=0.029 n=4+4)
UnmarshalDataset/canada-2            61.7ms ± 1%    62.1ms ± 3%    ~     (p=1.000 n=5+5)
UnmarshalDataset/citm_catalog-2      24.7ms ± 1%    24.2ms ± 0%  -2.30%  (p=0.008 n=5+5)
UnmarshalDataset/twitter-2           10.9ms ± 2%    10.7ms ± 1%  -1.46%  (p=0.008 n=5+5)
UnmarshalDataset/code-2               108ms ± 0%     106ms ± 0%  -1.91%  (p=0.008 n=5+5)
UnmarshalDataset/example-2            176µs ± 0%     173µs ± 0%  -1.83%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/struct-2     586ns ± 1%     587ns ± 0%    ~     (p=0.690 n=5+5)
Unmarshal/SimpleDocument/map-2        876ns ± 0%     872ns ± 0%    ~     (p=0.095 n=5+5)
Unmarshal/ReferenceFile/struct-2     49.5µs ± 0%    49.5µs ± 0%    ~     (p=0.222 n=5+5)
Unmarshal/ReferenceFile/map-2        79.6µs ± 0%    79.1µs ± 0%  -0.62%  (p=0.008 n=5+5)
Unmarshal/HugoFrontMatter-2          13.7µs ± 0%    13.5µs ± 0%  -0.91%  (p=0.008 n=5+5)

name                               old speed      new speed      delta
UnmarshalDataset/config-2          42.2MB/s ± 0%  42.7MB/s ± 0%  +1.10%  (p=0.029 n=4+4)
UnmarshalDataset/canada-2          35.7MB/s ± 1%  35.5MB/s ± 3%    ~     (p=1.000 n=5+5)
UnmarshalDataset/citm_catalog-2    22.6MB/s ± 1%  23.1MB/s ± 0%  +2.36%  (p=0.008 n=5+5)
UnmarshalDataset/twitter-2         40.6MB/s ± 2%  41.2MB/s ± 1%  +1.47%  (p=0.008 n=5+5)
UnmarshalDataset/code-2            24.9MB/s ± 0%  25.4MB/s ± 0%  +1.95%  (p=0.008 n=5+5)
UnmarshalDataset/example-2         46.0MB/s ± 0%  46.9MB/s ± 0%  +1.86%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/struct-2  18.8MB/s ± 1%  18.7MB/s ± 0%    ~     (p=0.651 n=5+5)
Unmarshal/SimpleDocument/map-2     12.6MB/s ± 0%  12.6MB/s ± 0%    ~     (p=0.087 n=5+5)
Unmarshal/ReferenceFile/struct-2    106MB/s ± 0%   106MB/s ± 0%    ~     (p=0.222 n=5+5)
Unmarshal/ReferenceFile/map-2      65.8MB/s ± 0%  66.2MB/s ± 0%  +0.63%  (p=0.008 n=5+5)
Unmarshal/HugoFrontMatter-2        40.0MB/s ± 0%  40.3MB/s ± 0%  +0.92%  (p=0.008 n=5+5)

name                               old alloc/op   new alloc/op   delta
UnmarshalDataset/config-2            5.85MB ± 0%    5.85MB ± 0%    ~     (p=1.000 n=5+5)
UnmarshalDataset/canada-2            75.2MB ± 0%    75.2MB ± 0%    ~     (p=1.000 n=5+5)
UnmarshalDataset/citm_catalog-2      35.0MB ± 0%    35.0MB ± 0%    ~     (p=0.841 n=5+5)
UnmarshalDataset/twitter-2           13.5MB ± 0%    13.5MB ± 0%    ~     (p=0.548 n=5+5)
UnmarshalDataset/code-2              22.0MB ± 0%    22.0MB ± 0%    ~     (p=0.738 n=5+5)
UnmarshalDataset/example-2            203kB ± 0%     203kB ± 0%    ~     (p=0.714 n=5+5)
Unmarshal/SimpleDocument/struct-2      709B ± 0%      709B ± 0%    ~     (all equal)
Unmarshal/SimpleDocument/map-2       1.08kB ± 0%    1.08kB ± 0%    ~     (all equal)
Unmarshal/ReferenceFile/struct-2     19.7kB ± 0%    19.7kB ± 0%    ~     (all equal)
Unmarshal/ReferenceFile/map-2        37.0kB ± 0%    37.0kB ± 0%    ~     (p=0.333 n=4+5)
Unmarshal/HugoFrontMatter-2          7.22kB ± 0%    7.22kB ± 0%    ~     (all equal)

name                               old allocs/op  new allocs/op  delta
UnmarshalDataset/config-2              230k ± 0%      230k ± 0%    ~     (p=0.556 n=4+5)
UnmarshalDataset/canada-2              391k ± 0%      391k ± 0%    ~     (all equal)
UnmarshalDataset/citm_catalog-2        158k ± 0%      158k ± 0%    ~     (p=1.000 n=4+5)
UnmarshalDataset/twitter-2            54.7k ± 0%     54.7k ± 0%    ~     (p=1.000 n=4+5)
UnmarshalDataset/code-2               1.05M ± 0%     1.05M ± 0%    ~     (all equal)
UnmarshalDataset/example-2            1.28k ± 0%     1.28k ± 0%    ~     (all equal)
Unmarshal/SimpleDocument/struct-2      8.00 ± 0%      8.00 ± 0%    ~     (all equal)
Unmarshal/SimpleDocument/map-2         13.0 ± 0%      13.0 ± 0%    ~     (all equal)
Unmarshal/ReferenceFile/struct-2        123 ± 0%       123 ± 0%    ~     (all equal)
Unmarshal/ReferenceFile/map-2           590 ± 0%       590 ± 0%    ~     (all equal)
Unmarshal/HugoFrontMatter-2             130 ± 0%       130 ± 0%    ~     (all equal)
```
2021-11-13 19:20:20 -05:00
Thomas Pelletier b52f6c9823 Remove some allocs for slices in interfaces
```
name                               old time/op    new time/op    delta
UnmarshalDataset/config-2            24.9ms ± 1%    24.9ms ± 0%     ~     (p=0.413 n=5+4)
UnmarshalDataset/canada-2            66.1ms ± 0%    61.7ms ± 1%   -6.63%  (p=0.008 n=5+5)
UnmarshalDataset/citm_catalog-2      25.3ms ± 5%    24.7ms ± 1%   -2.09%  (p=0.032 n=5+5)
UnmarshalDataset/twitter-2           10.9ms ± 2%    10.9ms ± 2%     ~     (p=1.000 n=5+5)
UnmarshalDataset/code-2               108ms ± 0%     108ms ± 0%     ~     (p=0.095 n=5+5)
UnmarshalDataset/example-2            177µs ± 2%     176µs ± 0%     ~     (p=0.841 n=5+5)
Unmarshal/SimpleDocument/struct-2     579ns ± 0%     586ns ± 1%   +1.30%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/map-2        875ns ± 1%     876ns ± 0%     ~     (p=0.548 n=5+5)
Unmarshal/ReferenceFile/struct-2     49.7µs ± 1%    49.5µs ± 0%     ~     (p=0.095 n=5+5)
Unmarshal/ReferenceFile/map-2        80.4µs ± 0%    79.6µs ± 0%   -0.99%  (p=0.008 n=5+5)
Unmarshal/HugoFrontMatter-2          13.9µs ± 0%    13.7µs ± 0%   -1.70%  (p=0.008 n=5+5)

name                               old speed      new speed      delta
UnmarshalDataset/config-2          42.1MB/s ± 1%  42.2MB/s ± 0%     ~     (p=0.381 n=5+4)
UnmarshalDataset/canada-2          33.3MB/s ± 0%  35.7MB/s ± 1%   +7.11%  (p=0.008 n=5+5)
UnmarshalDataset/citm_catalog-2    22.1MB/s ± 5%  22.6MB/s ± 1%   +2.08%  (p=0.032 n=5+5)
UnmarshalDataset/twitter-2         40.7MB/s ± 2%  40.6MB/s ± 2%     ~     (p=1.000 n=5+5)
UnmarshalDataset/code-2            24.8MB/s ± 0%  24.9MB/s ± 0%     ~     (p=0.103 n=5+5)
UnmarshalDataset/example-2         45.8MB/s ± 2%  46.0MB/s ± 0%     ~     (p=0.841 n=5+5)
Unmarshal/SimpleDocument/struct-2  19.0MB/s ± 0%  18.8MB/s ± 1%   -1.26%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/map-2     12.6MB/s ± 1%  12.6MB/s ± 0%     ~     (p=0.508 n=5+5)
Unmarshal/ReferenceFile/struct-2    105MB/s ± 1%   106MB/s ± 0%     ~     (p=0.095 n=5+5)
Unmarshal/ReferenceFile/map-2      65.2MB/s ± 0%  65.8MB/s ± 0%   +1.00%  (p=0.008 n=5+5)
Unmarshal/HugoFrontMatter-2        39.3MB/s ± 0%  40.0MB/s ± 0%   +1.73%  (p=0.008 n=5+5)

name                               old alloc/op   new alloc/op   delta
UnmarshalDataset/config-2            5.85MB ± 0%    5.85MB ± 0%   -0.00%  (p=0.008 n=5+5)
UnmarshalDataset/canada-2            76.6MB ± 0%    75.2MB ± 0%   -1.76%  (p=0.016 n=4+5)
UnmarshalDataset/citm_catalog-2      35.3MB ± 0%    35.0MB ± 0%   -0.71%  (p=0.008 n=5+5)
UnmarshalDataset/twitter-2           13.5MB ± 0%    13.5MB ± 0%   -0.19%  (p=0.016 n=4+5)
UnmarshalDataset/code-2              22.3MB ± 0%    22.0MB ± 0%   -1.31%  (p=0.008 n=5+5)
UnmarshalDataset/example-2            204kB ± 0%     203kB ± 0%   -0.34%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/struct-2      709B ± 0%      709B ± 0%     ~     (all equal)
Unmarshal/SimpleDocument/map-2       1.08kB ± 0%    1.08kB ± 0%     ~     (all equal)
Unmarshal/ReferenceFile/struct-2     19.8kB ± 0%    19.7kB ± 0%   -0.24%  (p=0.008 n=5+5)
Unmarshal/ReferenceFile/map-2        37.3kB ± 0%    37.0kB ± 0%   -0.64%  (p=0.029 n=4+4)
Unmarshal/HugoFrontMatter-2          7.26kB ± 0%    7.22kB ± 0%   -0.66%  (p=0.008 n=5+5)

name                               old allocs/op  new allocs/op  delta
UnmarshalDataset/config-2              230k ± 0%      230k ± 0%   -0.00%  (p=0.000 n=5+4)
UnmarshalDataset/canada-2              447k ± 0%      391k ± 0%  -12.53%  (p=0.008 n=5+5)
UnmarshalDataset/citm_catalog-2        169k ± 0%      158k ± 0%   -6.20%  (p=0.029 n=4+4)
UnmarshalDataset/twitter-2            55.8k ± 0%     54.7k ± 0%   -1.88%  (p=0.029 n=4+4)
UnmarshalDataset/code-2               1.06M ± 0%     1.05M ± 0%   -1.14%  (p=0.008 n=5+5)
UnmarshalDataset/example-2            1.31k ± 0%     1.28k ± 0%   -2.21%  (p=0.008 n=5+5)
Unmarshal/SimpleDocument/struct-2      8.00 ± 0%      8.00 ± 0%     ~     (all equal)
Unmarshal/SimpleDocument/map-2         13.0 ± 0%      13.0 ± 0%     ~     (all equal)
Unmarshal/ReferenceFile/struct-2        125 ± 0%       123 ± 0%   -1.60%  (p=0.008 n=5+5)
Unmarshal/ReferenceFile/map-2           600 ± 0%       590 ± 0%   -1.67%  (p=0.008 n=5+5)
Unmarshal/HugoFrontMatter-2             132 ± 0%       130 ± 0%   -1.52%  (p=0.008 n=5+5)
```
2021-11-13 19:20:20 -05:00
Thomas Pelletier 12244064bb Use global cache to unmarshal all slice types 2021-11-13 19:20:20 -05:00
Thomas Pelletier 6430ee0bfa Generic slice unmarshal fn 2021-11-13 19:20:20 -05:00
Thomas Pelletier cf530eba46 Specialize array unmarshal into []interface{} 2021-11-13 19:20:19 -05:00
Thomas Pelletier 64fe47161f API: Encoder and Decoder options are chainable (#670)
Fixes #583
2021-11-13 19:04:53 -05:00
Thomas Pelletier 4dff8eaa4d Decoder: prevent duplicates of inline tables (#667)
* seen: prevent duplicates of inline tables

* Provide clearer error message for redefined keys

For example:

``
toml: key b is already defined
```
2021-11-10 10:04:43 -05:00
Cameron Moore 2dbd29a565 parser: Fix missing check for upper exponent (#665) 2021-11-09 21:15:23 -05:00
Thomas Pelletier f27a07d31a seen: verify arrays (#663)
Fixes #662
2021-11-09 20:26:30 -05:00
Thomas Pelletier 644515958c Update TOML test suite (#661)
Ref #658
2021-11-08 22:35:35 -05:00
Thomas Pelletier 8683be35f6 seen: check inline tables (#660)
Fixes #658
2021-11-08 21:53:02 -05:00
Thomas Pelletier dc1740d473 Decode: code cleanup for struct cache (#659) 2021-11-07 18:35:30 -05:00
Thomas Pelletier 11f789ef11 Decode: prevent comments that look like dates to be accepted (#657)
* parser: fix date detection

When the parser has to decide between parsing and integer or a date, it should
check that all characters are actually acceptable (digits, or date/time
elements).

Fixes #655
2021-11-04 22:06:12 -04:00
Thomas Pelletier 74d21b367f scanner: handle carriage return in comments (#656)
Fixes #653
2021-11-04 21:40:16 -04:00
Thomas Pelletier 6617e7e73d utf8: use lookup table to validate ASCII (#654) 2021-11-04 16:05:36 -04:00
Thomas Pelletier 3dbca20bc9 Decoder: flag invalid carriage returns in strings (#652)
Fixes #651
2021-11-02 10:02:25 -04:00
Thomas Pelletier 85c0658984 Decode: add missing checks for LocalTime (#650) 2021-10-29 22:13:08 -04:00
Thomas Pelletier 772d169b52 testsuite: return error when can't encode tag (#648) 2021-10-29 21:51:50 -04:00
Cameron Moore b4ec220f7e Update tomltestgen and regenerate tests (#645)
Remove testsuite build tag from generated tests file

Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-28 20:46:08 -04:00
Thomas Pelletier 3694ae88f6 decode: error on _ before exponent in floats (#647)
Fixes #646
2021-10-28 20:41:10 -04:00
Thomas Pelletier 19751e8a51 Missing performance section 2021-10-28 19:10:43 -04:00
Thomas Pelletier 925f214125 Add GitHub release configuration (#644) 2021-10-28 19:06:14 -04:00
Thomas Pelletier 39f893ad99 Multiline strings fixes (#643)
* scanner: allow multiline strings to end with "" or ''

* parser: trim all whitespaces after \ in multiline
2021-10-28 18:26:34 -04:00
Thomas Pelletier c871a61015 unmarshal: use UnmarshalText for any type (#642)
Not only structs can implement TextUnmarshaler.

Fixes #564
2021-10-28 17:02:47 -04:00
Thomas Pelletier d0d001625c unmarshal: don't panic when storing table in slice (#641)
New error message:

```
toml: cannot store a table in a slice
1| [things]
 |  ~~~~~~ cannot store a table in a slice
2| foo = "bar"
```

Fixes #623
2021-10-25 16:47:10 -04:00
Thomas Pelletier 64941b99e2 unmarshal: empty document results in map (#640)
Fixes #602
2021-10-25 15:55:54 -04:00
Thomas Pelletier ed02a1f192 seen: check for explicit tables on dotted keys (#639)
The TOML spec is being clarified to say that dotted keys "define" their
intermediate tables. Therefore the seen tracker needs to verify that none of
them reference an explicit table.

Also added a missing seen expression check for key-values parsed as part of a
table section.

See https://github.com/toml-lang/toml/issues/846
2021-10-22 23:25:28 -04:00
Thomas Pelletier 4d7c9ddac7 Floats and integers parsing fixes (#638)
* parser: fix scan of float with exp but no decimal
* decoder: validate leading zeros for decimals
2021-10-22 22:25:56 -04:00
Thomas Pelletier feb1830dcc tomltest: enable TestTOMLTest_Valid_Comment_Tricky 2021-10-21 22:30:58 -04:00
Thomas Pelletier 1c33d6ce20 tomltest: custom comparison functions (#637) 2021-10-21 22:29:04 -04:00
Thomas Pelletier 3000471a12 parser: improve floats validation (#636) 2021-10-20 08:49:28 -04:00
Thomas Pelletier 1f33a6a476 tomltest: enable Valid_Datetime_LocalTime 2021-10-19 21:20:45 -04:00
Thomas Pelletier 2700aad5d2 tomltest: run UTF8 tests (#634) 2021-10-19 16:00:56 -04:00
Thomas Pelletier 7ccaa2744e tomltest: unmarshal JSONs for tests (#633)
Comparing the output and the expected results byte-wise means we get false
negative when order doesn't matter (for example the ValidTableKeyword test).
2021-10-19 15:29:49 -04:00
Johanan Idicula df4bb061f8 time: follow RFC3339 spec for datetime (#632) 2021-10-18 09:56:07 -04:00
Thomas Pelletier 9e81ce1c33 Create SECURITY.md 2021-10-17 20:57:34 -04:00
Cameron Moore a23850f29b decode: preserve nanosecond precision when decoding time (#626)
Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-17 20:43:29 -04:00
Johanan Idicula 76f53c857b unmarshal: validate date (#622) 2021-10-17 20:18:20 -04:00
Thomas Pelletier 85f5d567e4 parser: validate invalid ASCII control characters 2021-10-16 07:41:12 -04:00
Thomas Pelletier bd5cba0b0b Update benchmarks readme (#630)
* Fix ci.sh for new benchmarks

Nice + taskset are more stable on my machine. We want to excude non high-level
benchmarks. BurntSushi/toml now supports canada.toml.

* Update latest benchmarks in README
2021-10-15 19:53:40 -04:00
Thomas Pelletier cd54472d03 Validate UTF-8 (#629) 2021-10-15 19:13:21 -04:00
Thomas Pelletier cc0d1a90ff testgen: skip currently failing tests (#627) 2021-10-14 11:14:44 -04:00
Sterling Hanenkamp 4984dcb5e9 encode: ensure floats have decimal point (#615)
Fixes #571

Co-authored-by: Sterling Hanenkamp <sterling@ziprecruiter.com>
2021-10-14 08:34:54 -04:00
jidicula 86632bc190 parser: fail when missing array separator (#616)
Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-14 08:26:29 -04:00
Cameron Moore d25eec183f gotoml-test-decoder: add toml-test decoder command (#619) 2021-10-14 08:14:34 -04:00
Riya John e96746311c decoder: fix panic date time should have a timezone (#614)
Fixes #596

Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-06 21:24:25 -04:00
Cameron Moore 62acca2b68 tomltestgen: add toml-test unit test generation command (#610)
Tests are hidden behind a "testsuite" build tag for now since many tests
are failing.  Use `go test -tags testsuite` to activate.

Use `go generate` to regenerate toml_testgen_test.go.

Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-03 22:15:30 -04:00
Cameron Moore 476492a85c unmarshal: support lowercase 'T' and 'Z' in date-time parsing (#601)
RFC3399 allows for lowercase 't' and 'z' in date-time values.

Fixes #600
2021-09-25 10:02:23 -07:00
Thomas Pelletier ee9b902222 unmarshal: convert ints if target type is compatible (#594)
This is required to support custom types.

Fixes #590
2021-09-09 21:25:14 -04:00
Thomas Pelletier fa56f48daf parser: don't overflow when parsing bad times (#593)
Fixes #585
2021-09-09 11:59:37 -04:00
Thomas Pelletier f34c9c332f scanner: fix error reporting for last comments (#591)
When an invalid TOML expression ends with a comment before the end of
file, the decode error would take a nil from scanComment, which is not
part of the document.

Fixes #588
2021-09-08 21:54:30 -04:00
Thomas Pelletier a0d685d482 unmarshal: don't crash on unterminated inline table (#587)
Fixes #586
2021-09-07 20:08:59 -04:00
Thomas Pelletier 4a5ae9e81e errors: fix context generation with only one line 2021-09-07 10:36:22 -04:00
Thomas Pelletier 7e2fa1bc80 unmarshal: fix non-terminated array error
Fixes #581
2021-09-07 10:36:22 -04:00
Thomas Pelletier 40cfb6f458 parser: don't crash on unterminated table key (#580)
* parser: don't crash on unterminated table key

Fixes #579

* parser: fix format of error returned by expect

EOF was missing the format string and %U is not very human friendly.
2021-09-06 12:18:45 -04:00
Thomas Pelletier 1230ca485e unmarshal: make copy of non addressable values (#576)
When unmarshaling into a nested struct in a map, the value is not
addressable. In that case, make a copy of it and modify it instead.

Fixes #575
2021-08-31 20:22:38 -04:00
Thomas Pelletier 69ab7e10d1 Go 1.17 release (#574)
Minimum supported version: Go 1.16.
2021-08-17 09:43:52 -04:00
Thomas Pelletier fa07960695 Add installation instructions (#572) 2021-07-27 18:12:44 -04:00
kkHAIKE 8be357dfa1 Add LocalTime to interface{} decode support (#567)
Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-07-21 17:50:12 +02:00
kkHAIKE a93b34d984 Unicode parsing optimization (#568)
Inline call to hexToRune and uses specialized parsing, as found in encoding/json.

Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-07-21 10:50:03 +02:00
Matthieu MOREL 9c24fbeaad Set up Dependabot for GitHub actions and docker (#570) 2021-07-20 16:54:26 +02:00
Thomas Pelletier f6b38c33b7 Provide own implementation of Local* (#558)
* Reduces the public API.
* Reuses optimized parsing functions.
* Removes reliance on Google code under Apache license.
2021-06-08 20:27:05 -04:00
Thomas Pelletier 773f10110c Unmarshal recursive structs (#557)
Co-authored-by: Nabetani <takenori@nabetani.sakura.ne.jp>
2021-06-08 14:22:39 -04:00
40 changed files with 4683 additions and 2163 deletions
+15 -4
View File
@@ -1,6 +1,17 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: "gomod" - package-ecosystem: gomod
directory: "/" # Location of package manifests directory: /
schedule: schedule:
interval: "daily" interval: daily
open-pull-requests-limit: 10
- package-ecosystem: github-actions
directory: /
schedule:
interval: daily
open-pull-requests-limit: 10
- package-ecosystem: docker
directory: /
schedule:
interval: daily
open-pull-requests-limit: 10
+20
View File
@@ -0,0 +1,20 @@
changelog:
exclude:
labels:
- build
categories:
- title: What's new
labels:
- feature
- title: Performance
labels:
- performance
- title: Fixed bugs
labels:
- bug
- title: Documentation
labels:
- doc
- title: Other changes
labels:
- "*"
+1 -1
View File
@@ -12,7 +12,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ 'ubuntu-latest', 'windows-latest', 'macos-latest'] os: [ 'ubuntu-latest', 'windows-latest', 'macos-latest']
go: [ '1.15', '1.16' ] go: [ '1.16', '1.17' ]
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
name: ${{ matrix.go }}/${{ matrix.os }} name: ${{ matrix.go }}/${{ matrix.os }}
steps: steps:
+33 -16
View File
@@ -31,6 +31,8 @@ Full API, examples, and implementation notes are available in the Go documentati
import "github.com/pelletier/go-toml/v2" import "github.com/pelletier/go-toml/v2"
``` ```
See [Modules](#Modules).
## Features ## Features
### Stdlib behavior ### Stdlib behavior
@@ -156,12 +158,12 @@ Execution time speedup compared to other Go TOML libraries:
<tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr> <tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr>
</thead> </thead>
<tbody> <tbody>
<tr><td>Marshal/HugoFrontMatter</td><td>2.0x</td><td>2.0x</td></tr> <tr><td>Marshal/HugoFrontMatter-2</td><td>1.9x</td><td>1.9x</td></tr>
<tr><td>Marshal/ReferenceFile/map</td><td>1.8x</td><td>2.0x</td></tr> <tr><td>Marshal/ReferenceFile/map-2</td><td>1.7x</td><td>1.9x</td></tr>
<tr><td>Marshal/ReferenceFile/struct</td><td>2.7x</td><td>2.7x</td></tr> <tr><td>Marshal/ReferenceFile/struct-2</td><td>2.4x</td><td>2.6x</td></tr>
<tr><td>Unmarshal/HugoFrontMatter</td><td>3.0x</td><td>2.6x</td></tr> <tr><td>Unmarshal/HugoFrontMatter-2</td><td>2.9x</td><td>2.5x</td></tr>
<tr><td>Unmarshal/ReferenceFile/map</td><td>3.0x</td><td>3.1x</td></tr> <tr><td>Unmarshal/ReferenceFile/map-2</td><td>2.7x</td><td>2.6x</td></tr>
<tr><td>Unmarshal/ReferenceFile/struct</td><td>5.9x</td><td>6.6x</td></tr> <tr><td>Unmarshal/ReferenceFile/struct-2</td><td>4.8x</td><td>5.1x</td></tr>
</tbody> </tbody>
</table> </table>
<details><summary>See more</summary> <details><summary>See more</summary>
@@ -174,21 +176,36 @@ provided for completeness.</p>
<tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr> <tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr>
</thead> </thead>
<tbody> <tbody>
<tr><td>Marshal/SimpleDocument/map</td><td>1.7x</td><td>2.1x</td></tr> <tr><td>Marshal/SimpleDocument/map-2</td><td>1.7x</td><td>2.1x</td></tr>
<tr><td>Marshal/SimpleDocument/struct</td><td>2.6x</td><td>2.9x</td></tr> <tr><td>Marshal/SimpleDocument/struct-2</td><td>2.5x</td><td>2.8x</td></tr>
<tr><td>Unmarshal/SimpleDocument/map</td><td>4.1x</td><td>2.9x</td></tr> <tr><td>Unmarshal/SimpleDocument/map-2</td><td>4.1x</td><td>3.1x</td></tr>
<tr><td>Unmarshal/SimpleDocument/struct</td><td>6.3x</td><td>4.1x</td></tr> <tr><td>Unmarshal/SimpleDocument/struct-2</td><td>6.4x</td><td>4.3x</td></tr>
<tr><td>UnmarshalDataset/example</td><td>3.5x</td><td>2.4x</td></tr> <tr><td>UnmarshalDataset/example-2</td><td>3.4x</td><td>3.2x</td></tr>
<tr><td>UnmarshalDataset/code</td><td>2.2x</td><td>2.8x</td></tr> <tr><td>UnmarshalDataset/code-2</td><td>2.2x</td><td>2.5x</td></tr>
<tr><td>UnmarshalDataset/twitter</td><td>2.8x</td><td>2.1x</td></tr> <tr><td>UnmarshalDataset/twitter-2</td><td>2.8x</td><td>2.7x</td></tr>
<tr><td>UnmarshalDataset/citm_catalog</td><td>2.3x</td><td>1.5x</td></tr> <tr><td>UnmarshalDataset/citm_catalog-2</td><td>2.2x</td><td>2.0x</td></tr>
<tr><td>UnmarshalDataset/config</td><td>4.2x</td><td>3.2x</td></tr> <tr><td>UnmarshalDataset/canada-2</td><td>1.8x</td><td>1.4x</td></tr>
<tr><td>[Geo mean]</td><td>3.0x</td><td>2.7x</td></tr> <tr><td>UnmarshalDataset/config-2</td><td>4.4x</td><td>2.9x</td></tr>
<tr><td>[Geo mean]</td><td>2.8x</td><td>2.6x</td></tr>
</tbody> </tbody>
</table> </table>
<p>This table can be generated with <code>./ci.sh benchmark -a -html</code>.</p> <p>This table can be generated with <code>./ci.sh benchmark -a -html</code>.</p>
</details> </details>
## Modules
go-toml uses Go's standard modules system.
Installation instructions:
- Go ≥ 1.16: Nothing to do. Use the import in your code. The `go` command deals
with it automatically.
- Go ≥ 1.13: `GO111MODULE=on go get github.com/pelletier/go-toml/v2`.
In case of trouble: [Go Modules FAQ][mod-faq].
[mod-faq]: https://github.com/golang/go/wiki/Modules#why-does-installing-a-tool-via-go-get-fail-with-error-cannot-find-main-module
## Migrating from v1 ## Migrating from v1
This section describes the differences between v1 and v2, with some pointers on This section describes the differences between v1 and v2, with some pointers on
+19
View File
@@ -0,0 +1,19 @@
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ---------- | ------------------ |
| Latest 2.x | :white_check_mark: |
| All 1.x | :x: |
| All 0.x | :x: |
## Reporting a Vulnerability
Email a vulnerability report to `security@pelletier.codes`. Make sure to include
as many details as possible to reproduce the vulnerability. This is a
side-project: I will try to get back to you as quickly as possible, time
permitting in my personal life. Providing a working patch helps very much!
-1
View File
@@ -321,7 +321,6 @@ type benchmarkDoc struct {
Key1 []int64 Key1 []int64
Key2 []string Key2 []string
Key3 [][]int64 Key3 [][]int64
// TODO: Key4 not supported by go-toml's Unmarshal
Key4 []interface{} Key4 []interface{}
Key5 []int64 Key5 []int64
Key6 []int64 Key6 []int64
+71
View File
@@ -0,0 +1,71 @@
package toml
import (
"bytes"
"testing"
)
var valid10Ascii = []byte("1234567890")
var valid10Utf8 = []byte("日本語a")
var valid1kUtf8 = bytes.Repeat([]byte("0123456789日本語日本語日本語日abcdefghijklmnopqrstuvwx"), 16)
var valid1MUtf8 = bytes.Repeat(valid1kUtf8, 1024)
var valid1kAscii = bytes.Repeat([]byte("012345678998jhjklasDJKLAAdjdfjsdklfjdslkabcdefghijklmnopqrstuvwx"), 16)
var valid1MAscii = bytes.Repeat(valid1kAscii, 1024)
func BenchmarkScanComments(b *testing.B) {
wrap := func(x []byte) []byte {
return []byte("# " + string(x) + "\n")
}
inputs := map[string][]byte{
"10Valid": wrap(valid10Ascii),
"1kValid": wrap(valid1kAscii),
"1MValid": wrap(valid1MAscii),
"10ValidUtf8": wrap(valid10Utf8),
"1kValidUtf8": wrap(valid1kUtf8),
"1MValidUtf8": wrap(valid1MUtf8),
}
for name, input := range inputs {
b.Run(name, func(b *testing.B) {
b.SetBytes(int64(len(input)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
scanComment(input)
}
})
}
}
func BenchmarkParseLiteralStringValid(b *testing.B) {
wrap := func(x []byte) []byte {
return []byte("'" + string(x) + "'")
}
inputs := map[string][]byte{
"10Valid": wrap(valid10Ascii),
"1kValid": wrap(valid1kAscii),
"1MValid": wrap(valid1MAscii),
"10ValidUtf8": wrap(valid10Utf8),
"1kValidUtf8": wrap(valid1kUtf8),
"1MValidUtf8": wrap(valid1MUtf8),
}
for name, input := range inputs {
b.Run(name, func(b *testing.B) {
p := parser{}
b.SetBytes(int64(len(input)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _, err := p.parseLiteralString(input)
if err != nil {
panic(err)
}
}
})
}
}
+2 -4
View File
@@ -140,12 +140,10 @@ bench() {
if [ "${replace}" != "" ]; then if [ "${replace}" != "" ]; then
find ./benchmark/ -iname '*.go' -exec sed -i -E "s|github.com/pelletier/go-toml/v2|${replace}|g" {} \; find ./benchmark/ -iname '*.go' -exec sed -i -E "s|github.com/pelletier/go-toml/v2|${replace}|g" {} \;
go get "${replace}" go get "${replace}"
# hack: remove canada.toml.gz because it is not supported by
# burntsushi, and replace is only used for benchmark -a
rm -f benchmark/testdata/canada.toml.gz
fi fi
go test -bench=. -count=10 ./... | tee "${out}" export GOMAXPROCS=2
nice -n -19 taskset --cpu-list 0,1 go test '-bench=^Benchmark(Un)?[mM]arshal' -count=5 -run=Nothing ./... | tee "${out}"
popd popd
if [ "${branch}" != "HEAD" ]; then if [ "${branch}" != "HEAD" ]; then
+30
View File
@@ -0,0 +1,30 @@
package main
import (
"flag"
"log"
"os"
"path"
"github.com/pelletier/go-toml/v2/testsuite"
)
func main() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
if flag.NArg() != 0 {
flag.Usage()
}
err := testsuite.DecodeStdin()
if err != nil {
log.Fatal(err)
}
}
func usage() {
log.Printf("Usage: %s < toml-file\n", path.Base(os.Args[0]))
flag.PrintDefaults()
os.Exit(1)
}
+223
View File
@@ -0,0 +1,223 @@
// tomltestgen retrieves a given version of the language-agnostic TOML test suite in
// https://github.com/BurntSushi/toml-test and generates go-toml unit tests.
//
// Within the go-toml package, run `go generate`. Otherwise, use:
//
// go run github.com/pelletier/go-toml/cmd/tomltestgen -o toml_testgen_test.go
package main
import (
"archive/zip"
"bytes"
"flag"
"fmt"
"go/format"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"text/template"
"time"
)
type invalid struct {
Name string
Input string
}
type valid struct {
Name string
Input string
JsonRef string
}
type testsCollection struct {
Ref string
Timestamp string
Invalid []invalid
Valid []valid
Count int
}
const srcTemplate = "// Generated by tomltestgen for toml-test ref {{.Ref}} on {{.Timestamp}}\n" +
"package toml_test\n" +
" import (\n" +
" \"testing\"\n" +
")\n" +
"{{range .Invalid}}\n" +
"func TestTOMLTest_Invalid_{{.Name}}(t *testing.T) {\n" +
" input := {{.Input|gostr}}\n" +
" testgenInvalid(t, input)\n" +
"}\n" +
"{{end}}\n" +
"\n" +
"{{range .Valid}}\n" +
"func TestTOMLTest_Valid_{{.Name}}(t *testing.T) {\n" +
" input := {{.Input|gostr}}\n" +
" jsonRef := {{.JsonRef|gostr}}\n" +
" testgenValid(t, input, jsonRef)\n" +
"}\n" +
"{{end}}\n"
func downloadTmpFile(url string) string {
log.Println("starting to download file from", url)
resp, err := http.Get(url)
if err != nil {
panic(err)
}
defer resp.Body.Close()
tmpfile, err := ioutil.TempFile("", "toml-test-*.zip")
if err != nil {
panic(err)
}
defer tmpfile.Close()
copiedLen, err := io.Copy(tmpfile, resp.Body)
if err != nil {
panic(err)
}
if resp.ContentLength > 0 && copiedLen != resp.ContentLength {
panic(fmt.Errorf("copied %d bytes, request body had %d", copiedLen, resp.ContentLength))
}
return tmpfile.Name()
}
func kebabToCamel(kebab string) string {
camel := ""
nextUpper := true
for _, c := range kebab {
if nextUpper {
camel += strings.ToUpper(string(c))
nextUpper = false
} else if c == '-' {
nextUpper = true
} else if c == '/' {
nextUpper = true
camel += "_"
} else {
camel += string(c)
}
}
return camel
}
func readFileFromZip(f *zip.File) string {
reader, err := f.Open()
if err != nil {
panic(err)
}
defer reader.Close()
bytes, err := ioutil.ReadAll(reader)
if err != nil {
panic(err)
}
return string(bytes)
}
func templateGoStr(input string) string {
return strconv.Quote(input)
}
var (
ref = flag.String("r", "master", "git reference")
out = flag.String("o", "", "output file")
)
func usage() {
_, _ = fmt.Fprintf(os.Stderr, "usage: tomltestgen [flags]\n")
flag.PrintDefaults()
}
func main() {
flag.Usage = usage
flag.Parse()
url := "https://codeload.github.com/BurntSushi/toml-test/zip/" + *ref
resultFile := downloadTmpFile(url)
defer os.Remove(resultFile)
log.Println("file written to", resultFile)
zipReader, err := zip.OpenReader(resultFile)
if err != nil {
panic(err)
}
defer zipReader.Close()
collection := testsCollection{
Ref: *ref,
Timestamp: time.Now().Format(time.RFC3339),
}
zipFilesMap := map[string]*zip.File{}
for _, f := range zipReader.File {
zipFilesMap[f.Name] = f
}
testFileRegexp := regexp.MustCompile(`([^/]+/tests/(valid|invalid)/(.+))\.(toml)`)
for _, f := range zipReader.File {
groups := testFileRegexp.FindStringSubmatch(f.Name)
if len(groups) > 0 {
name := kebabToCamel(groups[3])
testType := groups[2]
log.Printf("> [%s] %s\n", testType, name)
tomlContent := readFileFromZip(f)
switch testType {
case "invalid":
collection.Invalid = append(collection.Invalid, invalid{
Name: name,
Input: tomlContent,
})
collection.Count++
case "valid":
baseFilePath := groups[1]
jsonFilePath := baseFilePath + ".json"
jsonContent := readFileFromZip(zipFilesMap[jsonFilePath])
collection.Valid = append(collection.Valid, valid{
Name: name,
Input: tomlContent,
JsonRef: jsonContent,
})
collection.Count++
default:
panic(fmt.Sprintf("unknown test type: %s", testType))
}
}
}
log.Printf("Collected %d tests from toml-test\n", collection.Count)
funcMap := template.FuncMap{
"gostr": templateGoStr,
}
t := template.Must(template.New("src").Funcs(funcMap).Parse(srcTemplate))
buf := new(bytes.Buffer)
err = t.Execute(buf, collection)
if err != nil {
panic(err)
}
outputBytes, err := format.Source(buf.Bytes())
if err != nil {
panic(err)
}
if *out == "" {
fmt.Println(string(outputBytes))
return
}
err = os.WriteFile(*out, outputBytes, 0644)
if err != nil {
panic(err)
}
}
+201 -27
View File
@@ -35,26 +35,42 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD")
} }
date.Year = parseDecimalDigits(b[0:4]) var err error
v := parseDecimalDigits(b[5:7]) date.Year, err = parseDecimalDigits(b[0:4])
if err != nil {
return LocalDate{}, err
}
date.Month = time.Month(v) date.Month, err = parseDecimalDigits(b[5:7])
if err != nil {
return LocalDate{}, err
}
date.Day = parseDecimalDigits(b[8:10]) date.Day, err = parseDecimalDigits(b[8:10])
if err != nil {
return LocalDate{}, err
}
if !isValidDate(date.Year, date.Month, date.Day) {
return LocalDate{}, newDecodeError(b, "impossible date")
}
return date, nil return date, nil
} }
func parseDecimalDigits(b []byte) int { func parseDecimalDigits(b []byte) (int, error) {
v := 0 v := 0
for _, c := range b { for i, c := range b {
if c < '0' || c > '9' {
return 0, newDecodeError(b[i:i+1], "expected digit (0-9)")
}
v *= 10 v *= 10
v += int(c - '0') v += int(c - '0')
} }
return v return v, nil
} }
func parseDateTime(b []byte) (time.Time, error) { func parseDateTime(b []byte) (time.Time, error) {
@@ -75,7 +91,7 @@ func parseDateTime(b []byte) (time.Time, error) {
panic("date time should have a timezone") panic("date time should have a timezone")
} }
if b[0] == 'Z' { if b[0] == 'Z' || b[0] == 'z' {
b = b[1:] b = b[1:]
zone = time.UTC zone = time.UTC
} else { } else {
@@ -100,13 +116,13 @@ func parseDateTime(b []byte) (time.Time, error) {
} }
t := time.Date( t := time.Date(
dt.Date.Year, dt.Year,
dt.Date.Month, time.Month(dt.Month),
dt.Date.Day, dt.Day,
dt.Time.Hour, dt.Hour,
dt.Time.Minute, dt.Minute,
dt.Time.Second, dt.Second,
dt.Time.Nanosecond, dt.Nanosecond,
zone) zone)
return t, nil return t, nil
@@ -124,10 +140,10 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
if err != nil { if err != nil {
return dt, nil, err return dt, nil, err
} }
dt.Date = date dt.LocalDate = date
sep := b[10] sep := b[10]
if sep != 'T' && sep != ' ' { if sep != 'T' && sep != ' ' && sep != 't' {
return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space") return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space")
} }
@@ -135,7 +151,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
if err != nil { if err != nil {
return dt, nil, err return dt, nil, err
} }
dt.Time = t dt.LocalTime = t
return dt, rest, nil return dt, rest, nil
} }
@@ -149,22 +165,45 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
t LocalTime t LocalTime
) )
// check if b matches to have expected format HH:MM:SS[.NNNNNN]
const localTimeByteLen = 8 const localTimeByteLen = 8
if len(b) < localTimeByteLen { if len(b) < localTimeByteLen {
return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
t.Hour = parseDecimalDigits(b[0:2]) var err error
t.Hour, err = parseDecimalDigits(b[0:2])
if err != nil {
return t, nil, err
}
if t.Hour > 23 {
return t, nil, newDecodeError(b[0:2], "hour cannot be greater 23")
}
if b[2] != ':' { if b[2] != ':' {
return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes")
} }
t.Minute = parseDecimalDigits(b[3:5]) t.Minute, err = parseDecimalDigits(b[3:5])
if err != nil {
return t, nil, err
}
if t.Minute > 59 {
return t, nil, newDecodeError(b[3:5], "minutes cannot be greater 59")
}
if b[5] != ':' { if b[5] != ':' {
return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds")
} }
t.Second = parseDecimalDigits(b[6:8]) t.Second, err = parseDecimalDigits(b[6:8])
if err != nil {
return t, nil, err
}
if t.Second > 59 {
return t, nil, newDecodeError(b[3:5], "seconds cannot be greater 59")
}
const minLengthWithFrac = 9 const minLengthWithFrac = 9
if len(b) >= minLengthWithFrac && b[minLengthWithFrac-1] == '.' { if len(b) >= minLengthWithFrac && b[minLengthWithFrac-1] == '.' {
@@ -190,7 +229,12 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
digits++ digits++
} }
if digits == 0 {
return t, nil, newDecodeError(b[minLengthWithFrac-1:minLengthWithFrac], "nanoseconds need at least one digit")
}
t.Nanosecond = frac * nspow[digits] t.Nanosecond = frac * nspow[digits]
t.Precision = digits
return t, b[9+digits:], nil return t, b[9+digits:], nil
} }
@@ -204,7 +248,7 @@ func parseFloat(b []byte) (float64, error) {
return math.NaN(), nil return math.NaN(), nil
} }
cleaned, err := checkAndRemoveUnderscores(b) cleaned, err := checkAndRemoveUnderscoresFloats(b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -217,6 +261,30 @@ func parseFloat(b []byte) (float64, error) {
return 0, newDecodeError(b, "float cannot end with a dot") return 0, newDecodeError(b, "float cannot end with a dot")
} }
dotAlreadySeen := false
for i, c := range cleaned {
if c == '.' {
if dotAlreadySeen {
return 0, newDecodeError(b[i:i+1], "float can have at most one decimal point")
}
if !isDigit(cleaned[i-1]) {
return 0, newDecodeError(b[i-1:i+1], "float decimal point must be preceded by a digit")
}
if !isDigit(cleaned[i+1]) {
return 0, newDecodeError(b[i:i+2], "float decimal point must be followed by a digit")
}
dotAlreadySeen = true
}
}
start := 0
if b[0] == '+' || b[0] == '-' {
start = 1
}
if b[start] == '0' && isDigit(b[start+1]) {
return 0, newDecodeError(b, "float integer part cannot have leading zeroes")
}
f, err := strconv.ParseFloat(string(cleaned), 64) f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "unable to parse float: %w", err) return 0, newDecodeError(b, "unable to parse float: %w", err)
@@ -226,7 +294,7 @@ func parseFloat(b []byte) (float64, error) {
} }
func parseIntHex(b []byte) (int64, error) { func parseIntHex(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscores(b[2:]) cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -240,7 +308,7 @@ func parseIntHex(b []byte) (int64, error) {
} }
func parseIntOct(b []byte) (int64, error) { func parseIntOct(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscores(b[2:]) cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -254,7 +322,7 @@ func parseIntOct(b []byte) (int64, error) {
} }
func parseIntBin(b []byte) (int64, error) { func parseIntBin(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscores(b[2:]) cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -267,12 +335,26 @@ func parseIntBin(b []byte) (int64, error) {
return i, nil return i, nil
} }
func isSign(b byte) bool {
return b == '+' || b == '-'
}
func parseIntDec(b []byte) (int64, error) { func parseIntDec(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscores(b) cleaned, err := checkAndRemoveUnderscoresIntegers(b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
startIdx := 0
if isSign(cleaned[0]) {
startIdx++
}
if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' {
return 0, newDecodeError(b, "leading zero not allowed on decimal number")
}
i, err := strconv.ParseInt(string(cleaned), 10, 64) i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse decimal number: %w", err) return 0, newDecodeError(b, "couldn't parse decimal number: %w", err)
@@ -281,7 +363,7 @@ func parseIntDec(b []byte) (int64, error) {
return i, nil return i, nil
} }
func checkAndRemoveUnderscores(b []byte) ([]byte, error) { func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
if b[0] == '_' { if b[0] == '_' {
return nil, newDecodeError(b[0:1], "number cannot start with underscore") return nil, newDecodeError(b[0:1], "number cannot start with underscore")
} }
@@ -320,3 +402,95 @@ func checkAndRemoveUnderscores(b []byte) ([]byte, error) {
return cleaned, nil return cleaned, nil
} }
func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
if b[0] == '_' {
return nil, newDecodeError(b[0:1], "number cannot start with underscore")
}
if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore")
}
// fast path
i := 0
for ; i < len(b); i++ {
if b[i] == '_' {
break
}
}
if i == len(b) {
return b, nil
}
before := false
cleaned := make([]byte, 0, len(b))
for i := 0; i < len(b); i++ {
c := b[i]
switch c {
case '_':
if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores")
}
if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore before exponent")
}
before = false
case 'e', 'E':
if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after exponent")
}
cleaned = append(cleaned, c)
case '.':
if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after decimal point")
}
if i > 0 && b[i-1] == '_' {
return nil, newDecodeError(b[i-1:i], "cannot have underscore before decimal point")
}
cleaned = append(cleaned, c)
default:
before = true
cleaned = append(cleaned, c)
}
}
return cleaned, nil
}
// isValidDate checks if a provided date is a date that exists.
func isValidDate(year int, month int, day int) bool {
return day <= daysIn(month, year)
}
// daysBefore[m] counts the number of days in a non-leap year
// before month m begins. There is an entry for m=12, counting
// the number of days before January of next year (365).
var daysBefore = [...]int32{
0,
31,
31 + 28,
31 + 28 + 31,
31 + 28 + 31 + 30,
31 + 28 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30 + 31,
}
func daysIn(m int, year int) int {
if m == 2 && isLeap(year) {
return 29
}
return int(daysBefore[m] - daysBefore[m-1])
}
func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0)
}
+10 -1
View File
@@ -116,6 +116,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
maxLine := errLine + len(after) - 1 maxLine := errLine + len(after) - 1
lineColumnWidth := len(strconv.Itoa(maxLine)) lineColumnWidth := len(strconv.Itoa(maxLine))
// Write the lines of context strictly before the error.
for i := len(before) - 1; i > 0; i-- { for i := len(before) - 1; i > 0; i-- {
line := errLine - i line := errLine - i
buf.WriteString(formatLineNumber(line, lineColumnWidth)) buf.WriteString(formatLineNumber(line, lineColumnWidth))
@@ -129,6 +130,8 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.WriteRune('\n') buf.WriteRune('\n')
} }
// Write the document line that contains the error.
buf.WriteString(formatLineNumber(errLine, lineColumnWidth)) buf.WriteString(formatLineNumber(errLine, lineColumnWidth))
buf.WriteString("| ") buf.WriteString("| ")
@@ -143,6 +146,10 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
} }
buf.WriteRune('\n') buf.WriteRune('\n')
// Write the line with the error message itself (so it does not have a line
// number).
buf.WriteString(strings.Repeat(" ", lineColumnWidth)) buf.WriteString(strings.Repeat(" ", lineColumnWidth))
buf.WriteString("| ") buf.WriteString("| ")
@@ -157,6 +164,8 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.WriteString(errMessage) buf.WriteString(errMessage)
} }
// Write the lines of context strictly after the error.
for i := 1; i < len(after); i++ { for i := 1; i < len(after); i++ {
buf.WriteRune('\n') buf.WriteRune('\n')
line := errLine + i line := errLine + i
@@ -230,7 +239,7 @@ forward:
rest = rest[o+1:] rest = rest[o+1:]
o = 0 o = 0
case o == len(rest)-1 && o > 0: case o == len(rest)-1:
// add last line only if it's non-empty // add last line only if it's non-empty
afterLines = append(afterLines, rest) afterLines = append(afterLines, rest)
+7
View File
@@ -148,6 +148,13 @@ line 5`,
6| 6|
7| line 4`, 7| line 4`,
}, },
{
desc: "handle remainder of the error line when there is only one line",
doc: [3]string{`P=`, `[`, `#`},
msg: "array is incomplete",
expected: `1| P=[#
| ~ array is incomplete`,
},
} }
for _, e := range examples { for _, e := range examples {
+1 -1
View File
@@ -1,6 +1,6 @@
module github.com/pelletier/go-toml/v2 module github.com/pelletier/go-toml/v2
go 1.15 go 1.16
// latest (v1.7.0) doesn't have the fix for time.Time // latest (v1.7.0) doesn't have the fix for time.Time
require github.com/stretchr/testify v1.7.1-0.20210427113832-6241f9ab9942 require github.com/stretchr/testify v1.7.1-0.20210427113832-6241f9ab9942
-7
View File
@@ -136,7 +136,6 @@ func (n *Node) Key() Iterator {
// Guaranteed to be non-nil. // Guaranteed to be non-nil.
// Panics if not called on a KeyValue node, or if the Children are malformed. // Panics if not called on a KeyValue node, or if the Children are malformed.
func (n *Node) Value() *Node { func (n *Node) Value() *Node {
assertKind(KeyValue, *n)
return n.Child() return n.Child()
} }
@@ -144,9 +143,3 @@ func (n *Node) Value() *Node {
func (n *Node) Children() Iterator { func (n *Node) Children() Iterator {
return Iterator{node: n.Child()} return Iterator{node: n.Child()}
} }
func assertKind(k Kind, n Node) {
if n.Kind != k {
panic(fmt.Errorf("method was expecting a %s, not a %s", k, n.Kind))
}
}
+3 -3
View File
@@ -25,9 +25,9 @@ const (
Float Float
Integer Integer
LocalDate LocalDate
LocalTime
LocalDateTime LocalDateTime
DateTime DateTime
Time
) )
func (k Kind) String() string { func (k Kind) String() string {
@@ -58,12 +58,12 @@ func (k Kind) String() string {
return "Integer" return "Integer"
case LocalDate: case LocalDate:
return "LocalDate" return "LocalDate"
case LocalTime:
return "LocalTime"
case LocalDateTime: case LocalDateTime:
return "LocalDateTime" return "LocalDateTime"
case DateTime: case DateTime:
return "DateTime" return "DateTime"
case Time:
return "Time"
} }
panic(fmt.Errorf("Kind.String() not implemented for '%d'", k)) panic(fmt.Errorf("Kind.String() not implemented for '%d'", k))
} }
+11
View File
@@ -63,3 +63,14 @@ func Stride(ptr unsafe.Pointer, size uintptr, offset int) unsafe.Pointer {
// https://github.com/golang/go/issues/40481 // https://github.com/golang/go/issues/40481
return unsafe.Pointer(uintptr(ptr) + uintptr(int(size)*offset)) return unsafe.Pointer(uintptr(ptr) + uintptr(int(size)*offset))
} }
type Slice struct {
Data unsafe.Pointer
Len int
Cap int
}
type iface struct {
typ unsafe.Pointer
ptr unsafe.Pointer
}
+20
View File
@@ -0,0 +1,20 @@
//go:build go1.18
// +build go1.18
package danger
import (
"reflect"
"unsafe"
)
func ExtendSlice(t reflect.Type, s *Slice, n int) Slice {
arrayType := reflect.ArrayOf(n, t.Elem())
arrayData := reflect.New(arrayType)
reflect.Copy(arrayData.Elem(), reflect.NewAt(t, unsafe.Pointer(s)).Elem())
return Slice{
Data: unsafe.Pointer(arrayData.Pointer()),
Len: s.Len,
Cap: n,
}
}
+30
View File
@@ -0,0 +1,30 @@
//go:build !go1.18
// +build !go1.18
package danger
import (
"reflect"
"unsafe"
)
//go:linkname unsafe_NewArray reflect.unsafe_NewArray
func unsafe_NewArray(rtype unsafe.Pointer, length int) unsafe.Pointer
//go:linkname typedslicecopy reflect.typedslicecopy
//go:noescape
func typedslicecopy(elemType unsafe.Pointer, dst, src Slice) int
func ExtendSlice(t reflect.Type, s *Slice, n int) Slice {
elemTypeRef := t.Elem()
elemTypePtr := ((*iface)(unsafe.Pointer(&elemTypeRef))).ptr
d := Slice{
Data: unsafe_NewArray(elemTypePtr, n),
Len: s.Len,
Cap: n,
}
typedslicecopy(elemTypePtr, d, *s)
return d
}
+23
View File
@@ -0,0 +1,23 @@
package danger
import (
"reflect"
"unsafe"
)
// typeID is used as key in encoder and decoder caches to enable using
// the optimize runtime.mapaccess2_fast64 function instead of the more
// expensive lookup if we were to use reflect.Type as map key.
//
// typeID holds the pointer to the reflect.Type value, which is unique
// in the program.
//
// https://github.com/segmentio/encoding/blob/master/json/codec.go#L59-L61
type TypeID unsafe.Pointer
func MakeTypeID(t reflect.Type) TypeID {
// reflect.Type has the fields:
// typ unsafe.Pointer
// ptr unsafe.Pointer
return TypeID((*[2]unsafe.Pointer)(unsafe.Pointer(&t))[1])
}
@@ -1487,12 +1487,12 @@ func TestUnmarshalLocalDateTime(t *testing.T) {
name: "normal", name: "normal",
in: "1979-05-27T07:32:00", in: "1979-05-27T07:32:00",
out: toml.LocalDateTime{ out: toml.LocalDateTime{
Date: toml.LocalDate{ LocalDate: toml.LocalDate{
Year: 1979, Year: 1979,
Month: 5, Month: 5,
Day: 27, Day: 27,
}, },
Time: toml.LocalTime{ LocalTime: toml.LocalTime{
Hour: 7, Hour: 7,
Minute: 32, Minute: 32,
Second: 0, Second: 0,
@@ -1504,16 +1504,17 @@ func TestUnmarshalLocalDateTime(t *testing.T) {
name: "with nanoseconds", name: "with nanoseconds",
in: "1979-05-27T00:32:00.999999", in: "1979-05-27T00:32:00.999999",
out: toml.LocalDateTime{ out: toml.LocalDateTime{
Date: toml.LocalDate{ LocalDate: toml.LocalDate{
Year: 1979, Year: 1979,
Month: 5, Month: 5,
Day: 27, Day: 27,
}, },
Time: toml.LocalTime{ LocalTime: toml.LocalTime{
Hour: 0, Hour: 0,
Minute: 32, Minute: 32,
Second: 0, Second: 0,
Nanosecond: 999999000, Nanosecond: 999999000,
Precision: 6,
}, },
}, },
}, },
@@ -1551,26 +1552,26 @@ func TestUnmarshalLocalDateTime(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if obj.Date.Year() != example.out.Date.Year { if obj.Date.Year() != example.out.Year {
t.Errorf("expected year %d, got %d", example.out.Date.Year, obj.Date.Year()) t.Errorf("expected year %d, got %d", example.out.Year, obj.Date.Year())
} }
if obj.Date.Month() != example.out.Date.Month { if obj.Date.Month() != time.Month(example.out.Month) {
t.Errorf("expected month %d, got %d", example.out.Date.Month, obj.Date.Month()) t.Errorf("expected month %d, got %d", example.out.Month, obj.Date.Month())
} }
if obj.Date.Day() != example.out.Date.Day { if obj.Date.Day() != example.out.Day {
t.Errorf("expected day %d, got %d", example.out.Date.Day, obj.Date.Day()) t.Errorf("expected day %d, got %d", example.out.Day, obj.Date.Day())
} }
if obj.Date.Hour() != example.out.Time.Hour { if obj.Date.Hour() != example.out.Hour {
t.Errorf("expected hour %d, got %d", example.out.Time.Hour, obj.Date.Hour()) t.Errorf("expected hour %d, got %d", example.out.Hour, obj.Date.Hour())
} }
if obj.Date.Minute() != example.out.Time.Minute { if obj.Date.Minute() != example.out.Minute {
t.Errorf("expected minute %d, got %d", example.out.Time.Minute, obj.Date.Minute()) t.Errorf("expected minute %d, got %d", example.out.Minute, obj.Date.Minute())
} }
if obj.Date.Second() != example.out.Time.Second { if obj.Date.Second() != example.out.Second {
t.Errorf("expected second %d, got %d", example.out.Time.Second, obj.Date.Second()) t.Errorf("expected second %d, got %d", example.out.Second, obj.Date.Second())
} }
if obj.Date.Nanosecond() != example.out.Time.Nanosecond { if obj.Date.Nanosecond() != example.out.Nanosecond {
t.Errorf("expected nanoseconds %d, got %d", example.out.Time.Nanosecond, obj.Date.Nanosecond()) t.Errorf("expected nanoseconds %d, got %d", example.out.Nanosecond, obj.Date.Nanosecond())
} }
}) })
} }
@@ -1600,6 +1601,7 @@ func TestUnmarshalLocalTime(t *testing.T) {
Minute: 32, Minute: 32,
Second: 0, Second: 0,
Nanosecond: 999999000, Nanosecond: 999999000,
Precision: 6,
}, },
}, },
} }
+66 -19
View File
@@ -65,7 +65,7 @@ type entry struct {
explicit bool explicit bool
} }
// Remove all descendent of node at position idx. // Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) { func (s *SeenTracker) clear(idx int) {
p := s.entries[idx].id p := s.entries[idx].id
rest := clear(p, s.entries[idx+1:]) rest := clear(p, s.entries[idx+1:])
@@ -102,20 +102,21 @@ func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit
return idx return idx
} }
// CheckExpression takes a top-level node and checks that it does not contain keys // CheckExpression takes a top-level node and checks that it does not contain
// that have been seen in previous calls, and validates that types are consistent. // keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error { func (s *SeenTracker) CheckExpression(node *ast.Node) error {
if s.entries == nil { if s.entries == nil {
// s.entries = make([]entry, 0, 8) // Skip ID = 0 to remove the confusion between nodes whose
// Skip ID = 0 to remove the confusion between nodes whose parent has // parent has id 0 and root nodes (parent id is 0 because it's
// id 0 and root nodes (parent id is 0 because it's the zero value). // the zero value).
s.nextID = 1 s.nextID = 1
// Start unscoped, so idx is negative. // Start unscoped, so idx is negative.
s.currentIdx = -1 s.currentIdx = -1
} }
switch node.Kind { switch node.Kind {
case ast.KeyValue: case ast.KeyValue:
return s.checkKeyValue(node) return s.checkKeyValue(s.currentIdx, node)
case ast.Table: case ast.Table:
return s.checkTable(node) return s.checkTable(node)
case ast.ArrayTable: case ast.ArrayTable:
@@ -207,36 +208,82 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkKeyValue(node *ast.Node) error { func (s *SeenTracker) checkKeyValue(parentIdx int, node *ast.Node) error {
it := node.Key() it := node.Key()
parentIdx := s.currentIdx
for it.Next() { for it.Next() {
k := it.Node().Data k := it.Node().Data
idx := s.find(parentIdx, k) idx := s.find(parentIdx, k)
if idx >= 0 { if idx < 0 {
if s.entries[idx].kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), s.entries[idx].kind)
}
} else {
idx = s.create(parentIdx, k, tableKind, false) idx = s.create(parentIdx, k, tableKind, false)
} else {
entry := s.entries[idx]
if it.IsLast() {
return fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit {
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
}
} }
parentIdx = idx parentIdx = idx
} }
kind := valueKind s.entries[parentIdx].kind = valueKind
if node.Value().Kind == ast.InlineTable { value := node.Value()
kind = tableKind
switch value.Kind {
case ast.InlineTable:
return s.checkInlineTable(parentIdx, value)
case ast.Array:
return s.checkArray(parentIdx, value)
} }
s.entries[parentIdx].kind = kind
return nil return nil
} }
func (s *SeenTracker) checkArray(parentIdx int, node *ast.Node) error {
set := false
it := node.Children()
for it.Next() {
if set {
s.clear(parentIdx)
}
n := it.Node()
switch n.Kind {
case ast.InlineTable:
err := s.checkInlineTable(parentIdx, n)
if err != nil {
return err
}
set = true
case ast.Array:
err := s.checkArray(parentIdx, n)
if err != nil {
return err
}
set = true
}
}
return nil
}
func (s *SeenTracker) checkInlineTable(parentIdx int, node *ast.Node) error {
it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(parentIdx, n)
if err != nil {
return err
}
}
return nil
}
func (s *SeenTracker) id(idx int) int { func (s *SeenTracker) id(idx int) int {
if idx >= 0 { if idx >= 0 {
return s.entries[idx].id return s.entries[idx].id
+79 -259
View File
@@ -1,300 +1,120 @@
// Implementation of TOML's local date/time.
// Copied over from https://github.com/googleapis/google-cloud-go/blob/master/civil/civil.go
// to avoid pulling all the Google dependencies.
//
// Copyright 2016 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package civil implements types for civil time, a time-zone-independent
// representation of time that follows the rules of the proleptic
// Gregorian calendar with exactly 24-hour days, 60-minute hours, and 60-second
// minutes.
//
// Because they lack location information, these types do not represent unique
// moments or intervals of time. Use time.Time for that purpose.
package toml package toml
import ( import (
"fmt" "fmt"
"strings"
"time" "time"
) )
// A LocalDate represents a date (year, month, day). // LocalDate represents a calendar day in no specific timezone.
//
// This type does not include location information, and therefore does not
// describe a unique 24-hour timespan.
type LocalDate struct { type LocalDate struct {
Year int // Year (e.g., 2014). Year int
Month time.Month // Month of the year (January = 1, ...). Month int
Day int // Day of the month, starting at 1. Day int
} }
// LocalDateOf returns the LocalDate in which a time occurs in that time's location. // AsTime converts d into a specific time instance at midnight in zone.
func LocalDateOf(t time.Time) LocalDate { func (d LocalDate) AsTime(zone *time.Location) time.Time {
var d LocalDate return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, zone)
d.Year, d.Month, d.Day = t.Date()
return d
} }
// ParseLocalDate parses a string in RFC3339 full-date format and returns the date value it represents. // String returns RFC 3339 representation of d.
func ParseLocalDate(s string) (LocalDate, error) {
t, err := time.Parse("2006-01-02", s)
if err != nil {
return LocalDate{}, err
}
return LocalDateOf(t), nil
}
// String returns the date in RFC3339 full-date format.
func (d LocalDate) String() string { func (d LocalDate) String() string {
return fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day) return fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day)
} }
// IsValid reports whether the date is valid. // MarshalText returns RFC 3339 representation of d.
func (d LocalDate) IsValid() bool {
return LocalDateOf(d.In(time.UTC)) == d
}
// In returns the time corresponding to time 00:00:00 of the date in the location.
//
// In is always consistent with time.LocalDate, even when time.LocalDate returns a time
// on a different day. For example, if loc is America/Indiana/Vincennes, then both
// time.LocalDate(1955, time.May, 1, 0, 0, 0, 0, loc)
// and
// civil.LocalDate{Year: 1955, Month: time.May, Day: 1}.In(loc)
// return 23:00:00 on April 30, 1955.
//
// In panics if loc is nil.
func (d LocalDate) In(loc *time.Location) time.Time {
return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, loc)
}
// AddDays returns the date that is n days in the future.
// n can also be negative to go into the past.
func (d LocalDate) AddDays(n int) LocalDate {
return LocalDateOf(d.In(time.UTC).AddDate(0, 0, n))
}
// DaysSince returns the signed number of days between the date and s, not including the end day.
// This is the inverse operation to AddDays.
func (d LocalDate) DaysSince(s LocalDate) (days int) {
// We convert to Unix time so we do not have to worry about leap seconds:
// Unix time increases by exactly 86400 seconds per day.
deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix()
const secondsInADay = 86400
return int(deltaUnix / secondsInADay)
}
// Before reports whether d1 occurs before future date.
func (d LocalDate) Before(future LocalDate) bool {
if d.Year != future.Year {
return d.Year < future.Year
}
if d.Month != future.Month {
return d.Month < future.Month
}
return d.Day < future.Day
}
// After reports whether d1 occurs after past date.
func (d LocalDate) After(past LocalDate) bool {
return past.Before(d)
}
// MarshalText implements the encoding.TextMarshaler interface.
// The output is the result of d.String().
func (d LocalDate) MarshalText() ([]byte, error) { func (d LocalDate) MarshalText() ([]byte, error) {
return []byte(d.String()), nil return []byte(d.String()), nil
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText parses b using RFC 3339 to fill d.
// The date is expected to be a string in a format accepted by ParseLocalDate. func (d *LocalDate) UnmarshalText(b []byte) error {
func (d *LocalDate) UnmarshalText(data []byte) error { res, err := parseLocalDate(b)
var err error if err != nil {
*d, err = ParseLocalDate(string(data)) return err
}
return err *d = res
return nil
} }
// A LocalTime represents a time with nanosecond precision. // LocalTime represents a time of day of no specific day in no specific
// // timezone.
// This type does not include location information, and therefore does not
// describe a unique moment in time.
//
// This type exists to represent the TIME type in storage-based APIs like BigQuery.
// Most operations on Times are unlikely to be meaningful. Prefer the LocalDateTime type.
type LocalTime struct { type LocalTime struct {
Hour int // The hour of the day in 24-hour format; range [0-23] Hour int // Hour of the day: [0; 24[
Minute int // The minute of the hour; range [0-59] Minute int // Minute of the hour: [0; 60[
Second int // The second of the minute; range [0-59] Second int // Second of the minute: [0; 60[
Nanosecond int // The nanosecond of the second; range [0-999999999] Nanosecond int // Nanoseconds within the second: [0, 1000000000[
Precision int // Number of digits to display for Nanosecond.
} }
// LocalTimeOf returns the LocalTime representing the time of day in which a time occurs // String returns RFC 3339 representation of d.
// in that time's location. It ignores the date. // If d.Nanosecond and d.Precision are zero, the time won't have a nanosecond
func LocalTimeOf(t time.Time) LocalTime { // component. If d.Nanosecond > 0 but d.Precision = 0, then the minimum number
var tm LocalTime // of digits for nanoseconds is provided.
tm.Hour, tm.Minute, tm.Second = t.Clock() func (d LocalTime) String() string {
tm.Nanosecond = t.Nanosecond() s := fmt.Sprintf("%02d:%02d:%02d", d.Hour, d.Minute, d.Second)
return tm if d.Precision > 0 {
s += fmt.Sprintf(".%09d", d.Nanosecond)[:d.Precision+1]
} else if d.Nanosecond > 0 {
// Nanoseconds are specified, but precision is not provided. Use the
// minimum.
s += strings.Trim(fmt.Sprintf(".%09d", d.Nanosecond), "0")
}
return s
} }
// ParseLocalTime parses a string and returns the time value it represents. // MarshalText returns RFC 3339 representation of d.
// ParseLocalTime accepts an extended form of the RFC3339 partial-time format. After func (d LocalTime) MarshalText() ([]byte, error) {
// the HH:MM:SS part of the string, an optional fractional part may appear, return []byte(d.String()), nil
// consisting of a decimal point followed by one to nine decimal digits. }
// (RFC3339 admits only one digit after the decimal point).
func ParseLocalTime(s string) (LocalTime, error) { // UnmarshalText parses b using RFC 3339 to fill d.
t, err := time.Parse("15:04:05.999999999", s) func (d *LocalTime) UnmarshalText(b []byte) error {
res, left, err := parseLocalTime(b)
if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters")
}
if err != nil { if err != nil {
return LocalTime{}, err return err
} }
*d = res
return LocalTimeOf(t), nil return nil
} }
// String returns the date in the format described in ParseLocalTime. If Nanoseconds // LocalDateTime represents a time of a specific day in no specific timezone.
// is zero, no fractional part will be generated. Otherwise, the result will
// end with a fractional part consisting of a decimal point and nine digits.
func (t LocalTime) String() string {
s := fmt.Sprintf("%02d:%02d:%02d", t.Hour, t.Minute, t.Second)
if t.Nanosecond == 0 {
return s
}
return s + fmt.Sprintf(".%09d", t.Nanosecond)
}
// IsValid reports whether the time is valid.
func (t LocalTime) IsValid() bool {
// Construct a non-zero time.
tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC)
return LocalTimeOf(tm) == t
}
// MarshalText implements the encoding.TextMarshaler interface.
// The output is the result of t.String().
func (t LocalTime) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The time is expected to be a string in a format accepted by ParseLocalTime.
func (t *LocalTime) UnmarshalText(data []byte) error {
var err error
*t, err = ParseLocalTime(string(data))
return err
}
// A LocalDateTime represents a date and time.
//
// This type does not include location information, and therefore does not
// describe a unique moment in time.
type LocalDateTime struct { type LocalDateTime struct {
Date LocalDate LocalDate
Time LocalTime LocalTime
} }
// Note: We deliberately do not embed LocalDate into LocalDateTime, to avoid promoting AddDays and Sub. // AsTime converts d into a specific time instance in zone.
func (d LocalDateTime) AsTime(zone *time.Location) time.Time {
return time.Date(d.Year, time.Month(d.Month), d.Day, d.Hour, d.Minute, d.Second, d.Nanosecond, zone)
}
// LocalDateTimeOf returns the LocalDateTime in which a time occurs in that time's location. // String returns RFC 3339 representation of d.
func LocalDateTimeOf(t time.Time) LocalDateTime { func (d LocalDateTime) String() string {
return LocalDateTime{ return d.LocalDate.String() + "T" + d.LocalTime.String()
Date: LocalDateOf(t), }
Time: LocalTimeOf(t),
// MarshalText returns RFC 3339 representation of d.
func (d LocalDateTime) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText parses b using RFC 3339 to fill d.
func (d *LocalDateTime) UnmarshalText(data []byte) error {
res, left, err := parseLocalDateTime(data)
if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters")
} }
}
// ParseLocalDateTime parses a string and returns the LocalDateTime it represents.
// ParseLocalDateTime accepts a variant of the RFC3339 date-time format that omits
// the time offset but includes an optional fractional time, as described in
// ParseLocalTime. Informally, the accepted format is
// YYYY-MM-DDTHH:MM:SS[.FFFFFFFFF]
// where the 'T' may be a lower-case 't'.
func ParseLocalDateTime(s string) (LocalDateTime, error) {
t, err := time.Parse("2006-01-02T15:04:05.999999999", s)
if err != nil { if err != nil {
t, err = time.Parse("2006-01-02t15:04:05.999999999", s) return err
if err != nil {
return LocalDateTime{}, err
}
} }
return LocalDateTimeOf(t), nil *d = res
} return nil
// String returns the date in the format described in ParseLocalDate.
func (dt LocalDateTime) String() string {
return dt.Date.String() + "T" + dt.Time.String()
}
// IsValid reports whether the datetime is valid.
func (dt LocalDateTime) IsValid() bool {
return dt.Date.IsValid() && dt.Time.IsValid()
}
// In returns the time corresponding to the LocalDateTime in the given location.
//
// If the time is missing or ambigous at the location, In returns the same
// result as time.LocalDate. For example, if loc is America/Indiana/Vincennes, then
// both
// time.LocalDate(1955, time.May, 1, 0, 30, 0, 0, loc)
// and
// civil.LocalDateTime{
// civil.LocalDate{Year: 1955, Month: time.May, Day: 1}},
// civil.LocalTime{Minute: 30}}.In(loc)
// return 23:30:00 on April 30, 1955.
//
// In panics if loc is nil.
func (dt LocalDateTime) In(loc *time.Location) time.Time {
return time.Date(
dt.Date.Year, dt.Date.Month, dt.Date.Day,
dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc,
)
}
// Before reports whether dt occurs before future.
func (dt LocalDateTime) Before(future LocalDateTime) bool {
return dt.In(time.UTC).Before(future.In(time.UTC))
}
// After reports whether dt occurs after past.
func (dt LocalDateTime) After(past LocalDateTime) bool {
return past.Before(dt)
}
// MarshalText implements the encoding.TextMarshaler interface.
// The output is the result of dt.String().
func (dt LocalDateTime) MarshalText() ([]byte, error) {
return []byte(dt.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The datetime is expected to be a string in a format accepted by ParseLocalDateTime.
func (dt *LocalDateTime) UnmarshalText(data []byte) error {
var err error
*dt, err = ParseLocalDateTime(string(data))
return err
} }
+83 -454
View File
@@ -1,489 +1,118 @@
// Copyright 2016 Google LLC package toml_test
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package toml
import ( import (
"encoding/json"
"reflect"
"testing" "testing"
"time" "time"
"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/require"
) )
func cmpEqual(x, y interface{}) bool { func TestLocalDate_AsTime(t *testing.T) {
return reflect.DeepEqual(x, y) d := toml.LocalDate{2021, 6, 8}
cast := d.AsTime(time.UTC)
require.Equal(t, time.Date(2021, time.June, 8, 0, 0, 0, 0, time.UTC), cast)
} }
func TestDates(t *testing.T) { func TestLocalDate_String(t *testing.T) {
d := toml.LocalDate{2021, 6, 8}
for _, test := range []struct { require.Equal(t, "2021-06-08", d.String())
date LocalDate
loc *time.Location
wantStr string
wantTime time.Time
}{
{
date: LocalDate{2014, 7, 29},
loc: time.Local,
wantStr: "2014-07-29",
wantTime: time.Date(2014, time.July, 29, 0, 0, 0, 0, time.Local),
},
{
date: LocalDateOf(time.Date(2014, 8, 20, 15, 8, 43, 1, time.Local)),
loc: time.UTC,
wantStr: "2014-08-20",
wantTime: time.Date(2014, 8, 20, 0, 0, 0, 0, time.UTC),
},
{
date: LocalDateOf(time.Date(999, time.January, 26, 0, 0, 0, 0, time.Local)),
loc: time.UTC,
wantStr: "0999-01-26",
wantTime: time.Date(999, 1, 26, 0, 0, 0, 0, time.UTC),
},
} {
if got := test.date.String(); got != test.wantStr {
t.Errorf("%#v.String() = %q, want %q", test.date, got, test.wantStr)
}
if got := test.date.In(test.loc); !got.Equal(test.wantTime) {
t.Errorf("%#v.In(%v) = %v, want %v", test.date, test.loc, got, test.wantTime)
}
}
} }
func TestDateIsValid(t *testing.T) { func TestLocalDate_MarshalText(t *testing.T) {
d := toml.LocalDate{2021, 6, 8}
for _, test := range []struct { b, err := d.MarshalText()
date LocalDate require.NoError(t, err)
want bool require.Equal(t, []byte("2021-06-08"), b)
}{
{LocalDate{2014, 7, 29}, true},
{LocalDate{2000, 2, 29}, true},
{LocalDate{10000, 12, 31}, true},
{LocalDate{1, 1, 1}, true},
{LocalDate{0, 1, 1}, true}, // year zero is OK
{LocalDate{-1, 1, 1}, true}, // negative year is OK
{LocalDate{1, 0, 1}, false},
{LocalDate{1, 1, 0}, false},
{LocalDate{2016, 1, 32}, false},
{LocalDate{2016, 13, 1}, false},
{LocalDate{1, -1, 1}, false},
{LocalDate{1, 1, -1}, false},
} {
got := test.date.IsValid()
if got != test.want {
t.Errorf("%#v: got %t, want %t", test.date, got, test.want)
}
}
} }
func TestParseDate(t *testing.T) { func TestLocalDate_UnmarshalMarshalText(t *testing.T) {
d := toml.LocalDate{}
err := d.UnmarshalText([]byte("2021-06-08"))
require.NoError(t, err)
require.Equal(t, toml.LocalDate{2021, 6, 8}, d)
var emptyDate LocalDate err = d.UnmarshalText([]byte("what"))
require.Error(t, err)
for _, test := range []struct {
str string
want LocalDate // if empty, expect an error
}{
{"2016-01-02", LocalDate{2016, 1, 2}},
{"2016-12-31", LocalDate{2016, 12, 31}},
{"0003-02-04", LocalDate{3, 2, 4}},
{"999-01-26", emptyDate},
{"", emptyDate},
{"2016-01-02x", emptyDate},
} {
got, err := ParseLocalDate(test.str)
if got != test.want {
t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want)
}
if err != nil && test.want != (emptyDate) {
t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str)
}
}
} }
func TestDateArithmetic(t *testing.T) { func TestLocalTime_String(t *testing.T) {
d := toml.LocalTime{20, 12, 1, 2, 9}
for _, test := range []struct { require.Equal(t, "20:12:01.000000002", d.String())
desc string d = toml.LocalTime{20, 12, 1, 0, 0}
start LocalDate require.Equal(t, "20:12:01", d.String())
end LocalDate d = toml.LocalTime{20, 12, 1, 0, 9}
days int require.Equal(t, "20:12:01.000000000", d.String())
}{ d = toml.LocalTime{20, 12, 1, 100, 0}
{ require.Equal(t, "20:12:01.0000001", d.String())
desc: "zero days noop",
start: LocalDate{2014, 5, 9},
end: LocalDate{2014, 5, 9},
days: 0,
},
{
desc: "crossing a year boundary",
start: LocalDate{2014, 12, 31},
end: LocalDate{2015, 1, 1},
days: 1,
},
{
desc: "negative number of days",
start: LocalDate{2015, 1, 1},
end: LocalDate{2014, 12, 31},
days: -1,
},
{
desc: "full leap year",
start: LocalDate{2004, 1, 1},
end: LocalDate{2005, 1, 1},
days: 366,
},
{
desc: "full non-leap year",
start: LocalDate{2001, 1, 1},
end: LocalDate{2002, 1, 1},
days: 365,
},
{
desc: "crossing a leap second",
start: LocalDate{1972, 6, 30},
end: LocalDate{1972, 7, 1},
days: 1,
},
{
desc: "dates before the unix epoch",
start: LocalDate{101, 1, 1},
end: LocalDate{102, 1, 1},
days: 365,
},
} {
if got := test.start.AddDays(test.days); got != test.end {
t.Errorf("[%s] %#v.AddDays(%v) = %#v, want %#v", test.desc, test.start, test.days, got, test.end)
}
if got := test.end.DaysSince(test.start); got != test.days {
t.Errorf("[%s] %#v.Sub(%#v) = %v, want %v", test.desc, test.end, test.start, got, test.days)
}
}
} }
func TestDateBefore(t *testing.T) { func TestLocalTime_MarshalText(t *testing.T) {
d := toml.LocalTime{20, 12, 1, 2, 9}
for _, test := range []struct { b, err := d.MarshalText()
d1, d2 LocalDate require.NoError(t, err)
want bool require.Equal(t, []byte("20:12:01.000000002"), b)
}{
{LocalDate{2016, 12, 31}, LocalDate{2017, 1, 1}, true},
{LocalDate{2016, 1, 1}, LocalDate{2016, 1, 1}, false},
{LocalDate{2016, 12, 30}, LocalDate{2016, 12, 31}, true},
{LocalDate{2016, 1, 30}, LocalDate{2016, 12, 31}, true},
} {
if got := test.d1.Before(test.d2); got != test.want {
t.Errorf("%v.Before(%v): got %t, want %t", test.d1, test.d2, got, test.want)
}
}
} }
func TestDateAfter(t *testing.T) { func TestLocalTime_UnmarshalMarshalText(t *testing.T) {
d := toml.LocalTime{}
err := d.UnmarshalText([]byte("20:12:01.000000002"))
require.NoError(t, err)
require.Equal(t, toml.LocalTime{20, 12, 1, 2, 9}, d)
for _, test := range []struct { err = d.UnmarshalText([]byte("what"))
d1, d2 LocalDate require.Error(t, err)
want bool
}{ err = d.UnmarshalText([]byte("20:12:01.000000002 bad"))
{LocalDate{2016, 12, 31}, LocalDate{2017, 1, 1}, false}, require.Error(t, err)
{LocalDate{2016, 1, 1}, LocalDate{2016, 1, 1}, false},
{LocalDate{2016, 12, 30}, LocalDate{2016, 12, 31}, false},
} {
if got := test.d1.After(test.d2); got != test.want {
t.Errorf("%v.After(%v): got %t, want %t", test.d1, test.d2, got, test.want)
}
}
} }
func TestTimeToString(t *testing.T) { func TestLocalTime_RoundTrip(t *testing.T) {
var d struct{ A toml.LocalTime }
for _, test := range []struct { err := toml.Unmarshal([]byte("a=20:12:01.500"), &d)
str string require.NoError(t, err)
time LocalTime require.Equal(t, "20:12:01.500", d.A.String())
roundTrip bool // ParseLocalTime(str).String() == str?
}{
{"13:26:33", LocalTime{13, 26, 33, 0}, true},
{"01:02:03.000023456", LocalTime{1, 2, 3, 23456}, true},
{"00:00:00.000000001", LocalTime{0, 0, 0, 1}, true},
{"13:26:03.1", LocalTime{13, 26, 3, 100000000}, false},
{"13:26:33.0000003", LocalTime{13, 26, 33, 300}, false},
} {
gotTime, err := ParseLocalTime(test.str)
if err != nil {
t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err)
continue
}
if gotTime != test.time {
t.Errorf("ParseLocalTime(%q) = %+v, want %+v", test.str, gotTime, test.time)
}
if test.roundTrip {
gotStr := test.time.String()
if gotStr != test.str {
t.Errorf("%#v.String() = %q, want %q", test.time, gotStr, test.str)
}
}
}
} }
func TestTimeOf(t *testing.T) { func TestLocalDateTime_AsTime(t *testing.T) {
d := toml.LocalDateTime{
for _, test := range []struct { toml.LocalDate{2021, 6, 8},
time time.Time toml.LocalTime{20, 12, 1, 2, 9},
want LocalTime
}{
{time.Date(2014, 8, 20, 15, 8, 43, 1, time.Local), LocalTime{15, 8, 43, 1}},
{time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), LocalTime{0, 0, 0, 0}},
} {
if got := LocalTimeOf(test.time); got != test.want {
t.Errorf("LocalTimeOf(%v) = %+v, want %+v", test.time, got, test.want)
}
} }
cast := d.AsTime(time.UTC)
require.Equal(t, time.Date(2021, time.June, 8, 20, 12, 1, 2, time.UTC), cast)
} }
func TestTimeIsValid(t *testing.T) { func TestLocalDateTime_String(t *testing.T) {
d := toml.LocalDateTime{
for _, test := range []struct { toml.LocalDate{2021, 6, 8},
time LocalTime toml.LocalTime{20, 12, 1, 2, 9},
want bool
}{
{LocalTime{0, 0, 0, 0}, true},
{LocalTime{23, 0, 0, 0}, true},
{LocalTime{23, 59, 59, 999999999}, true},
{LocalTime{24, 59, 59, 999999999}, false},
{LocalTime{23, 60, 59, 999999999}, false},
{LocalTime{23, 59, 60, 999999999}, false},
{LocalTime{23, 59, 59, 1000000000}, false},
{LocalTime{-1, 0, 0, 0}, false},
{LocalTime{0, -1, 0, 0}, false},
{LocalTime{0, 0, -1, 0}, false},
{LocalTime{0, 0, 0, -1}, false},
} {
got := test.time.IsValid()
if got != test.want {
t.Errorf("%#v: got %t, want %t", test.time, got, test.want)
}
} }
require.Equal(t, "2021-06-08T20:12:01.000000002", d.String())
} }
func TestDateTimeToString(t *testing.T) { func TestLocalDateTime_MarshalText(t *testing.T) {
d := toml.LocalDateTime{
for _, test := range []struct { toml.LocalDate{2021, 6, 8},
str string toml.LocalTime{20, 12, 1, 2, 9},
dateTime LocalDateTime
roundTrip bool // ParseLocalDateTime(str).String() == str?
}{
{"2016-03-22T13:26:33", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 0}}, true},
{"2016-03-22T13:26:33.000000600", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 600}}, true},
{"2016-03-22t13:26:33", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 0}}, false},
} {
gotDateTime, err := ParseLocalDateTime(test.str)
if err != nil {
t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err)
continue
}
if gotDateTime != test.dateTime {
t.Errorf("ParseLocalDateTime(%q) = %+v, want %+v", test.str, gotDateTime, test.dateTime)
}
if test.roundTrip {
gotStr := test.dateTime.String()
if gotStr != test.str {
t.Errorf("%#v.String() = %q, want %q", test.dateTime, gotStr, test.str)
}
}
} }
b, err := d.MarshalText()
require.NoError(t, err)
require.Equal(t, []byte("2021-06-08T20:12:01.000000002"), b)
} }
func TestParseDateTimeErrors(t *testing.T) { func TestLocalDateTime_UnmarshalMarshalText(t *testing.T) {
d := toml.LocalDateTime{}
err := d.UnmarshalText([]byte("2021-06-08 20:12:01.000000002"))
require.NoError(t, err)
require.Equal(t, toml.LocalDateTime{
toml.LocalDate{2021, 6, 8},
toml.LocalTime{20, 12, 1, 2, 9},
}, d)
for _, str := range []string{ err = d.UnmarshalText([]byte("what"))
"", require.Error(t, err)
"2016-03-22", // just a date
"13:26:33", // just a time err = d.UnmarshalText([]byte("2021-06-08 20:12:01.000000002 bad"))
"2016-03-22 13:26:33", // wrong separating character require.Error(t, err)
"2016-03-22T13:26:33x", // extra at end
} {
if _, err := ParseLocalDateTime(str); err == nil {
t.Errorf("ParseLocalDateTime(%q) succeeded, want error", str)
}
}
}
func TestDateTimeOf(t *testing.T) {
for _, test := range []struct {
time time.Time
want LocalDateTime
}{
{
time.Date(2014, 8, 20, 15, 8, 43, 1, time.Local),
LocalDateTime{LocalDate{2014, 8, 20}, LocalTime{15, 8, 43, 1}},
},
{
time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
LocalDateTime{LocalDate{1, 1, 1}, LocalTime{0, 0, 0, 0}},
},
} {
if got := LocalDateTimeOf(test.time); got != test.want {
t.Errorf("LocalDateTimeOf(%v) = %+v, want %+v", test.time, got, test.want)
}
}
}
func TestDateTimeIsValid(t *testing.T) {
// No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid.
for _, test := range []struct {
dt LocalDateTime
want bool
}{
{LocalDateTime{LocalDate{2016, 3, 20}, LocalTime{0, 0, 0, 0}}, true},
{LocalDateTime{LocalDate{2016, -3, 20}, LocalTime{0, 0, 0, 0}}, false},
{LocalDateTime{LocalDate{2016, 3, 20}, LocalTime{24, 0, 0, 0}}, false},
} {
got := test.dt.IsValid()
if got != test.want {
t.Errorf("%#v: got %t, want %t", test.dt, got, test.want)
}
}
}
func TestDateTimeIn(t *testing.T) {
dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}}
want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC)
if got := dt.In(time.UTC); !got.Equal(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func TestDateTimeBefore(t *testing.T) {
d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct {
dt1, dt2 LocalDateTime
want bool
}{
{LocalDateTime{d1, t1}, LocalDateTime{d2, t1}, true},
{LocalDateTime{d1, t1}, LocalDateTime{d1, t2}, true},
{LocalDateTime{d2, t1}, LocalDateTime{d1, t1}, false},
{LocalDateTime{d2, t1}, LocalDateTime{d2, t1}, false},
} {
if got := test.dt1.Before(test.dt2); got != test.want {
t.Errorf("%v.Before(%v): got %t, want %t", test.dt1, test.dt2, got, test.want)
}
}
}
func TestDateTimeAfter(t *testing.T) {
d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct {
dt1, dt2 LocalDateTime
want bool
}{
{LocalDateTime{d1, t1}, LocalDateTime{d2, t1}, false},
{LocalDateTime{d1, t1}, LocalDateTime{d1, t2}, false},
{LocalDateTime{d2, t1}, LocalDateTime{d1, t1}, true},
{LocalDateTime{d2, t1}, LocalDateTime{d2, t1}, false},
} {
if got := test.dt1.After(test.dt2); got != test.want {
t.Errorf("%v.After(%v): got %t, want %t", test.dt1, test.dt2, got, test.want)
}
}
}
func TestMarshalJSON(t *testing.T) {
for _, test := range []struct {
value interface{}
want string
}{
{LocalDate{1987, 4, 15}, `"1987-04-15"`},
{LocalTime{18, 54, 2, 0}, `"18:54:02"`},
{LocalDateTime{LocalDate{1987, 4, 15}, LocalTime{18, 54, 2, 0}}, `"1987-04-15T18:54:02"`},
} {
bgot, err := json.Marshal(test.value)
if err != nil {
t.Fatal(err)
}
if got := string(bgot); got != test.want {
t.Errorf("%#v: got %s, want %s", test.value, got, test.want)
}
}
}
func TestUnmarshalJSON(t *testing.T) {
var (
d LocalDate
tm LocalTime
dt LocalDateTime
)
for _, test := range []struct {
data string
ptr interface{}
want interface{}
}{
{`"1987-04-15"`, &d, &LocalDate{1987, 4, 15}},
{`"1987-04-\u0031\u0035"`, &d, &LocalDate{1987, 4, 15}},
{`"18:54:02"`, &tm, &LocalTime{18, 54, 2, 0}},
{`"1987-04-15T18:54:02"`, &dt, &LocalDateTime{LocalDate{1987, 4, 15}, LocalTime{18, 54, 2, 0}}},
} {
if err := json.Unmarshal([]byte(test.data), test.ptr); err != nil {
t.Fatalf("%s: %v", test.data, err)
}
if !cmpEqual(test.ptr, test.want) {
t.Errorf("%s: got %#v, want %#v", test.data, test.ptr, test.want)
}
}
for _, bad := range []string{
"", `""`, `"bad"`, `"1987-04-15x"`,
`19870415`, // a JSON number
`11987-04-15x`, // not a JSON string
} {
if json.Unmarshal([]byte(bad), &d) == nil {
t.Errorf("%q, LocalDate: got nil, want error", bad)
}
if json.Unmarshal([]byte(bad), &tm) == nil {
t.Errorf("%q, LocalTime: got nil, want error", bad)
}
if json.Unmarshal([]byte(bad), &dt) == nil {
t.Errorf("%q, LocalDateTime: got nil, want error", bad)
}
}
} }
+19 -6
View File
@@ -5,6 +5,7 @@ import (
"encoding" "encoding"
"fmt" "fmt"
"io" "io"
"math"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@@ -53,8 +54,9 @@ func NewEncoder(w io.Writer) *Encoder {
// inline tag: // inline tag:
// //
// MyField `inline:"true"` // MyField `inline:"true"`
func (enc *Encoder) SetTablesInline(inline bool) { func (enc *Encoder) SetTablesInline(inline bool) *Encoder {
enc.tablesInline = inline enc.tablesInline = inline
return enc
} }
// SetArraysMultiline forces the encoder to emit all arrays with one element per // SetArraysMultiline forces the encoder to emit all arrays with one element per
@@ -63,20 +65,23 @@ func (enc *Encoder) SetTablesInline(inline bool) {
// This behavior can be controlled on an individual struct field basis with the multiline tag: // This behavior can be controlled on an individual struct field basis with the multiline tag:
// //
// MyField `multiline:"true"` // MyField `multiline:"true"`
func (enc *Encoder) SetArraysMultiline(multiline bool) { func (enc *Encoder) SetArraysMultiline(multiline bool) *Encoder {
enc.arraysMultiline = multiline enc.arraysMultiline = multiline
return enc
} }
// SetIndentSymbol defines the string that should be used for indentation. The // SetIndentSymbol defines the string that should be used for indentation. The
// provided string is repeated for each indentation level. Defaults to two // provided string is repeated for each indentation level. Defaults to two
// spaces. // spaces.
func (enc *Encoder) SetIndentSymbol(s string) { func (enc *Encoder) SetIndentSymbol(s string) *Encoder {
enc.indentSymbol = s enc.indentSymbol = s
return enc
} }
// SetIndentTables forces the encoder to intent tables and array tables. // SetIndentTables forces the encoder to intent tables and array tables.
func (enc *Encoder) SetIndentTables(indent bool) { func (enc *Encoder) SetIndentTables(indent bool) *Encoder {
enc.indentTables = indent enc.indentTables = indent
return enc
} }
// Encode writes a TOML representation of v to the stream. // Encode writes a TOML representation of v to the stream.
@@ -244,9 +249,17 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.String: case reflect.String:
b = enc.encodeString(b, v.String(), ctx.options) b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32: case reflect.Float32:
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32) if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 32)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
}
case reflect.Float64: case reflect.Float64:
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 64) if math.Trunc(v.Float()) == v.Float() {
b = strconv.AppendFloat(b, v.Float(), 'f', 1, 64)
} else {
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 64)
}
case reflect.Bool: case reflect.Bool:
if v.Bool() { if v.Bool() {
b = append(b, "true"...) b = append(b, "true"...)
+40 -1
View File
@@ -551,7 +551,7 @@ K = 42`,
type flagsSetters []struct { type flagsSetters []struct {
name string name string
f func(enc *toml.Encoder, flag bool) f func(enc *toml.Encoder, flag bool) *toml.Encoder
} }
var allFlags = flagsSetters{ var allFlags = flagsSetters{
@@ -782,6 +782,22 @@ func TestIssue424(t *testing.T) {
require.Equal(t, msg2, msg2parsed) require.Equal(t, msg2, msg2parsed)
} }
func TestIssue567(t *testing.T) {
var m map[string]interface{}
err := toml.Unmarshal([]byte("A = 12:08:05"), &m)
require.NoError(t, err)
require.IsType(t, m["A"], toml.LocalTime{})
}
func TestIssue590(t *testing.T) {
type CustomType int
var cfg struct {
Option CustomType `toml:"option"`
}
err := toml.Unmarshal([]byte("option = 42"), &cfg)
require.NoError(t, err)
}
func ExampleMarshal() { func ExampleMarshal() {
type MyConfig struct { type MyConfig struct {
Version int Version int
@@ -806,3 +822,26 @@ func ExampleMarshal() {
// Name = 'go-toml' // Name = 'go-toml'
// Tags = ['go', 'toml'] // Tags = ['go', 'toml']
} }
func TestIssue571(t *testing.T) {
type Foo struct {
Float32 float32
Float64 float64
}
const closeEnough = 1e-9
foo := Foo{
Float32: 42,
Float64: 43,
}
b, err := toml.Marshal(foo)
require.NoError(t, err)
var foo2 Foo
err = toml.Unmarshal(b, &foo2)
require.NoError(t, err)
assert.InDelta(t, 42, foo2.Float32, closeEnough)
assert.InDelta(t, 43, foo2.Float64, closeEnough)
}
+126 -65
View File
@@ -2,7 +2,7 @@ package toml
import ( import (
"bytes" "bytes"
"strconv" "unicode"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
@@ -107,9 +107,8 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
} }
if b[0] == '#' { if b[0] == '#' {
_, rest := scanComment(b) _, rest, err := scanComment(b)
return ref, rest, err
return ref, rest, nil
} }
if b[0] == '\n' || b[0] == '\r' { if b[0] == '\n' || b[0] == '\r' {
@@ -130,9 +129,8 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) > 0 && b[0] == '#' { if len(b) > 0 && b[0] == '#' {
_, rest := scanComment(b) _, rest, err := scanComment(b)
return ref, rest, err
return ref, rest, nil
} }
return ref, b, nil return ref, b, nil
@@ -354,7 +352,13 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
var err error var err error
for len(b) > 0 { for len(b) > 0 {
previousB := b
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 {
return parent, nil, newDecodeError(previousB[:1], "inline table is incomplete")
}
if b[0] == '}' { if b[0] == '}' {
break break
} }
@@ -398,6 +402,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
// array-values =/ ws-comment-newline val ws-comment-newline [ array-sep ] // array-values =/ ws-comment-newline val ws-comment-newline [ array-sep ]
// array-sep = %x2C ; , Comma // array-sep = %x2C ; , Comma
// ws-comment-newline = *( wschar / [ comment ] newline ) // ws-comment-newline = *( wschar / [ comment ] newline )
arrayStart := b
b = b[1:] b = b[1:]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(ast.Node{
@@ -416,7 +421,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
if len(b) == 0 { if len(b) == 0 {
return parent, nil, newDecodeError(b, "array is incomplete") return parent, nil, newDecodeError(arrayStart[:1], "array is incomplete")
} }
if b[0] == ']' { if b[0] == ']' {
@@ -433,6 +438,8 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
} }
} else if !first {
return parent, nil, newDecodeError(b[0:1], "array elements must be separated by commas")
} }
// TOML allows trailing commas in arrays. // TOML allows trailing commas in arrays.
@@ -441,7 +448,6 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
var valueRef ast.Reference var valueRef ast.Reference
valueRef, b, err = p.parseVal(b) valueRef, b, err = p.parseVal(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@@ -472,7 +478,10 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error)
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) > 0 && b[0] == '#' { if len(b) > 0 && b[0] == '#' {
_, b = scanComment(b) _, b, err = scanComment(b)
if err != nil {
return nil, err
}
} }
if len(b) == 0 { if len(b) == 0 {
@@ -522,7 +531,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
// mlb-quotes = 1*2quotation-mark // mlb-quotes = 1*2quotation-mark
// mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// mlb-escaped-nl = escape ws newline *( wschar / newline ) // mlb-escaped-nl = escape ws newline *( wschar / newline )
token, rest, err := scanMultilineBasicString(b) token, escaped, rest, err := scanMultilineBasicString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -539,21 +548,21 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
// fast path // fast path
startIdx := i startIdx := i
endIdx := len(token) - len(`"""`) endIdx := len(token) - len(`"""`)
for ; i < endIdx; i++ {
if token[i] == '\\' { if !escaped {
break str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str)
if verr.Zero() {
return token, str, rest, nil
} }
} return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
if i == endIdx {
return token, token[startIdx:endIdx], rest, nil
} }
var builder bytes.Buffer var builder bytes.Buffer
builder.Write(token[startIdx:i])
// The scanner ensures that the token starts and ends with quotes and that // The scanner ensures that the token starts and ends with quotes and that
// escapes are balanced. // escapes are balanced.
for ; i < len(token)-3; i++ { for i < len(token)-3 {
c := token[i] c := token[i]
//nolint:nestif //nolint:nestif
@@ -561,17 +570,29 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
// When the last non-whitespace character on a line is an unescaped \, // When the last non-whitespace character on a line is an unescaped \,
// it will be trimmed along with all whitespace (including newlines) up // it will be trimmed along with all whitespace (including newlines) up
// to the next non-whitespace character or closing delimiter. // to the next non-whitespace character or closing delimiter.
if token[i+1] == '\n' || (token[i+1] == '\r' && token[i+2] == '\n') {
i++ // skip the \ isLastNonWhitespaceOnLine := false
j := 1
findEOLLoop:
for ; j < len(token)-3-i; j++ {
switch token[i+j] {
case ' ', '\t':
continue
case '\n':
isLastNonWhitespaceOnLine = true
}
break findEOLLoop
}
if isLastNonWhitespaceOnLine {
i += j
for ; i < len(token)-3; i++ { for ; i < len(token)-3; i++ {
c := token[i] c := token[i]
if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') { if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') {
i-- i--
break break
} }
} }
i++
continue continue
} }
@@ -593,26 +614,31 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
case 't': case 't':
builder.WriteByte('\t') builder.WriteByte('\t')
case 'u': case 'u':
x, err := hexToString(atmost(token[i+1:], 4), 4) x, err := hexToRune(atmost(token[i+1:], 4), 4)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
builder.WriteRune(x)
builder.WriteString(x)
i += 4 i += 4
case 'U': case 'U':
x, err := hexToString(atmost(token[i+1:], 8), 8) x, err := hexToRune(atmost(token[i+1:], 8), 8)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
builder.WriteString(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++
} else { } else {
builder.WriteByte(c) size := utf8ValidNext(token[i:])
if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c)
}
builder.Write(token[i : i+size])
i += size
} }
} }
@@ -666,10 +692,6 @@ func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
// unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
// quoted-key = basic-string / literal-string // quoted-key = basic-string / literal-string
if len(b) == 0 {
return nil, nil, nil, newDecodeError(b, "key is incomplete")
}
switch { switch {
case b[0] == '\'': case b[0] == '\'':
return p.parseLiteralString(b) return p.parseLiteralString(b)
@@ -699,30 +721,33 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// escape-seq-char =/ %x74 ; t tab U+0009 // escape-seq-char =/ %x74 ; t tab U+0009
// escape-seq-char =/ %x75 4HEXDIG ; uXXXX U+XXXX // escape-seq-char =/ %x75 4HEXDIG ; uXXXX U+XXXX
// escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX // escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX
token, rest, err := scanBasicString(b) token, escaped, rest, err := scanBasicString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
// fast path startIdx := len(`"`)
i := len(`"`)
startIdx := i
endIdx := len(token) - len(`"`) endIdx := len(token) - len(`"`)
for ; i < endIdx; i++ {
if token[i] == '\\' { // Fast path. If there is no escape sequence, the string should just be
break // an UTF-8 encoded string, which is the same as Go. In that case,
// validate the string and return a direct reference to the buffer.
if !escaped {
str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str)
if verr.Zero() {
return token, str, rest, nil
} }
} return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
if i == endIdx {
return token, token[startIdx:endIdx], rest, nil
} }
i := startIdx
var builder bytes.Buffer var builder bytes.Buffer
builder.Write(token[startIdx:i])
// The scanner ensures that the token starts and ends with quotes and that // The scanner ensures that the token starts and ends with quotes and that
// escapes are balanced. // escapes are balanced.
for ; i < len(token)-1; i++ { for i < len(token)-1 {
c := token[i] c := token[i]
if c == '\\' { if c == '\\' {
i++ i++
@@ -742,46 +767,65 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
case 't': case 't':
builder.WriteByte('\t') builder.WriteByte('\t')
case 'u': case 'u':
x, err := hexToString(token[i+1:len(token)-1], 4) x, err := hexToRune(token[i+1:len(token)-1], 4)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
builder.WriteString(x) builder.WriteRune(x)
i += 4 i += 4
case 'U': case 'U':
x, err := hexToString(token[i+1:len(token)-1], 8) x, err := hexToRune(token[i+1:len(token)-1], 8)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
builder.WriteString(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++
} else { } else {
builder.WriteByte(c) size := utf8ValidNext(token[i:])
if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c)
}
builder.Write(token[i : i+size])
i += size
} }
} }
return token, builder.Bytes(), rest, nil return token, builder.Bytes(), rest, nil
} }
func hexToString(b []byte, length int) (string, error) { func hexToRune(b []byte, length int) (rune, error) {
if len(b) < length { if len(b) < length {
return "", newDecodeError(b, "unicode point needs %d character, not %d", length, len(b)) return -1, newDecodeError(b, "unicode point needs %d character, not %d", length, len(b))
} }
b = b[:length] b = b[:length]
//nolint:godox var r uint32
// TODO: slow for i, c := range b {
intcode, err := strconv.ParseInt(string(b), 16, 32) d := uint32(0)
if err != nil { switch {
return "", newDecodeError(b, "couldn't parse hexadecimal number: %w", err) case '0' <= c && c <= '9':
d = uint32(c - '0')
case 'a' <= c && c <= 'f':
d = uint32(c - 'a' + 10)
case 'A' <= c && c <= 'F':
d = uint32(c - 'A' + 10)
default:
return -1, newDecodeError(b[i:i+1], "non-hex character")
}
r = r*16 + d
} }
return string(rune(intcode)), nil if r > unicode.MaxRune || 0xD800 <= r && r < 0xE000 {
return -1, newDecodeError(b, "escape sequence is invalid Unicode code point")
}
return rune(r), nil
} }
func (p *parser) parseWhitespace(b []byte) []byte { func (p *parser) parseWhitespace(b []byte) []byte {
@@ -836,6 +880,8 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
if idx == 2 && c == ':' || (idx == 4 && c == '-') { if idx == 2 && c == ':' || (idx == 4 && c == '-') {
return p.scanDateTime(b) return p.scanDateTime(b)
} }
break
} }
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
@@ -856,6 +902,7 @@ func digitsToInt(b []byte) int {
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
// followed by a digit. // followed by a digit.
hasDate := false
hasTime := false hasTime := false
hasTz := false hasTz := false
seenSpace := false seenSpace := false
@@ -868,17 +915,23 @@ byteLoop:
switch { switch {
case isDigit(c): case isDigit(c):
case c == '-': case c == '-':
hasDate = true
const minOffsetOfTz = 8 const minOffsetOfTz = 8
if i >= minOffsetOfTz { if i >= minOffsetOfTz {
hasTz = true hasTz = true
} }
case c == 'T' || c == ':' || c == '.': case c == 'T' || c == 't' || c == ':' || c == '.':
hasTime = true hasTime = true
case c == '+' || c == '-' || c == 'Z': case c == '+' || c == '-' || c == 'Z' || c == 'z':
hasTz = true hasTz = true
case c == ' ': case c == ' ':
if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) { if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) {
i += 2 i += 2
// Avoid reaching past the end of the document in case the time
// is malformed. See TestIssue585.
if i >= len(b) {
i--
}
seenSpace = true seenSpace = true
hasTime = true hasTime = true
} else { } else {
@@ -892,10 +945,14 @@ byteLoop:
var kind ast.Kind var kind ast.Kind
if hasTime { if hasTime {
if hasTz { if hasDate {
kind = ast.DateTime if hasTz {
kind = ast.DateTime
} else {
kind = ast.LocalDateTime
}
} else { } else {
kind = ast.LocalDateTime kind = ast.LocalTime
} }
} else { } else {
kind = ast.LocalDate kind = ast.LocalDate
@@ -911,7 +968,7 @@ byteLoop:
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
i := 0 i := 0
if len(b) > 2 && b[0] == '0' && b[1] != '.' { if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' {
var isValidRune validRuneFn var isValidRune validRuneFn
switch b[1] { switch b[1] {
@@ -1018,8 +1075,12 @@ func isValidBinaryRune(r byte) bool {
} }
func expect(x byte, b []byte) ([]byte, error) { func expect(x byte, b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, newDecodeError(b, "expected character %c but the document ended here", x)
}
if b[0] != x { if b[0] != x {
return nil, newDecodeError(b[0:1], "expected character %U", x) return nil, newDecodeError(b[0:1], "expected character %c", x)
} }
return b[1:], nil return b[1:], nil
+100
View File
@@ -1,6 +1,8 @@
package toml package toml
import ( import (
"strconv"
"strings"
"testing" "testing"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
@@ -348,3 +350,101 @@ func TestParser_AST(t *testing.T) {
}) })
} }
} }
func BenchmarkParseBasicStringWithUnicode(b *testing.B) {
p := &parser{}
b.Run("4", func(b *testing.B) {
input := []byte(`"\u1234\u5678\u9ABC\u1234\u5678\u9ABC"`)
b.ReportAllocs()
b.SetBytes(int64(len(input)))
for i := 0; i < b.N; i++ {
p.parseBasicString(input)
}
})
b.Run("8", func(b *testing.B) {
input := []byte(`"\u12345678\u9ABCDEF0\u12345678\u9ABCDEF0"`)
b.ReportAllocs()
b.SetBytes(int64(len(input)))
for i := 0; i < b.N; i++ {
p.parseBasicString(input)
}
})
}
func BenchmarkParseBasicStringsEasy(b *testing.B) {
p := &parser{}
for _, size := range []int{1, 4, 8, 16, 21} {
b.Run(strconv.Itoa(size), func(b *testing.B) {
input := []byte(`"` + strings.Repeat("A", size) + `"`)
b.ReportAllocs()
b.SetBytes(int64(len(input)))
for i := 0; i < b.N; i++ {
p.parseBasicString(input)
}
})
}
}
func TestParser_AST_DateTimes(t *testing.T) {
examples := []struct {
desc string
input string
kind ast.Kind
err bool
}{
{
desc: "offset-date-time with delim 'T' and UTC offset",
input: `2021-07-21T12:08:05Z`,
kind: ast.DateTime,
},
{
desc: "offset-date-time with space delim and +8hours offset",
input: `2021-07-21 12:08:05+08:00`,
kind: ast.DateTime,
},
{
desc: "local-date-time with nano second",
input: `2021-07-21T12:08:05.666666666`,
kind: ast.LocalDateTime,
},
{
desc: "local-date-time",
input: `2021-07-21T12:08:05`,
kind: ast.LocalDateTime,
},
{
desc: "local-date",
input: `2021-07-21`,
kind: ast.LocalDate,
},
}
for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) {
p := parser{}
p.Reset([]byte(`A = ` + e.input))
p.NextExpression()
err := p.Error()
if e.err {
require.Error(t, err)
} else {
require.NoError(t, err)
expected := astNode{
Kind: ast.KeyValue,
Children: []astNode{
{Kind: e.kind, Data: []byte(e.input)},
{Kind: ast.Key, Data: []byte(`A`)},
},
}
compareNode(t, expected, p.Expression())
}
})
}
}
+96 -20
View File
@@ -49,13 +49,18 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) {
// literal-string = apostrophe *literal-char apostrophe // literal-string = apostrophe *literal-char apostrophe
// apostrophe = %x27 ; ' apostrophe // apostrophe = %x27 ; ' apostrophe
// literal-char = %x09 / %x20-26 / %x28-7E / non-ascii // literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
for i := 1; i < len(b); i++ { for i := 1; i < len(b); {
switch b[i] { switch b[i] {
case '\'': case '\'':
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
case '\n': case '\n':
return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines")
} }
size := utf8ValidNext(b[i:])
if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character")
}
i += size
} }
return nil, nil, newDecodeError(b[len(b):], "unterminated literal string") return nil, nil, newDecodeError(b[len(b):], "unterminated literal string")
@@ -70,10 +75,37 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
// mll-content = mll-char / newline // mll-content = mll-char / newline
// mll-char = %x09 / %x20-26 / %x28-7E / non-ascii // mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
// mll-quotes = 1*2apostrophe // mll-quotes = 1*2apostrophe
for i := 3; i < len(b); i++ { for i := 3; i < len(b); {
if b[i] == '\'' && scanFollowsMultilineLiteralStringDelimiter(b[i:]) { if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
return b[:i+3], b[i+3:], nil i += 3
// At that point we found 3 apostrophe, and i is the
// index of the byte after the third one. The scanner
// needs to be eager, because there can be an extra 2
// apostrophe that can be accepted at the end of the
// string.
if i >= len(b) || b[i] != '\'' {
return b[:i], b[i:], nil
}
i++
if i >= len(b) || b[i] != '\'' {
return b[:i], b[i:], nil
}
i++
if i < len(b) && b[i] == '\'' {
return nil, nil, newDecodeError(b[i-3:i+1], "''' not allowed in multiline literal string")
}
return b[:i], b[i:], nil
} }
size := utf8ValidNext(b[i:])
if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character")
}
i += size
} }
return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`)
@@ -106,45 +138,62 @@ func scanWhitespace(b []byte) ([]byte, []byte) {
} }
//nolint:unparam //nolint:unparam
func scanComment(b []byte) ([]byte, []byte) { func scanComment(b []byte) ([]byte, []byte, error) {
// comment-start-symbol = %x23 ; # // comment-start-symbol = %x23 ; #
// non-ascii = %x80-D7FF / %xE000-10FFFF // non-ascii = %x80-D7FF / %xE000-10FFFF
// non-eol = %x09 / %x20-7F / non-ascii // non-eol = %x09 / %x20-7F / non-ascii
// //
// comment = comment-start-symbol *non-eol // comment = comment-start-symbol *non-eol
for i := 1; i < len(b); i++ {
for i := 1; i < len(b); {
if b[i] == '\n' { if b[i] == '\n' {
return b[:i], b[i:] return b[:i], b[i:], nil
} }
if b[i] == '\r' {
if i+1 < len(b) && b[i+1] == '\n' {
return b[:i+1], b[i+1:], nil
}
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment")
}
size := utf8ValidNext(b[i:])
if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment")
}
i += size
} }
return b, nil return b, b[len(b):], nil
} }
func scanBasicString(b []byte) ([]byte, []byte, error) { func scanBasicString(b []byte) ([]byte, bool, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
// basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
// basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// escaped = escape escape-seq-char // escaped = escape escape-seq-char
for i := 1; i < len(b); i++ { escaped := false
i := 1
for ; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
return b[:i+1], b[i+1:], nil return b[:i+1], escaped, b[i+1:], nil
case '\n': case '\n', '\r':
return nil, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines") return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines")
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, nil, newDecodeError(b[i:i+1], "need a character after \\") return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\")
} }
escaped = true
i++ // skip the next character i++ // skip the next character
} }
} }
return nil, nil, newDecodeError(b[len(b):], `basic string not terminated by "`) return nil, escaped, nil, newDecodeError(b[len(b):], `basic string not terminated by "`)
} }
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim // ml-basic-string-delim
// ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
@@ -155,19 +204,46 @@ func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
// mlb-quotes = 1*2quotation-mark // mlb-quotes = 1*2quotation-mark
// mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// mlb-escaped-nl = escape ws newline *( wschar / newline ) // mlb-escaped-nl = escape ws newline *( wschar / newline )
for i := 3; i < len(b); i++ {
escaped := false
i := 3
for ; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
if scanFollowsMultilineBasicStringDelimiter(b[i:]) { if scanFollowsMultilineBasicStringDelimiter(b[i:]) {
return b[:i+3], b[i+3:], nil i += 3
// At that point we found 3 apostrophe, and i is the
// index of the byte after the third one. The scanner
// needs to be eager, because there can be an extra 2
// apostrophe that can be accepted at the end of the
// string.
if i >= len(b) || b[i] != '"' {
return b[:i], escaped, b[i:], nil
}
i++
if i >= len(b) || b[i] != '"' {
return b[:i], escaped, b[i:], nil
}
i++
if i < len(b) && b[i] == '"' {
return nil, escaped, nil, newDecodeError(b[i-3:i+1], `""" not allowed in multiline basic string`)
}
return b[:i], escaped, b[i:], nil
} }
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, nil, newDecodeError(b[len(b):], "need a character after \\") return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\")
} }
escaped = true
i++ // skip the next character i++ // skip the next character
} }
} }
return nil, nil, newDecodeError(b[len(b):], `multiline basic string not terminated by """`) return nil, escaped, nil, newDecodeError(b[len(b):], `multiline basic string not terminated by """`)
} }
+74
View File
@@ -0,0 +1,74 @@
package testsuite
import (
"fmt"
"math"
"time"
"github.com/pelletier/go-toml/v2"
)
// addTag adds JSON tags to a data structure as expected by toml-test.
func addTag(key string, tomlData interface{}) interface{} {
// Switch on the data type.
switch orig := tomlData.(type) {
default:
//return map[string]interface{}{}
panic(fmt.Sprintf("Unknown type: %T", tomlData))
// A table: we don't need to add any tags, just recurse for every table
// entry.
case map[string]interface{}:
typed := make(map[string]interface{}, len(orig))
for k, v := range orig {
typed[k] = addTag(k, v)
}
return typed
// An array: we don't need to add any tags, just recurse for every table
// entry.
case []map[string]interface{}:
typed := make([]map[string]interface{}, len(orig))
for i, v := range orig {
typed[i] = addTag("", v).(map[string]interface{})
}
return typed
case []interface{}:
typed := make([]interface{}, len(orig))
for i, v := range orig {
typed[i] = addTag("", v)
}
return typed
// Datetime: tag as datetime.
case toml.LocalTime:
return tag("time-local", orig.String())
case toml.LocalDate:
return tag("date-local", orig.String())
case toml.LocalDateTime:
return tag("datetime-local", orig.String())
case time.Time:
return tag("datetime", orig.Format("2006-01-02T15:04:05.999999999Z07:00"))
// Tag primitive values: bool, string, int, and float64.
case bool:
return tag("bool", fmt.Sprintf("%v", orig))
case string:
return tag("string", orig)
case int64:
return tag("integer", fmt.Sprintf("%d", orig))
case float64:
// Special case for nan since NaN == NaN is false.
if math.IsNaN(orig) {
return tag("float", "nan")
}
return tag("float", fmt.Sprintf("%v", orig))
}
}
func tag(typeName string, data interface{}) map[string]interface{} {
return map[string]interface{}{
"type": typeName,
"value": data,
}
}
+244
View File
@@ -0,0 +1,244 @@
package testsuite
import (
"fmt"
"strconv"
"strings"
"testing"
"time"
)
func CmpJSON(t *testing.T, key string, want, have interface{}) {
switch w := want.(type) {
case map[string]interface{}:
cmpJSONMaps(t, key, w, have)
case []interface{}:
cmpJSONArrays(t, key, w, have)
default:
t.Errorf(
"Key '%s' in expected output should be a map or a list of maps, but it's a %T",
key, want)
}
}
func cmpJSONMaps(t *testing.T, key string, want map[string]interface{}, have interface{}) {
haveMap, ok := have.(map[string]interface{})
if !ok {
mismatch(t, key, "table", want, haveMap)
return
}
// Check to make sure both or neither are values.
if isValue(want) && !isValue(haveMap) {
t.Fatalf("Key '%s' is supposed to be a value, but the parser reports it as a table", key)
}
if !isValue(want) && isValue(haveMap) {
t.Fatalf("Key '%s' is supposed to be a table, but the parser reports it as a value", key)
}
if isValue(want) && isValue(haveMap) {
cmpJSONValues(t, key, want, haveMap)
return
}
// Check that the keys of each map are equivalent.
for k := range want {
if _, ok := haveMap[k]; !ok {
bunk := kjoin(key, k)
t.Fatalf("Could not find key '%s' in parser output.", bunk)
}
}
for k := range haveMap {
if _, ok := want[k]; !ok {
bunk := kjoin(key, k)
t.Fatalf("Could not find key '%s' in expected output.", bunk)
}
}
// Okay, now make sure that each value is equivalent.
for k := range want {
CmpJSON(t, kjoin(key, k), want[k], haveMap[k])
}
}
func cmpJSONArrays(t *testing.T, key string, want, have interface{}) {
wantSlice, ok := want.([]interface{})
if !ok {
panic(fmt.Sprintf("'value' should be a JSON array when 'type=array', but it is a %T", want))
}
haveSlice, ok := have.([]interface{})
if !ok {
t.Fatalf("Malformed output from your encoder: 'value' is not a JSON array: %T", have)
}
if len(wantSlice) != len(haveSlice) {
t.Fatalf("Array lengths differ for key '%s':\n"+
" Expected: %d\n"+
" Your encoder: %d",
key, len(wantSlice), len(haveSlice))
}
for i := 0; i < len(wantSlice); i++ {
CmpJSON(t, key, wantSlice[i], haveSlice[i])
}
}
func cmpJSONValues(t *testing.T, key string, want, have map[string]interface{}) {
wantType, ok := want["type"].(string)
if !ok {
panic(fmt.Sprintf("'type' should be a string, but it is a %T", want["type"]))
}
haveType, ok := have["type"].(string)
if !ok {
t.Fatalf("Malformed output from your encoder: 'type' is not a string: %T", have["type"])
}
if wantType != haveType {
valMismatch(t, key, wantType, haveType, want, have)
}
// If this is an array, then we've got to do some work to check equality.
if wantType == "array" {
cmpJSONArrays(t, key, want, have)
return
}
// Atomic values are always strings
wantVal, ok := want["value"].(string)
if !ok {
panic(fmt.Sprintf("'value' %v should be a string, but it is a %[1]T", want["value"]))
}
haveVal, ok := have["value"].(string)
if !ok {
panic(fmt.Sprintf("Malformed output from your encoder: %T is not a string", have["value"]))
}
// Excepting floats and datetimes, other values can be compared as strings.
switch wantType {
case "float":
cmpFloats(t, key, wantVal, haveVal)
case "datetime", "datetime-local", "date-local", "time-local":
cmpAsDatetimes(t, key, wantType, wantVal, haveVal)
default:
cmpAsStrings(t, key, wantVal, haveVal)
}
}
func cmpAsStrings(t *testing.T, key string, want, have string) {
if want != have {
t.Fatalf("Values for key '%s' don't match:\n"+
" Expected: %s\n"+
" Your encoder: %s",
key, want, have)
}
}
func cmpFloats(t *testing.T, key string, want, have string) {
// Special case for NaN, since NaN != NaN.
if strings.HasSuffix(want, "nan") || strings.HasSuffix(have, "nan") {
if want != have {
t.Fatalf("Values for key '%s' don't match:\n"+
" Expected: %v\n"+
" Your encoder: %v",
key, want, have)
}
return
}
wantF, err := strconv.ParseFloat(want, 64)
if err != nil {
panic(fmt.Sprintf("Could not read '%s' as a float value for key '%s'", want, key))
}
haveF, err := strconv.ParseFloat(have, 64)
if err != nil {
panic(fmt.Sprintf("Malformed output from your encoder: key '%s' is not a float: '%s'", key, have))
}
if wantF != haveF {
t.Fatalf("Values for key '%s' don't match:\n"+
" Expected: %v\n"+
" Your encoder: %v",
key, wantF, haveF)
}
}
var datetimeRepl = strings.NewReplacer(
" ", "T",
"t", "T",
"z", "Z")
var layouts = map[string]string{
"datetime": time.RFC3339Nano,
"datetime-local": "2006-01-02T15:04:05.999999999",
"date-local": "2006-01-02",
"time-local": "15:04:05",
}
func cmpAsDatetimes(t *testing.T, key string, kind, want, have string) {
layout, ok := layouts[kind]
if !ok {
panic("should never happen")
}
wantT, err := time.Parse(layout, datetimeRepl.Replace(want))
if err != nil {
panic(fmt.Sprintf("Could not read '%s' as a datetime value for key '%s'", want, key))
}
haveT, err := time.Parse(layout, datetimeRepl.Replace(want))
if err != nil {
t.Fatalf("Malformed output from your encoder: key '%s' is not a datetime: '%s'", key, have)
return
}
if !wantT.Equal(haveT) {
t.Fatalf("Values for key '%s' don't match:\n"+
" Expected: %v\n"+
" Your encoder: %v",
key, wantT, haveT)
}
}
func cmpAsDatetimesLocal(t *testing.T, key string, want, have string) {
if datetimeRepl.Replace(want) != datetimeRepl.Replace(have) {
t.Fatalf("Values for key '%s' don't match:\n"+
" Expected: %v\n"+
" Your encoder: %v",
key, want, have)
}
}
func kjoin(old, key string) string {
if len(old) == 0 {
return key
}
return old + "." + key
}
func isValue(m map[string]interface{}) bool {
if len(m) != 2 {
return false
}
if _, ok := m["type"]; !ok {
return false
}
if _, ok := m["value"]; !ok {
return false
}
return true
}
func mismatch(t *testing.T, key string, wantType string, want, have interface{}) {
t.Fatalf("Key '%s' is not an %s but %[4]T:\n"+
" Expected: %#[3]v\n"+
" Your encoder: %#[4]v",
key, wantType, want, have)
}
func valMismatch(t *testing.T, key string, wantType, haveType string, want, have interface{}) {
t.Fatalf("Key '%s' is not an %s but %s:\n"+
" Expected: %#[3]v\n"+
" Your encoder: %#[4]v",
key, wantType, want, have)
}
+69
View File
@@ -0,0 +1,69 @@
package testsuite
import (
"bytes"
"encoding/json"
"fmt"
"github.com/pelletier/go-toml/v2"
)
type parser struct{}
func (p parser) Decode(input string) (output string, outputIsError bool, retErr error) {
defer func() {
if r := recover(); r != nil {
switch rr := r.(type) {
case error:
retErr = rr
default:
retErr = fmt.Errorf("%s", rr)
}
}
}()
var v interface{}
if err := toml.Unmarshal([]byte(input), &v); err != nil {
return err.Error(), true, nil
}
j, err := json.MarshalIndent(addTag("", v), "", " ")
if err != nil {
return "", false, retErr
}
return string(j), false, retErr
}
func (p parser) Encode(input string) (output string, outputIsError bool, retErr error) {
defer func() {
if r := recover(); r != nil {
switch rr := r.(type) {
case error:
retErr = rr
default:
retErr = fmt.Errorf("%s", rr)
}
}
}()
var tmp interface{}
err := json.Unmarshal([]byte(input), &tmp)
if err != nil {
return "", false, err
}
rm, err := rmTag(tmp)
if err != nil {
return err.Error(), true, retErr
}
buf := new(bytes.Buffer)
err = toml.NewEncoder(buf).Encode(rm)
if err != nil {
return err.Error(), true, retErr
}
return buf.String(), false, retErr
}
+110
View File
@@ -0,0 +1,110 @@
package testsuite
import (
"fmt"
"strconv"
"time"
)
// Remove JSON tags to a data structure as returned by toml-test.
func rmTag(typedJson interface{}) (interface{}, error) {
// Check if key is in the table m.
in := func(key string, m map[string]interface{}) bool {
_, ok := m[key]
return ok
}
// Switch on the data type.
switch v := typedJson.(type) {
// Object: this can either be a TOML table or a primitive with tags.
case map[string]interface{}:
// This value represents a primitive: remove the tags and return just
// the primitive value.
if len(v) == 2 && in("type", v) && in("value", v) {
ut, err := untag(v)
if err != nil {
return ut, fmt.Errorf("tag.Remove: %w", err)
}
return ut, nil
}
// Table: remove tags on all children.
m := make(map[string]interface{}, len(v))
for k, v2 := range v {
var err error
m[k], err = rmTag(v2)
if err != nil {
return nil, err
}
}
return m, nil
// Array: remove tags from all itenm.
case []interface{}:
a := make([]interface{}, len(v))
for i := range v {
var err error
a[i], err = rmTag(v[i])
if err != nil {
return nil, err
}
}
return a, nil
}
// The top level must be an object or array.
return nil, fmt.Errorf("unrecognized JSON format '%T'", typedJson)
}
// Return a primitive: read the "type" and convert the "value" to that.
func untag(typed map[string]interface{}) (interface{}, error) {
t := typed["type"].(string)
v := typed["value"].(string)
switch t {
case "string":
return v, nil
case "integer":
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("untag: %w", err)
}
return n, nil
case "float":
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return nil, fmt.Errorf("untag: %w", err)
}
return f, nil
case "datetime":
return parseTime(v, "2006-01-02T15:04:05.999999999Z07:00", false)
case "datetime-local":
return parseTime(v, "2006-01-02T15:04:05.999999999", true)
case "date-local":
return parseTime(v, "2006-01-02", true)
case "time-local":
return parseTime(v, "15:04:05.999999999", true)
case "bool":
switch v {
case "true":
return true, nil
case "false":
return false, nil
}
return nil, fmt.Errorf("untag: could not parse %q as a boolean", v)
}
return nil, fmt.Errorf("untag: unrecognized tag type %q", t)
}
func parseTime(v, format string, local bool) (t time.Time, err error) {
if local {
t, err = time.ParseInLocation(format, v, time.Local)
} else {
t, err = time.Parse(format, v)
}
if err != nil {
return time.Time{}, fmt.Errorf("Could not parse %q as a datetime: %w", v, err)
}
return t, nil
}
+50
View File
@@ -0,0 +1,50 @@
// Package testsuite provides helper functions for interoperating with the
// language-agnostic TOML test suite at github.com/BurntSushi/toml-test.
package testsuite
import (
"encoding/json"
"fmt"
"os"
"github.com/pelletier/go-toml/v2"
)
// Marshal is a helpfer function for calling toml.Marshal
//
// Only needed to avoid package import loops.
func Marshal(v interface{}) ([]byte, error) {
return toml.Marshal(v)
}
// Unmarshal is a helper function for calling toml.Unmarshal.
//
// Only needed to avoid package import loops.
func Unmarshal(data []byte, v interface{}) error {
return toml.Unmarshal(data, v)
}
// ValueToTaggedJSON takes a data structure and returns the tagged JSON
// representation.
func ValueToTaggedJSON(doc interface{}) ([]byte, error) {
return json.MarshalIndent(addTag("", doc), "", " ")
}
// DecodeStdin is a helper function for the toml-test binary interface. TOML input
// is read from STDIN and a resulting tagged JSON representation is written to
// STDOUT.
func DecodeStdin() error {
var decoded map[string]interface{}
if err := toml.NewDecoder(os.Stdin).Decode(&decoded); err != nil {
return fmt.Errorf("Error decoding TOML: %s", err)
}
j := json.NewEncoder(os.Stdout)
j.SetIndent("", " ")
if err := j.Encode(addTag("", decoded)); err != nil {
return fmt.Errorf("Error encoding JSON: %s", err)
}
return nil
}
+23 -117
View File
@@ -1,14 +1,14 @@
//go:generate go run ./cmd/tomltestgen/main.go -o toml_testgen_test.go
// This is a support file for toml_testgen_test.go // This is a support file for toml_testgen_test.go
package toml_test package toml_test
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strconv"
"testing" "testing"
"time"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"github.com/pelletier/go-toml/v2/testsuite"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -17,10 +17,14 @@ func testgenInvalid(t *testing.T, input string) {
t.Logf("Input TOML:\n%s", input) t.Logf("Input TOML:\n%s", input)
doc := map[string]interface{}{} doc := map[string]interface{}{}
err := toml.Unmarshal([]byte(input), &doc) err := testsuite.Unmarshal([]byte(input), &doc)
if err == nil { if err == nil {
t.Log(json.Marshal(doc)) out, err := json.Marshal(doc)
if err != nil {
panic("could not marshal map to json")
}
t.Log("JSON output from unmarshal:", string(out))
t.Fatalf("test did not fail") t.Fatalf("test did not fail")
} }
} }
@@ -29,124 +33,26 @@ func testgenValid(t *testing.T, input string, jsonRef string) {
t.Helper() t.Helper()
t.Logf("Input TOML:\n%s", input) t.Logf("Input TOML:\n%s", input)
doc := map[string]interface{}{} // TODO: change this to interface{}
var doc map[string]interface{}
err := toml.Unmarshal([]byte(input), &doc) err := testsuite.Unmarshal([]byte(input), &doc)
if err != nil { if err != nil {
if de, ok := err.(*toml.DecodeError); ok {
t.Logf("%s\n%s", err, de)
}
t.Fatalf("failed parsing toml: %s", err) t.Fatalf("failed parsing toml: %s", err)
} }
j, err := testsuite.ValueToTaggedJSON(doc)
refDoc := testgenBuildRefDoc(jsonRef)
require.Equal(t, refDoc, doc)
out, err := toml.Marshal(doc)
require.NoError(t, err) require.NoError(t, err)
doc2 := map[string]interface{}{} var ref interface{}
err = toml.Unmarshal(out, &doc2) err = json.Unmarshal([]byte(jsonRef), &ref)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, refDoc, doc2) var actual interface{}
} err = json.Unmarshal([]byte(j), &actual)
require.NoError(t, err)
func testgenBuildRefDoc(jsonRef string) map[string]interface{} {
descTree := map[string]interface{}{} testsuite.CmpJSON(t, "", ref, actual)
err := json.Unmarshal([]byte(jsonRef), &descTree)
if err != nil {
panic(fmt.Sprintf("reference doc should be valid JSON: %s", err))
}
doc := testGenTranslateDesc(descTree)
if doc == nil {
return map[string]interface{}{}
}
return doc.(map[string]interface{})
}
//nolint:funlen,gocognit,cyclop
func testGenTranslateDesc(input interface{}) interface{} {
a, ok := input.([]interface{})
if ok {
xs := make([]interface{}, len(a))
for i, v := range a {
xs[i] = testGenTranslateDesc(v)
}
return xs
}
d, ok := input.(map[string]interface{})
if !ok {
panic(fmt.Sprintf("input should be valid map[string]: %v", input))
}
var (
dtype string
dvalue interface{}
)
//nolint:nestif
if len(d) == 2 {
dtypeiface, ok := d["type"]
if ok {
dvalue, ok = d["value"]
if ok {
dtype = dtypeiface.(string)
switch dtype {
case "string":
return dvalue.(string)
case "float":
v, err := strconv.ParseFloat(dvalue.(string), 64)
if err != nil {
panic(fmt.Sprintf("invalid float '%s': %s", dvalue, err))
}
return v
case "integer":
v, err := strconv.ParseInt(dvalue.(string), 10, 64)
if err != nil {
panic(fmt.Sprintf("invalid int '%s': %s", dvalue, err))
}
return v
case "bool":
return dvalue.(string) == "true"
case "datetime":
dt, err := time.Parse("2006-01-02T15:04:05Z", dvalue.(string))
if err != nil {
panic(fmt.Sprintf("invalid datetime '%s': %s", dvalue, err))
}
return dt
case "array":
if dvalue == nil {
return nil
}
a := dvalue.([]interface{})
xs := make([]interface{}, len(a))
for i, v := range a {
xs[i] = testGenTranslateDesc(v)
}
return xs
}
panic(fmt.Sprintf("unknown type: %s", dtype))
}
}
}
dest := map[string]interface{}{}
for k, v := range d {
dest[k] = testGenTranslateDesc(v)
}
return dest
} }
+1303 -772
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -11,3 +11,4 @@ var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
var sliceInterfaceType = reflect.TypeOf([]interface{}{}) var sliceInterfaceType = reflect.TypeOf([]interface{}{})
var stringType = reflect.TypeOf("")
+285 -144
View File
@@ -9,10 +9,12 @@ import (
"math" "math"
"reflect" "reflect"
"strings" "strings"
"sync" "sync/atomic"
"time" "time"
"unsafe"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
) )
@@ -47,8 +49,9 @@ func NewDecoder(r io.Reader) *Decoder {
// that could not be set on the target value. In that case, the decoder returns // that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as // a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields. // well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) { func (d *Decoder) SetStrict(strict bool) *Decoder {
d.strict = strict d.strict = strict
return d
} }
// Decode the whole content of r into v. // Decode the whole content of r into v.
@@ -174,7 +177,13 @@ func (d *decoder) FromParser(v interface{}) error {
return fmt.Errorf("toml: decoding pointer target cannot be nil") return fmt.Errorf("toml: decoding pointer target cannot be nil")
} }
err := d.fromParser(r.Elem()) r = r.Elem()
if r.Kind() == reflect.Interface && r.IsNil() {
newMap := map[string]interface{}{}
r.Set(reflect.ValueOf(newMap))
}
err := d.fromParser(r)
if err == nil { if err == nil {
return d.strict.Error(d.p.data) return d.strict.Error(d.p.data)
} }
@@ -370,13 +379,22 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// First, dispatch over v to make sure it is a valid object. // First, dispatch over v to make sure it is a valid object.
// There is no guarantee over what it could be. // There is no guarantee over what it could be.
switch v.Kind() { switch v.Kind() {
case reflect.Ptr:
elem := v.Elem()
if !elem.IsValid() {
v.Set(reflect.New(v.Type().Elem()))
}
elem = v.Elem()
return d.handleKeyPart(key, elem, nextFn, makeFn)
case reflect.Map: case reflect.Map:
// Create the key for the map element. For now assume it's a string. // Create the key for the map element. For now assume it's a string.
mk := reflect.ValueOf(string(key.Node().Data)) mk := reflect.ValueOf(string(key.Node().Data))
// If the map does not exist, create it. // If the map does not exist, create it.
if v.IsNil() { if v.IsNil() {
v = reflect.MakeMap(v.Type()) vt := v.Type()
v = reflect.MakeMap(vt)
rv = v rv = v
} }
@@ -388,7 +406,8 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// map[string]interface{} or a []interface{} depending on whether // map[string]interface{} or a []interface{} depending on whether
// this is the last part of the array table key. // this is the last part of the array table key.
t := v.Type().Elem() vt := v.Type()
t := vt.Elem()
if t.Kind() == reflect.Interface { if t.Kind() == reflect.Interface {
mv = makeFn() mv = makeFn()
} else { } else {
@@ -401,6 +420,13 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
mv = makeFn() mv = makeFn()
} }
set = true set = true
} else if !mv.CanAddr() {
vt := v.Type()
t := vt.Elem()
oldmv := mv
mv = reflect.New(t).Elem()
mv.Set(oldmv)
set = true
} }
x, err := nextFn(key, mv) x, err := nextFn(key, mv)
@@ -434,7 +460,7 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
if v.Elem().IsValid() { if v.Elem().IsValid() {
v = v.Elem() v = v.Elem()
} else { } else {
v = reflect.MakeMap(mapStringInterfaceType) v = makeMapStringInterface()
} }
x, err := d.handleKeyPart(key, v, nextFn, makeFn) x, err := d.handleKeyPart(key, v, nextFn, makeFn)
@@ -469,6 +495,9 @@ func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (refle
// cannot handle it. // cannot handle it.
func (d *decoder) handleTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
if v.Len() == 0 {
return reflect.Value{}, newDecodeError(key.Node().Data, "cannot store a table in a slice")
}
elem := v.Index(v.Len() - 1) elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem) x, err := d.handleTable(key, elem)
if err != nil { if err != nil {
@@ -503,6 +532,11 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
break break
} }
err := d.seen.CheckExpression(expr)
if err != nil {
return reflect.Value{}, err
}
x, err := d.handleKeyValue(expr, v) x, err := d.handleKeyValue(expr, v)
if err != nil { if err != nil {
return reflect.Value{}, err return reflect.Value{}, err
@@ -533,10 +567,6 @@ func (d *decoder) handleTablePart(key ast.Iterator, v reflect.Value) (reflect.Va
} }
func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, error) { func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, error) {
if v.Kind() != reflect.Struct {
return false, nil
}
// Special case for time, because we allow to unmarshal to it from // Special case for time, because we allow to unmarshal to it from
// different kind of AST nodes. // different kind of AST nodes.
if v.Type() == timeType { if v.Type() == timeType {
@@ -578,6 +608,8 @@ func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error {
return d.unmarshalDateTime(value, v) return d.unmarshalDateTime(value, v)
case ast.LocalDate: case ast.LocalDate:
return d.unmarshalLocalDate(value, v) return d.unmarshalLocalDate(value, v)
case ast.LocalTime:
return d.unmarshalLocalTime(value, v)
case ast.LocalDateTime: case ast.LocalDateTime:
return d.unmarshalLocalDateTime(value, v) return d.unmarshalLocalDateTime(value, v)
case ast.InlineTable: case ast.InlineTable:
@@ -589,62 +621,128 @@ func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error {
} }
} }
func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error { type unmarshalArrayFn func(d *decoder, array *ast.Node, v reflect.Value) error
switch v.Kind() {
case reflect.Slice: var globalUnmarshalArrayFnCache atomic.Value // map[danger.TypeID]unmarshalArrayFn
if v.IsNil() {
v.Set(reflect.MakeSlice(v.Type(), 0, 16)) func unmarshalArrayFnForSlice(vt reflect.Type) unmarshalArrayFn {
} else { tid := danger.MakeTypeID(vt)
v.SetLen(0)
} cache, _ := globalUnmarshalArrayFnCache.Load().(map[danger.TypeID]unmarshalArrayFn)
case reflect.Array: fn, ok := cache[tid]
// arrays are always initialized
case reflect.Interface: if ok {
elem := v.Elem() return fn
if !elem.IsValid() {
elem = reflect.New(sliceInterfaceType).Elem()
elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
} else if elem.Kind() == reflect.Slice {
if elem.Type() != sliceInterfaceType {
elem = reflect.New(sliceInterfaceType).Elem()
elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
} else if !elem.CanSet() {
nelem := reflect.New(sliceInterfaceType).Elem()
nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap()))
reflect.Copy(nelem, elem)
elem = nelem
}
}
err := d.unmarshalArray(array, elem)
if err != nil {
return err
}
v.Set(elem)
return nil
default:
// TODO: use newDecodeError, but first the parser needs to fill
// array.Data.
return fmt.Errorf("toml: cannot store array in Go type %s", v.Kind())
} }
elemType := v.Type().Elem() elemType := vt.Elem()
elemSize := elemType.Size()
it := array.Children() fn = func(d *decoder, array *ast.Node, v reflect.Value) error {
idx := 0 sp := (*danger.Slice)(unsafe.Pointer(v.UnsafeAddr()))
for it.Next() {
n := it.Node()
// TODO: optimize sp.Len = 0
if v.Kind() == reflect.Slice {
elem := reflect.New(elemType).Elem() it := array.Children()
for it.Next() {
n := it.Node()
idx := sp.Len
if sp.Len == sp.Cap {
c := sp.Cap
if c == 0 {
c = 16
} else {
c *= 2
}
*sp = danger.ExtendSlice(vt, sp, c)
}
datap := unsafe.Pointer(sp.Data)
elemp := danger.Stride(datap, elemSize, idx)
elem := reflect.NewAt(elemType, elemp).Elem()
err := d.handleValue(n, elem) err := d.handleValue(n, elem)
if err != nil { if err != nil {
return err return err
} }
v.Set(reflect.Append(v, elem)) sp.Len++
} else { // array }
if sp.Data == nil {
*sp = danger.ExtendSlice(vt, sp, 0)
}
return nil
}
newCache := make(map[danger.TypeID]unmarshalArrayFn, len(cache)+1)
newCache[tid] = fn
for k, v := range cache {
newCache[k] = v
}
globalUnmarshalArrayFnCache.Store(newCache)
return fn
}
func unmarshalArraySliceInterface(d *decoder, array *ast.Node, v reflect.Value) error {
sp := (*danger.Slice)(unsafe.Pointer(v.UnsafeAddr()))
sp.Len = 0
var x interface{}
it := array.Children()
for it.Next() {
n := it.Node()
idx := sp.Len
if sp.Len == sp.Cap {
c := sp.Cap
if c == 0 {
c = 16
} else {
c *= 2
}
*sp = danger.ExtendSlice(sliceInterfaceType, sp, c)
}
datap := unsafe.Pointer(sp.Data)
elemp := danger.Stride(datap, unsafe.Sizeof(x), idx)
elem := reflect.NewAt(sliceInterfaceType.Elem(), elemp).Elem()
err := d.handleValue(n, elem)
if err != nil {
return err
}
sp.Len++
}
if sp.Data == nil {
*sp = danger.ExtendSlice(sliceInterfaceType, sp, 0)
}
return nil
}
func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error {
switch v.Kind() {
case reflect.Slice:
fn := unmarshalArrayFnForSlice(v.Type())
return fn(d, array, v)
case reflect.Array:
// arrays are always initialized
it := array.Children()
idx := 0
for it.Next() {
n := it.Node()
if idx >= v.Len() { if idx >= v.Len() {
return nil return nil
} }
@@ -655,6 +753,39 @@ func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error {
} }
idx++ idx++
} }
case reflect.Interface:
elemIsSliceInterface := false
elem := v.Elem()
if !elem.IsValid() {
s := make([]interface{}, 0, 16)
elem = reflect.ValueOf(&s).Elem()
elemIsSliceInterface = true
} else if elem.Kind() == reflect.Slice {
if elem.Type() != sliceInterfaceType {
s := make([]interface{}, 0, 16)
elem = reflect.ValueOf(&s).Elem()
} else if !elem.CanSet() {
s := make([]interface{}, elem.Len(), elem.Cap())
nelem := reflect.ValueOf(&s).Elem()
reflect.Copy(nelem, elem)
elem = nelem
}
elemIsSliceInterface = true
}
var err error
if elemIsSliceInterface {
err = unmarshalArraySliceInterface(d, array, elem)
} else {
err = d.unmarshalArray(array, elem)
}
v.Set(elem)
return err
default:
// TODO: use newDecodeError, but first the parser needs to fill
// array.Data.
return fmt.Errorf("toml: cannot store array in Go type %s", v.Kind())
} }
return nil return nil
@@ -672,7 +803,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
case reflect.Interface: case reflect.Interface:
elem := v.Elem() elem := v.Elem()
if !elem.IsValid() { if !elem.IsValid() {
elem = reflect.MakeMap(mapStringInterfaceType) elem = makeMapStringInterface()
v.Set(elem) v.Set(elem)
} }
return d.unmarshalInlineTable(itable, elem) return d.unmarshalInlineTable(itable, elem)
@@ -713,8 +844,7 @@ func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error {
} }
if v.Type() == timeType { if v.Type() == timeType {
cast := ld.In(time.Local) cast := ld.AsTime(time.Local)
v.Set(reflect.ValueOf(cast)) v.Set(reflect.ValueOf(cast))
return nil return nil
} }
@@ -724,6 +854,20 @@ func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalLocalTime(value *ast.Node, v reflect.Value) error {
lt, rest, err := parseLocalTime(value.Data)
if err != nil {
return err
}
if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local time")
}
v.Set(reflect.ValueOf(lt))
return nil
}
func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error {
ldt, rest, err := parseLocalDateTime(value.Data) ldt, rest, err := parseLocalDateTime(value.Data)
if err != nil { if err != nil {
@@ -735,7 +879,7 @@ func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error
} }
if v.Type() == timeType { if v.Type() == timeType {
cast := ldt.In(time.Local) cast := ldt.AsTime(time.Local)
v.Set(reflect.ValueOf(cast)) v.Set(reflect.ValueOf(cast))
return nil return nil
@@ -795,86 +939,92 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error {
return err return err
} }
var r reflect.Value
switch v.Kind() { switch v.Kind() {
case reflect.Int64: case reflect.Int64:
v.SetInt(i) v.SetInt(i)
return nil
case reflect.Int32: case reflect.Int32:
if i < math.MinInt32 || i > math.MaxInt32 { if i < math.MinInt32 || i > math.MaxInt32 {
return fmt.Errorf("toml: number %d does not fit in an int32", i) return fmt.Errorf("toml: number %d does not fit in an int32", i)
} }
v.Set(reflect.ValueOf(int32(i))) r = reflect.ValueOf(int32(i))
return nil
case reflect.Int16: case reflect.Int16:
if i < math.MinInt16 || i > math.MaxInt16 { if i < math.MinInt16 || i > math.MaxInt16 {
return fmt.Errorf("toml: number %d does not fit in an int16", i) return fmt.Errorf("toml: number %d does not fit in an int16", i)
} }
v.Set(reflect.ValueOf(int16(i))) r = reflect.ValueOf(int16(i))
case reflect.Int8: case reflect.Int8:
if i < math.MinInt8 || i > math.MaxInt8 { if i < math.MinInt8 || i > math.MaxInt8 {
return fmt.Errorf("toml: number %d does not fit in an int8", i) return fmt.Errorf("toml: number %d does not fit in an int8", i)
} }
v.Set(reflect.ValueOf(int8(i))) r = reflect.ValueOf(int8(i))
case reflect.Int: case reflect.Int:
if i < minInt || i > maxInt { if i < minInt || i > maxInt {
return fmt.Errorf("toml: number %d does not fit in an int", i) return fmt.Errorf("toml: number %d does not fit in an int", i)
} }
v.Set(reflect.ValueOf(int(i))) r = reflect.ValueOf(int(i))
case reflect.Uint64: case reflect.Uint64:
if i < 0 { if i < 0 {
return fmt.Errorf("toml: negative number %d does not fit in an uint64", i) return fmt.Errorf("toml: negative number %d does not fit in an uint64", i)
} }
v.Set(reflect.ValueOf(uint64(i))) r = reflect.ValueOf(uint64(i))
case reflect.Uint32: case reflect.Uint32:
if i < 0 || i > math.MaxUint32 { if i < 0 || i > math.MaxUint32 {
return fmt.Errorf("toml: negative number %d does not fit in an uint32", i) return fmt.Errorf("toml: negative number %d does not fit in an uint32", i)
} }
v.Set(reflect.ValueOf(uint32(i))) r = reflect.ValueOf(uint32(i))
case reflect.Uint16: case reflect.Uint16:
if i < 0 || i > math.MaxUint16 { if i < 0 || i > math.MaxUint16 {
return fmt.Errorf("toml: negative number %d does not fit in an uint16", i) return fmt.Errorf("toml: negative number %d does not fit in an uint16", i)
} }
v.Set(reflect.ValueOf(uint16(i))) r = reflect.ValueOf(uint16(i))
case reflect.Uint8: case reflect.Uint8:
if i < 0 || i > math.MaxUint8 { if i < 0 || i > math.MaxUint8 {
return fmt.Errorf("toml: negative number %d does not fit in an uint8", i) return fmt.Errorf("toml: negative number %d does not fit in an uint8", i)
} }
v.Set(reflect.ValueOf(uint8(i))) r = reflect.ValueOf(uint8(i))
case reflect.Uint: case reflect.Uint:
if i < 0 { if i < 0 {
return fmt.Errorf("toml: negative number %d does not fit in an uint", i) return fmt.Errorf("toml: negative number %d does not fit in an uint", i)
} }
v.Set(reflect.ValueOf(uint(i))) r = reflect.ValueOf(uint(i))
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(i)) r = reflect.ValueOf(i)
default: default:
err = fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) return fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind())
} }
return err if !r.Type().AssignableTo(v.Type()) {
r = r.Convert(v.Type())
}
v.Set(r)
return nil
} }
func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error {
var err error
switch v.Kind() { switch v.Kind() {
case reflect.String: case reflect.String:
v.SetString(string(value.Data)) v.SetString(string(value.Data))
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(string(value.Data))) v.Set(reflect.ValueOf(string(value.Data)))
default: default:
err = newDecodeError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind()) return newDecodeError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind())
} }
return err return nil
} }
func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value, error) {
@@ -909,12 +1059,15 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
// There is no guarantee over what it could be. // There is no guarantee over what it could be.
switch v.Kind() { switch v.Kind() {
case reflect.Map: case reflect.Map:
mk := reflect.ValueOf(string(key.Node().Data)) vt := v.Type()
keyType := v.Type().Key() mk := reflect.ValueOf(string(key.Node().Data))
if !mk.Type().AssignableTo(keyType) { mkt := stringType
if !mk.Type().ConvertibleTo(keyType) {
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mk.Type(), keyType) keyType := vt.Key()
if !mkt.AssignableTo(keyType) {
if !mkt.ConvertibleTo(keyType) {
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mkt, keyType)
} }
mk = mk.Convert(keyType) mk = mk.Convert(keyType)
@@ -922,7 +1075,7 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
// If the map does not exist, create it. // If the map does not exist, create it.
if v.IsNil() { if v.IsNil() {
v = reflect.MakeMap(v.Type()) v = reflect.MakeMap(vt)
rv = v rv = v
} }
@@ -969,11 +1122,12 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
case reflect.Interface: case reflect.Interface:
v = v.Elem() v = v.Elem()
// Following encoding/toml: decoding an object into an interface{}, it // Following encoding/json: decoding an object into an
// needs to always hold a map[string]interface{}. This is for the types // interface{}, it needs to always hold a
// to be consistent whether a previous value was set or not. // map[string]interface{}. This is for the types to be
// consistent whether a previous value was set or not.
if !v.IsValid() || v.Type() != mapStringInterfaceType { if !v.IsValid() || v.Type() != mapStringInterfaceType {
v = reflect.MakeMap(mapStringInterfaceType) v = makeMapStringInterface()
} }
x, err := d.handleKeyValuePart(key, value, v) x, err := d.handleKeyValuePart(key, value, v)
@@ -1020,70 +1174,30 @@ func initAndDereferencePointer(v reflect.Value) reflect.Value {
type fieldPathsMap = map[string][]int type fieldPathsMap = map[string][]int
type fieldPathsCache struct { var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap
m map[reflect.Type]fieldPathsMap
l sync.RWMutex
}
func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) {
c.l.RLock()
paths, ok := c.m[t]
c.l.RUnlock()
return paths, ok
}
func (c *fieldPathsCache) set(t reflect.Type, m fieldPathsMap) {
c.l.Lock()
c.m[t] = m
c.l.Unlock()
}
var globalFieldPathsCache = fieldPathsCache{
m: map[reflect.Type]fieldPathsMap{},
l: sync.RWMutex{},
}
func structField(v reflect.Value, name string) (reflect.Value, bool) { func structField(v reflect.Value, name string) (reflect.Value, bool) {
//nolint:godox t := v.Type()
// TODO: cache this, and reduce allocations tid := danger.MakeTypeID(t)
fieldPaths, ok := globalFieldPathsCache.get(v.Type())
cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap)
fieldPaths, ok := cache[tid]
if !ok { if !ok {
fieldPaths = map[string][]int{} fieldPaths = map[string][]int{}
path := make([]int, 0, 16) forEachField(t, nil, func(name string, path []int) {
fieldPaths[name] = path
// extra copy for the case-insensitive match
fieldPaths[strings.ToLower(name)] = path
})
var walk func(reflect.Value) newCache := make(map[danger.TypeID]fieldPathsMap, len(cache)+1)
walk = func(v reflect.Value) { newCache[tid] = fieldPaths
t := v.Type() for k, v := range cache {
for i := 0; i < t.NumField(); i++ { newCache[k] = v
l := len(path)
path = append(path, i)
f := t.Field(i)
if f.Anonymous {
walk(v.Field(i))
} else if f.PkgPath == "" {
// only consider exported fields
fieldName, ok := f.Tag.Lookup("toml")
if !ok {
fieldName = f.Name
}
pathCopy := make([]int, len(path))
copy(pathCopy, path)
fieldPaths[fieldName] = pathCopy
// extra copy for the case-insensitive match
fieldPaths[strings.ToLower(fieldName)] = pathCopy
}
path = path[:l]
}
} }
globalFieldPathsCache.Store(newCache)
walk(v)
globalFieldPathsCache.set(v.Type(), fieldPaths)
} }
path, ok := fieldPaths[name] path, ok := fieldPaths[name]
@@ -1097,3 +1211,30 @@ func structField(v reflect.Value, name string) (reflect.Value, bool) {
return v.FieldByIndex(path), true return v.FieldByIndex(path), true
} }
func forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
n := t.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
if !f.Anonymous && f.PkgPath != "" {
// only consider exported fields.
continue
}
fieldPath := append(path, i)
fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]
if f.Anonymous {
forEachField(f.Type, fieldPath, do)
continue
}
name, ok := f.Tag.Lookup("toml")
if !ok {
name = f.Name
}
do(name, fieldPath)
}
}
+935 -223
View File
File diff suppressed because it is too large Load Diff
+240
View File
@@ -0,0 +1,240 @@
package toml
import (
"unicode/utf8"
)
type utf8Err struct {
Index int
Size int
}
func (u utf8Err) Zero() bool {
return u.Size == 0
}
// Verified that a given string is only made of valid UTF-8 characters allowed
// by the TOML spec:
//
// Any Unicode character may be used except those that must be escaped:
// quotation mark, backslash, and the control characters other than tab (U+0000
// to U+0008, U+000A to U+001F, U+007F).
//
// It is a copy of the Go 1.17 utf8.Valid implementation, tweaked to exit early
// when a character is not allowed.
//
// The returned utf8Err is Zero() if the string is valid, or contains the byte
// index and size of the invalid character.
//
// quotation mark => already checked
// backslash => already checked
// 0-0x8 => invalid
// 0x9 => tab, ok
// 0xA - 0x1F => invalid
// 0x7F => invalid
func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
offset := 0
for len(p) >= 8 {
// Combining two 32 bit loads allows the same code to be used
// for 32 and 64 bit platforms.
// The compiler can generate a 32bit load for first32 and second32
// on many platforms. See test/codegen/memcombine.go.
first32 := uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
second32 := uint32(p[4]) | uint32(p[5])<<8 | uint32(p[6])<<16 | uint32(p[7])<<24
if (first32|second32)&0x80808080 != 0 {
// Found a non ASCII byte (>= RuneSelf).
break
}
for i, b := range p[:8] {
if invalidAscii(b) {
err.Index = offset + i
err.Size = 1
return
}
}
p = p[8:]
offset += 8
}
n := len(p)
for i := 0; i < n; {
pi := p[i]
if pi < utf8.RuneSelf {
if invalidAscii(pi) {
err.Index = offset + i
err.Size = 1
return
}
i++
continue
}
x := first[pi]
if x == xx {
// Illegal starter byte.
err.Index = offset + i
err.Size = 1
return
}
size := int(x & 7)
if i+size > n {
// Short or invalid.
err.Index = offset + i
err.Size = n - i
return
}
accept := acceptRanges[x>>4]
if c := p[i+1]; c < accept.lo || accept.hi < c {
err.Index = offset + i
err.Size = 2
return
} else if size == 2 {
} else if c := p[i+2]; c < locb || hicb < c {
err.Index = offset + i
err.Size = 3
return
} else if size == 3 {
} else if c := p[i+3]; c < locb || hicb < c {
err.Index = offset + i
err.Size = 4
return
}
i += size
}
return
}
// Return the size of the next rune if valid, 0 otherwise.
func utf8ValidNext(p []byte) int {
c := p[0]
if c < utf8.RuneSelf {
if invalidAscii(c) {
return 0
}
return 1
}
x := first[c]
if x == xx {
// Illegal starter byte.
return 0
}
size := int(x & 7)
if size > len(p) {
// Short or invalid.
return 0
}
accept := acceptRanges[x>>4]
if c := p[1]; c < accept.lo || accept.hi < c {
return 0
} else if size == 2 {
} else if c := p[2]; c < locb || hicb < c {
return 0
} else if size == 3 {
} else if c := p[3]; c < locb || hicb < c {
return 0
}
return size
}
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func invalidAscii(b byte) bool {
return invalidAsciiTable[b]
}
// acceptRange gives the range of valid values for the second byte in a UTF-8
// sequence.
type acceptRange struct {
lo uint8 // lowest value for second byte.
hi uint8 // highest value for second byte.
}
// acceptRanges has size 16 to avoid bounds checks in the code that uses it.
var acceptRanges = [16]acceptRange{
0: {locb, hicb},
1: {0xA0, hicb},
2: {locb, 0x9F},
3: {0x90, hicb},
4: {locb, 0x8F},
}
// first is information about the first byte in a UTF-8 sequence.
var first = [256]uint8{
// 1 2 3 4 5 6 7 8 9 A B C D E F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x00-0x0F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x10-0x1F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x20-0x2F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x30-0x3F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x40-0x4F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x50-0x5F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x60-0x6F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x70-0x7F
// 1 2 3 4 5 6 7 8 9 A B C D E F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0x80-0x8F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0x90-0x9F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xA0-0xAF
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xB0-0xBF
xx, xx, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, // 0xC0-0xCF
s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, // 0xD0-0xDF
s2, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s4, s3, s3, // 0xE0-0xEF
s5, s6, s6, s6, s7, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xF0-0xFF
}
const (
// The default lowest and highest continuation byte.
locb = 0b10000000
hicb = 0b10111111
// These names of these constants are chosen to give nice alignment in the
// table below. The first nibble is an index into acceptRanges or F for
// special one-byte cases. The second nibble is the Rune length or the
// Status for the special one-byte case.
xx = 0xF1 // invalid: size 1
as = 0xF0 // ASCII: size 1
s1 = 0x02 // accept 0, size 2
s2 = 0x13 // accept 1, size 3
s3 = 0x03 // accept 0, size 3
s4 = 0x23 // accept 2, size 3
s5 = 0x34 // accept 3, size 4
s6 = 0x04 // accept 0, size 4
s7 = 0x44 // accept 4, size 4
)