Compare commits

..

25 Commits

Author SHA1 Message Date
Thomas Pelletier 45ea20024b Readme (#535) 2021-05-08 17:03:51 -04:00
Thomas Pelletier ea225df3ed v2: errors (#534)
```
name                              old time/op    new time/op    delta
UnmarshalDataset/config-32          86.7ms ± 2%    87.5ms ± 2%     ~     (p=0.113 n=9+10)
UnmarshalDataset/canada-32           129ms ± 4%     106ms ± 3%  -17.94%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    59.4ms ± 5%    58.7ms ± 5%     ~     (p=0.393 n=10+10)
UnmarshalDataset/twitter-32         27.0ms ± 7%    26.9ms ± 6%     ~     (p=0.720 n=10+9)
UnmarshalDataset/code-32             326ms ± 4%     322ms ± 7%     ~     (p=0.661 n=9+10)
UnmarshalDataset/example-32          510µs ±11%     526µs ± 7%     ~     (p=0.182 n=10+9)
UnmarshalSimple-32                  1.41µs ± 6%    1.41µs ± 4%     ~     (p=0.736 n=10+9)
ReferenceFile-32                    45.6µs ± 3%    43.9µs ±10%     ~     (p=0.089 n=10+10)

name                              old speed      new speed      delta
UnmarshalDataset/config-32        12.1MB/s ± 2%  12.0MB/s ± 2%     ~     (p=0.108 n=9+10)
UnmarshalDataset/canada-32        17.1MB/s ± 4%  20.9MB/s ± 3%  +21.86%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32  9.41MB/s ± 5%  9.51MB/s ± 5%     ~     (p=0.362 n=10+10)
UnmarshalDataset/twitter-32       16.4MB/s ± 8%  16.5MB/s ± 6%     ~     (p=0.704 n=10+9)
UnmarshalDataset/code-32          8.24MB/s ± 4%  8.34MB/s ± 7%     ~     (p=0.675 n=9+10)
UnmarshalDataset/example-32       15.9MB/s ±11%  15.4MB/s ± 7%     ~     (p=0.182 n=10+9)
ReferenceFile-32                   115MB/s ± 4%   120MB/s ±10%     ~     (p=0.085 n=10+10)

name                              old alloc/op   new alloc/op   delta
UnmarshalDataset/config-32          16.9MB ± 0%    16.9MB ± 0%   -0.02%  (p=0.000 n=10+10)
UnmarshalDataset/canada-32          76.8MB ± 0%    74.3MB ± 0%   -3.31%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    37.3MB ± 0%    37.1MB ± 0%   -0.60%  (p=0.000 n=9+10)
UnmarshalDataset/twitter-32         15.6MB ± 0%    15.6MB ± 0%   -0.09%  (p=0.000 n=10+10)
UnmarshalDataset/code-32            60.2MB ± 0%    59.3MB ± 0%   -1.51%  (p=0.000 n=10+9)
UnmarshalDataset/example-32          238kB ± 0%     238kB ± 0%   -0.18%  (p=0.000 n=10+10)
ReferenceFile-32                    11.8kB ± 0%    11.8kB ± 0%     ~     (all equal)

name                              old allocs/op  new allocs/op  delta
UnmarshalDataset/config-32            653k ± 0%      645k ± 0%   -1.20%  (p=0.000 n=10+6)
UnmarshalDataset/canada-32           1.01M ± 0%     0.90M ± 0%  -11.04%  (p=0.000 n=9+10)
UnmarshalDataset/citm_catalog-32      384k ± 0%      370k ± 0%   -3.75%  (p=0.000 n=10+10)
UnmarshalDataset/twitter-32           160k ± 0%      157k ± 0%   -1.32%  (p=0.000 n=10+10)
UnmarshalDataset/code-32             2.97M ± 0%     2.91M ± 0%   -2.15%  (p=0.000 n=10+7)
UnmarshalDataset/example-32          3.69k ± 0%     3.63k ± 0%   -1.52%  (p=0.000 n=10+10)
ReferenceFile-32                       253 ± 0%       253 ± 0%     ~     (all equal)
```
2021-05-08 16:04:25 -04:00
Thomas Pelletier 4545a3e94b ci: remove benchmarks
Both github actions and my own VPS have too much noise to be useful.
2021-05-07 23:34:17 -04:00
Vincent Serpoul 3f2bb0b363 golangci-lint (#530) 2021-05-06 22:29:21 -04:00
Vincent Serpoul 201d5dd422 golangci-lint: misc (#529) 2021-04-27 20:29:00 -04:00
Thomas Pelletier 1e80267558 parser: require \n after parsing integer in kv (#527)
Fixes #526
2021-04-24 09:57:21 -04:00
Thomas Pelletier 931f02a519 encoder: support indentation (#525) 2021-04-23 17:08:27 -04:00
Thomas Pelletier a533331aee v2: benchdiff (#524) 2021-04-23 15:21:41 -04:00
Vincent Serpoul 466faaab9f golangci-lint: marshaler, strict (#523) 2021-04-23 10:41:21 -04:00
Thomas Pelletier e443b4fdb8 encoder: support TextMarshaler (#522)
Fixes #521
2021-04-22 10:13:41 -04:00
Vincent Serpoul 2b1c52dddd golangci-lint: decoder/unmarshal (#518) 2021-04-22 09:29:23 -04:00
Thomas Pelletier 21445f5170 Add test for issue #424 2021-04-21 22:27:30 -04:00
Thomas Pelletier 9ba52996d8 Encoder multiline array (#520) 2021-04-21 22:13:45 -04:00
Thomas Pelletier 6fe332a869 Encoder inline tables (#519) 2021-04-21 19:11:15 -04:00
Thomas Pelletier 32c1a8d372 encoder: move nspow into the parseLocalTime 2021-04-20 23:19:40 -04:00
Thomas Pelletier ee102a3528 decoder: fix time fractional parsing 2021-04-20 23:16:08 -04:00
Thomas Pelletier 9b67e40640 decoder: strict mode (#512) 2021-04-20 21:26:22 -04:00
Vincent Serpoul dca2103910 golangci-lint: marshaler (#516) 2021-04-20 20:24:44 -04:00
Cameron Moore a713a96e69 Add more newline tests for scanner (#515) 2021-04-16 19:07:29 -04:00
Cameron Moore a7b50eb8f1 Tidy (#511)
* Disconnect package godoc comment from imported file

* Add missing newline in toml.abnf

* Tag testing helper funcs
2021-04-15 16:49:19 -04:00
Cameron Moore 24b62ebe61 Simplify scanFollows usage (#510)
Use static functions to avoid declaring global vars and creating more
package init costs.  This change has no negative effects on benchmarks
in my testing.
2021-04-15 16:48:19 -04:00
Thomas Pelletier 9bc4641a49 ci-lint: disable ifshort 2021-04-15 13:37:24 -04:00
Thomas Pelletier b86b890b8d decoder: handle private anonymous structs
Ref #508
2021-04-15 12:49:24 -04:00
Vincent Serpoul 080baa8574 golangci-lint: localtime (#509) 2021-04-15 12:44:31 -04:00
Thomas Pelletier 0537b928df decoder: add test for #507 2021-04-15 11:36:36 -04:00
33 changed files with 3339 additions and 1471 deletions
-3
View File
@@ -23,6 +23,3 @@ jobs:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Run unit tests - name: Run unit tests
run: go test -race ./... run: go test -race ./...
- name: Run benchmark tests
run: go test -race ./...
working-directory: benchmark
+7 -4
View File
@@ -4,6 +4,9 @@ golangci-lint-version = "1.39.0"
[linters-settings.wsl] [linters-settings.wsl]
allow-assign-and-anything = true allow-assign-and-anything = true
[linters-settings.exhaustive]
default-signifies-exhaustive = true
[linters] [linters]
disable-all = true disable-all = true
enable = [ enable = [
@@ -18,13 +21,13 @@ enable = [
"errcheck", "errcheck",
"errorlint", "errorlint",
"exhaustive", "exhaustive",
"exhaustivestruct", # "exhaustivestruct",
"exportloopref", "exportloopref",
"forbidigo", "forbidigo",
"forcetypeassert", "forcetypeassert",
"funlen", "funlen",
"gci", "gci",
"gochecknoglobals", # "gochecknoglobals",
"gochecknoinits", "gochecknoinits",
"gocognit", "gocognit",
"goconst", "goconst",
@@ -45,7 +48,7 @@ enable = [
"gosec", "gosec",
"gosimple", "gosimple",
"govet", "govet",
"ifshort", # "ifshort",
"importas", "importas",
"ineffassign", "ineffassign",
"lll", "lll",
@@ -76,6 +79,6 @@ enable = [
"varcheck", "varcheck",
"wastedassign", "wastedassign",
"whitespace", "whitespace",
"wrapcheck", # "wrapcheck",
"wsl" "wsl"
] ]
+334 -40
View File
@@ -1,57 +1,351 @@
# go-toml V2 # go-toml v2
Development branch. Use at your own risk. Go library for the [TOML](https://toml.io/en/) format.
[👉 Discussion on github](https://github.com/pelletier/go-toml/discussions/471). This library supports [TOML v1.0.0](https://toml.io/en/v1.0.0).
* `toml.Unmarshal()` should work as well as v1.
## Must do ## Development status
### Unmarshal This is the upcoming major version of go-toml. It is currently in active
development. As of release v2.0.0-beta.1, the library has reached feature parity
with v1, and fixes a lot known bugs and performance issues along the way.
- [x] Unmarshal into maps. If you do not need the advanced document editing features of v1, you are
- [x] Support Array Tables. encouraged to try out this version.
- [x] Unmarshal into pointers.
- [x] Support Date / times.
- [x] Support struct tags annotations.
- [x] Support Arrays.
- [x] Support Unmarshaler interface.
- [x] Original go-toml unmarshal tests pass.
- [x] Benchmark!
- [x] Abstract AST.
- [x] Original go-toml testgen tests pass.
- [x] Track file position (line, column) for errors.
- [ ] Strict mode.
- [ ] Document Unmarshal / Decode
### Marshal 👉 [Roadmap for v2](https://github.com/pelletier/go-toml/discussions/506).
- [x] Minimal implementation
- [x] Multiline strings
- [ ] Multiline arrays
- [ ] `inline` tag for tables
- [ ] Optional indentation
- [ ] Option to pick default quotes
### Document ## Documentation
- [ ] Gather requirements and design API. Full API, examples, and implementation notes are available in the Go documentation.
## Ideas [![Go Reference](https://pkg.go.dev/badge/github.com/pelletier/go-toml/v2.svg)](https://pkg.go.dev/github.com/pelletier/go-toml/v2)
- [ ] Allow types to implement a `ASTUnmarshaler` interface to unmarshal
straight from the AST?
- [x] Rewrite AST to use a single array as storage instead of one allocation per
node.
- [ ] Provide "minimal allocations" option that uses `unsafe` to reuse the input
byte array as storage for strings.
- [x] Cache reflection operations per type.
- [ ] Optimize tracker pass.
## Differences with v1 ## Import
* [unmarshal](https://github.com/pelletier/go-toml/discussions/488) ```go
import "github.com/pelletier/go-toml/v2"
```
## Features
### Stdlib behavior
As much as possible, this library is designed to behave similarly as the
standard library's `encoding/json`.
### Performance
While go-toml favors usability, it is written with performance in mind. Most
operations should not be shockingly slow.
### Strict mode
`Decoder` can be set to "strict mode", which makes it error when some parts of
the TOML document was not prevent in the target structure. This is a great way
to check for typos. [See example in the documentation][strict].
[strict]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#example-Decoder.SetStrict
### Contextualized errors
When decoding errors occur, go-toml returns [`DecodeError`][decode-err]), which
contains a human readable contextualized version of the error. For example:
```
2| key1 = "value1"
3| key2 = "missing2"
| ~~~~ missing field
4| key3 = "missing3"
5| key4 = "value4"
```
[decode-err]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#DecodeError
### Local date and time support
TOML supports native [local date/times][ldt]. It allows to represent a given
date, time, or date-time without relation to a timezone or offset. To support
this use-case, go-toml provides [`LocalDate`][tld], [`LocalTime`][tlt], and
[`LocalDateTime`][tldt]. Those types can be transformed to and from `time.Time`,
making them convenient yet unambiguous structures for their respective TOML
representation.
[ldt]: https://toml.io/en/v1.0.0#local-date-time
[tld]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#LocalDate
[tlt]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#LocalTime
[tldt]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#LocalDateTime
## Getting started
Given the following struct, let's see how to read it and write it as TOML:
```go
type MyConfig struct {
Version int
Name string
Tags []string
}
```
### Unmarshaling
[`Unmarshal`][unmarshal] reads a TOML document and fills a Go structure with its
content. For example:
```go
doc := `
version = 2
name = "go-toml"
tags = ["go", "toml"]
`
var cfg MyConfig
err := toml.Unmarshal([]byte(doc), &cfg)
if err != nil {
panic(err)
}
fmt.Println("version:", cfg.Version)
fmt.Println("name:", cfg.Name)
fmt.Println("tags:", cfg.Tags)
// Output:
// version: 2
// name: go-toml
// tags: [go toml]
```
[unmarshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Unmarshal
### Marshaling
[`Marshal`][marshal] is the opposite of Unmarshal: it represents a Go structure
as a TOML document:
```go
cfg := MyConfig{
Version: 2,
Name: "go-toml",
Tags: []string{"go", "toml"},
}
b, err := toml.Marshal(cfg)
if err != nil {
panic(err)
}
fmt.Println(string(b))
// Output:
// Version = 2
// Name = 'go-toml'
// Tags = ['go', 'toml']
```
[marshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Marshal
## Migrating from v1
This section describes the differences between v1 and v2, with some pointers on
how to get the original behavior when possible.
### Decoding / Unmarshal
#### Automatic field name guessing
When unmarshaling to a struct, if a key in the TOML document does not exactly
match the name of a struct field or any of the `toml`-tagged field, v1 tries
multiple variations of the key ([code][v1-keys]).
V2 instead does a case-insensitive matching, like `encoding/json`.
This could impact you if you are relying on casing to differentiate two fields,
and one of them is a not using the `toml` struct tag. The recommended solution
is to be specific about tag names for those fields using the `toml` struct tag.
[v1-keys]: https://github.com/pelletier/go-toml/blob/a2e52561804c6cd9392ebf0048ca64fe4af67a43/marshal.go#L775-L781
#### Ignore preexisting value in interface
When decoding into a non-nil `interface{}`, go-toml v1 uses the type of the
element in the interface to decode the object. For example:
```go
type inner struct {
B interface{}
}
type doc struct {
A interface{}
}
d := doc{
A: inner{
B: "Before",
},
}
data := `
[A]
B = "After"
`
toml.Unmarshal([]byte(data), &d)
fmt.Printf("toml v1: %#v\n", d)
// toml v1: main.doc{A:main.inner{B:"After"}}
```
In this case, field `A` is of type `interface{}`, containing a `inner` struct.
V1 sees that type and uses it when decoding the object.
When decoding an object into an `interface{}`, V2 instead disregards whatever
value the `interface{}` may contain and replaces it with a
`map[string]interface{}`. With the same data structure as above, here is what
the result looks like:
```go
toml.Unmarshal([]byte(data), &d)
fmt.Printf("toml v2: %#v\n", d)
// toml v2: main.doc{A:map[string]interface {}{"B":"After"}}
```
This is to match `encoding/json`'s behavior. There is no way to make the v2
decoder behave like v1.
#### Values out of array bounds ignored
When decoding into an array, v1 returns an error when the number of elements
contained in the doc is superior to the capacity of the array. For example:
```go
type doc struct {
A [2]string
}
d := doc{}
err := toml.Unmarshal([]byte(`A = ["one", "two", "many"]`), &d)
fmt.Println(err)
// (1, 1): unmarshal: TOML array length (3) exceeds destination array length (2)
```
In the same situation, v2 ignores the last value:
```go
err := toml.Unmarshal([]byte(`A = ["one", "two", "many"]`), &d)
fmt.Println("err:", err, "d:", d)
// err: <nil> d: {[one two]}
```
This is to match `encoding/json`'s behavior. There is no way to make the v2
decoder behave like v1.
#### Support for `toml.Unmarshaler` has been dropped
This method was not widely used, poorly defined, and added a lot of complexity.
A similar effect can be achieved by implementing the `encoding.TextUnmarshaler`
interface and use strings.
### Encoding / Marshal
#### Default struct fields order
V1 emits struct fields order alphabetically by default. V2 struct fields are
emitted in order they are defined. For example:
```go
type S struct {
B string
A string
}
data := S{
B: "B",
A: "A",
}
b, _ := tomlv1.Marshal(data)
fmt.Println("v1:\n" + string(b))
b, _ = tomlv2.Marshal(data)
fmt.Println("v2:\n" + string(b))
// Output:
// v1:
// A = "A"
// B = "B"
// v2:
// B = 'B'
// A = 'A'
```
There is no way to make v2 encoder behave like v1. A workaround could be to
manually sort the fields alphabetically in the struct definition.
#### No indentation by default
V1 automatically indents content of tables by default. V2 does not. However the
same behavior can be obtained using [`Encoder.SetIndentTables`][sit]. For example:
```go
data := map[string]interface{}{
"table": map[string]string{
"key": "value",
},
}
b, _ := tomlv1.Marshal(data)
fmt.Println("v1:\n" + string(b))
b, _ = tomlv2.Marshal(data)
fmt.Println("v2:\n" + string(b))
buf := bytes.Buffer{}
enc := tomlv2.NewEncoder(&buf)
enc.SetIndentTables(true)
enc.Encode(data)
fmt.Println("v2 Encoder:\n" + string(buf.Bytes()))
// Output:
// v1:
//
// [table]
// key = "value"
//
// v2:
// [table]
// key = 'value'
//
//
// v2 Encoder:
// [table]
// key = 'value'
```
[sit]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Encoder.SetIndentTables
#### Keys and strings are single quoted
V1 always uses double quotes (`"`) around strings and keys that cannot be
represented bare (unquoted). V2 uses single quotes instead by default (`'`),
unless a character cannot be represented, then falls back to double quotes.
There is no way to make v2 encoder behave like v1.
#### `TextMarshaler` emits as a string, not TOML
Types that implement [`encoding.TextMarshaler`][tm] can emit arbitrary TOML in
v1. The encoder would append the result to the output directly. In v2 the result
is wrapped in a string. As a result, this interface cannot be implemented by the
root object.
There is no way to make v2 encoder behave like v1.
[tm]: https://golang.org/pkg/encoding/#TextMarshaler
## License ## License
+13 -26
View File
@@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -32,20 +33,12 @@ func TestUnmarshalDatasetCode(t *testing.T) {
for _, tc := range bench_inputs { for _, tc := range bench_inputs {
buf := fixture(t, tc.name) buf := fixture(t, tc.name)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
for _, r := range runners { var v interface{}
if r.name == "bs" && tc.name == "canada" { check(t, toml.Unmarshal(buf, &v))
t.Skip("skipping: burntsushi can't handle mixed arrays")
}
t.Run(r.name, func(t *testing.T) { b, err := json.Marshal(v)
var v interface{} check(t, err)
check(t, r.unmarshal(buf, &v)) require.Equal(t, len(b), tc.jsonLen)
b, err := json.Marshal(v)
check(t, err)
require.Equal(t, len(b), tc.jsonLen)
})
}
}) })
} }
} }
@@ -54,19 +47,13 @@ func BenchmarkUnmarshalDataset(b *testing.B) {
for _, tc := range bench_inputs { for _, tc := range bench_inputs {
buf := fixture(b, tc.name) buf := fixture(b, tc.name)
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
bench(b, func(r runner, b *testing.B) { b.SetBytes(int64(len(buf)))
if r.name == "bs" && tc.name == "canada" { b.ReportAllocs()
b.Skip("skipping: burntsushi can't handle mixed arrays") b.ResetTimer()
} for i := 0; i < b.N; i++ {
var v interface{}
b.SetBytes(int64(len(buf))) check(b, toml.Unmarshal(buf, &v))
b.ReportAllocs() }
b.ResetTimer()
for i := 0; i < b.N; i++ {
var v interface{}
check(b, r.unmarshal(buf, &v))
}
})
}) })
} }
} }
+21 -46
View File
@@ -5,44 +5,21 @@ import (
"testing" "testing"
"time" "time"
tomlbs "github.com/BurntSushi/toml"
tomlv1 "github.com/pelletier/go-toml-v1"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type runner struct {
name string
unmarshal func([]byte, interface{}) error
}
var runners = []runner{
{"v2", toml.Unmarshal},
{"v1", tomlv1.Unmarshal},
{"bs", tomlbs.Unmarshal},
}
func bench(b *testing.B, f func(r runner, b *testing.B)) {
for _, r := range runners {
b.Run(r.name, func(b *testing.B) {
f(r, b)
})
}
}
func BenchmarkUnmarshalSimple(b *testing.B) { func BenchmarkUnmarshalSimple(b *testing.B) {
bench(b, func(r runner, b *testing.B) { d := struct {
d := struct { A string
A string }{}
}{} doc := []byte(`A = "hello"`)
doc := []byte(`A = "hello"`) for i := 0; i < b.N; i++ {
for i := 0; i < b.N; i++ { err := toml.Unmarshal(doc, &d)
err := r.unmarshal(doc, &d) if err != nil {
if err != nil { panic(err)
panic(err)
}
} }
}) }
} }
type benchmarkDoc struct { type benchmarkDoc struct {
@@ -152,22 +129,20 @@ type benchmarkDoc struct {
} }
func BenchmarkReferenceFile(b *testing.B) { func BenchmarkReferenceFile(b *testing.B) {
bench(b, func(r runner, b *testing.B) { bytes, err := ioutil.ReadFile("benchmark.toml")
bytes, err := ioutil.ReadFile("benchmark.toml") if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(bytes)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
d := benchmarkDoc{}
err := toml.Unmarshal(bytes, &d)
if err != nil { if err != nil {
b.Fatal(err) panic(err)
} }
b.SetBytes(int64(len(bytes))) }
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
d := benchmarkDoc{}
err := r.unmarshal(bytes, &d)
if err != nil {
panic(err)
}
}
})
} }
func TestReferenceFile(t *testing.T) { func TestReferenceFile(t *testing.T) {
-14
View File
@@ -1,14 +0,0 @@
module github.com/pelletier/go-toml/v2/benchmark
go 1.16
replace github.com/pelletier/go-toml/v2 => ../
replace github.com/pelletier/go-toml-v1 => github.com/pelletier/go-toml v1.8.1
require (
github.com/BurntSushi/toml v0.3.1
github.com/pelletier/go-toml-v1 v0.0.0-00010101000000-000000000000
github.com/pelletier/go-toml/v2 v2.0.0-00010101000000-000000000000
github.com/stretchr/testify v1.7.0
)
-16
View File
@@ -1,16 +0,0 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM=
github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+97 -119
View File
@@ -1,11 +1,8 @@
package toml package toml
import ( import (
"errors"
"fmt"
"math" "math"
"strconv" "strconv"
"strings"
"time" "time"
) )
@@ -59,14 +56,12 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, nil return date, nil
} }
var errNotDigit = errors.New("not a digit")
func parseDecimalDigits(b []byte) (int, error) { func parseDecimalDigits(b []byte) (int, error) {
v := 0 v := 0
for _, c := range b { for i, c := range b {
if !isDigit(c) { if !isDigit(c) {
return 0, fmt.Errorf("%s: %w", b, errNotDigit) return 0, newDecodeError(b[i:i+1], "should be a digit (0-9)")
} }
v *= 10 v *= 10
@@ -76,13 +71,14 @@ func parseDecimalDigits(b []byte) (int, error) {
return v, nil return v, nil
} }
var errParseDateTimeMissingInfo = errors.New("date-time missing timezone information")
func parseDateTime(b []byte) (time.Time, error) { func parseDateTime(b []byte) (time.Time, error) {
// offset-date-time = full-date time-delim full-time // offset-date-time = full-date time-delim full-time
// full-time = partial-time time-offset // full-time = partial-time time-offset
// time-offset = "Z" / time-numoffset // time-offset = "Z" / time-numoffset
// time-numoffset = ( "+" / "-" ) time-hour ":" time-minute // time-numoffset = ( "+" / "-" ) time-hour ":" time-minute
originalBytes := b
dt, b, err := parseLocalDateTime(b) dt, b, err := parseLocalDateTime(b)
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
@@ -91,7 +87,7 @@ func parseDateTime(b []byte) (time.Time, error) {
var zone *time.Location var zone *time.Location
if len(b) == 0 { if len(b) == 0 {
return time.Time{}, errParseDateTimeMissingInfo return time.Time{}, newDecodeError(originalBytes, "date-time is missing timezone")
} }
if b[0] == 'Z' { if b[0] == 'Z' {
@@ -134,19 +130,12 @@ func parseDateTime(b []byte) (time.Time, error) {
return t, nil return t, nil
} }
var (
errParseLocalDateTimeWrongLength = errors.New(
"local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNN]",
)
errParseLocalDateTimeWrongSeparator = errors.New("datetime separator is expected to be T or a space")
)
func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
var dt LocalDateTime var dt LocalDateTime
const localDateTimeByteLen = 11 const localDateTimeByteMinLen = 11
if len(b) < localDateTimeByteLen { if len(b) < localDateTimeByteMinLen {
return dt, nil, errParseLocalDateTimeWrongLength return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]")
} }
date, err := parseLocalDate(b[:10]) date, err := parseLocalDate(b[:10])
@@ -157,7 +146,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
sep := b[10] sep := b[10]
if sep != 'T' && sep != ' ' { if sep != 'T' && sep != ' ' {
return dt, nil, errParseLocalDateTimeWrongSeparator return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space")
} }
t, rest, err := parseLocalTime(b[11:]) t, rest, err := parseLocalTime(b[11:])
@@ -169,17 +158,19 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
return dt, rest, nil return dt, rest, nil
} }
var errParseLocalTimeWrongLength = errors.New("times are expected to have the format HH:MM:SS[.NNNNNN]")
// parseLocalTime is a bit different because it also returns the remaining // parseLocalTime is a bit different because it also returns the remaining
// []byte that is didn't need. This is to allow parseDateTime to parse those // []byte that is didn't need. This is to allow parseDateTime to parse those
// remaining bytes as a timezone. // remaining bytes as a timezone.
//nolint:cyclop,funlen
func parseLocalTime(b []byte) (LocalTime, []byte, error) { func parseLocalTime(b []byte) (LocalTime, []byte, error) {
var t LocalTime var (
nspow = [10]int{0, 1e8, 1e7, 1e6, 1e5, 1e4, 1e3, 1e2, 1e1, 1e0}
t LocalTime
)
const localTimeByteLen = 8 const localTimeByteLen = 8
if len(b) < localTimeByteLen { if len(b) < localTimeByteLen {
return t, nil, errParseLocalTimeWrongLength return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
var err error var err error
@@ -207,23 +198,37 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, err return t, nil, err
} }
if len(b) >= 15 && b[8] == '.' { if len(b) >= 9 && b[8] == '.' {
t.Nanosecond, err = parseDecimalDigits(b[9:15]) frac := 0
if err != nil { digits := 0
return t, nil, err
for i, c := range b[9:] {
if !isDigit(c) {
if i == 0 {
return t, nil, newDecodeError(b[i:i+1], "need at least one digit after fraction point")
}
break
}
//nolint:gomnd
if i >= 9 {
return t, nil, newDecodeError(b[i:i+1], "maximum precision for date time is nanosecond")
}
frac *= 10
frac += int(c - '0')
digits++
} }
return t, b[15:], nil t.Nanosecond = frac * nspow[digits]
return t, b[9+digits:], nil
} }
return t, b[8:], nil return t, b[8:], nil
} }
var (
errParseFloatStartDot = errors.New("float cannot start with a dot")
errParseFloatEndDot = errors.New("float cannot end with a dot")
)
//nolint:cyclop //nolint:cyclop
func parseFloat(b []byte) (float64, error) { func parseFloat(b []byte) (float64, error) {
//nolint:godox //nolint:godox
@@ -232,150 +237,123 @@ func parseFloat(b []byte) (float64, error) {
return math.NaN(), nil return math.NaN(), nil
} }
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b)
err := numberContainsInvalidUnderscore(tok)
if err != nil { if err != nil {
return 0, err return 0, err
} }
cleanedVal := cleanupNumberToken(tok) if cleaned[0] == '.' {
if cleanedVal[0] == '.' { return 0, newDecodeError(b, "float cannot start with a dot")
return 0, errParseFloatStartDot
} }
if cleanedVal[len(cleanedVal)-1] == '.' { if cleaned[len(cleaned)-1] == '.' {
return 0, errParseFloatEndDot return 0, newDecodeError(b, "float cannot end with a dot")
} }
f, err := strconv.ParseFloat(cleanedVal, 64) f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseFloat %w", err) return 0, newDecodeError(b, "coudn't parse float: %w", err)
} }
return f, nil return f, nil
} }
func parseIntHex(b []byte) (int64, error) { func parseIntHex(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 16, 64) i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntHex %w", err) return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntOct(b []byte) (int64, error) { func parseIntOct(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 8, 64) i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntOct %w", err) return 0, newDecodeError(b, "couldn't parse octal number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntBin(b []byte) (int64, error) { func parseIntBin(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 2, 64) i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntBin %w", err) return 0, newDecodeError(b, "couldn't parse binary number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntDec(b []byte) (int64, error) { func parseIntDec(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal, 10, 64) i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't parseIntDec %w", err) return 0, newDecodeError(b, "couldn't parse decimal number: %w", err)
} }
return i, nil return i, nil
} }
func numberContainsInvalidUnderscore(value string) error { func checkAndRemoveUnderscores(b []byte) ([]byte, error) {
// For large numbers, you may use underscores between digits to enhance if len(b) == 0 {
// readability. Each underscore must be surrounded by at least one digit on return b, nil
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
} }
return nil if b[0] == '_' {
} return nil, newDecodeError(b[0:1], "number cannot start with underscore")
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
} }
return nil if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore")
}
// fast path
i := 0
for ; i < len(b); i++ {
if b[i] == '_' {
break
}
}
if i == len(b) {
return b, nil
}
before := false
cleaned := make([]byte, i, len(b))
copy(cleaned, b)
for i++; i < len(b); i++ {
c := b[i]
if c == '_' {
if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores")
}
before = false
} else {
before = true
cleaned = append(cleaned, c)
}
}
return cleaned, nil
} }
func cleanupNumberToken(value string) string {
cleanedVal := strings.ReplaceAll(value, "_", "")
return cleanedVal
}
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var (
errInvalidUnderscore = errors.New("invalid use of _ in number")
errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
)
+4
View File
@@ -0,0 +1,4 @@
/*
Package toml is a library to read and write TOML documents.
*/
package toml
+45 -4
View File
@@ -18,15 +18,49 @@ type DecodeError struct {
message string message string
line int line int
column int column int
key Key
human string human string
} }
// StrictMissingError occurs in a TOML document that does not have a
// corresponding field in the target value. It contains all the missing fields
// in Errors.
//
// Emitted by Decoder when SetStrict(true) was called.
type StrictMissingError struct {
// One error per field that could not be found.
Errors []DecodeError
}
// Error returns the canonical string for this error.
func (s *StrictMissingError) Error() string {
return "strict mode: fields in the document are missing in the target struct"
}
// String returns a human readable description of all errors.
func (s *StrictMissingError) String() string {
var buf strings.Builder
for i, e := range s.Errors {
if i > 0 {
buf.WriteString("\n---\n")
}
buf.WriteString(e.String())
}
return buf.String()
}
type Key []string
// internal version of DecodeError that is used as the base to create a // internal version of DecodeError that is used as the base to create a
// DecodeError with full context. // DecodeError with full context.
type decodeError struct { type decodeError struct {
highlight []byte highlight []byte
message string message string
key Key // optional
} }
func (de *decodeError) Error() string { func (de *decodeError) Error() string {
@@ -36,13 +70,13 @@ func (de *decodeError) Error() string {
func newDecodeError(highlight []byte, format string, args ...interface{}) error { func newDecodeError(highlight []byte, format string, args ...interface{}) error {
return &decodeError{ return &decodeError{
highlight: highlight, highlight: highlight,
message: fmt.Sprintf(format, args...), message: fmt.Errorf(format, args...).Error(),
} }
} }
// Error returns the error message contained in the DecodeError. // Error returns the error message contained in the DecodeError.
func (e *DecodeError) Error() string { func (e *DecodeError) Error() string {
return e.message return "toml: " + e.message
} }
// String returns the human-readable contextualized error. This string is multi-line. // String returns the human-readable contextualized error. This string is multi-line.
@@ -56,7 +90,13 @@ func (e *DecodeError) Position() (row int, column int) {
return e.line, e.column return e.line, e.column
} }
// decodeErrorFromHighlight creates a DecodeError referencing to a highlighted // Key that was being processed when the error occurred. The key is present only
// if this DecodeError is part of a StrictMissingError.
func (e *DecodeError) Key() Key {
return e.key
}
// decodeErrorFromHighlight creates a DecodeError referencing a highlighted
// range of bytes from document. // range of bytes from document.
// //
// highlight needs to be a sub-slice of document, or this function panics. // highlight needs to be a sub-slice of document, or this function panics.
@@ -64,7 +104,7 @@ func (e *DecodeError) Position() (row int, column int) {
// The function copies all bytes used in DecodeError, so that document and // The function copies all bytes used in DecodeError, so that document and
// highlight can be freely deallocated. // highlight can be freely deallocated.
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) error { func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
if de == nil { if de == nil {
return nil return nil
} }
@@ -137,6 +177,7 @@ func wrapDecodeError(document []byte, de *decodeError) error {
message: errMessage, message: errMessage,
line: errLine, line: errLine,
column: errColumn, column: errColumn,
key: de.key,
human: buf.String(), human: buf.String(),
} }
} }
+21
View File
@@ -3,6 +3,7 @@ package toml
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"strings" "strings"
"testing" "testing"
@@ -179,3 +180,23 @@ line 5`,
}) })
} }
} }
func ExampleDecodeError() {
doc := `name = 123__456`
s := map[string]interface{}{}
err := Unmarshal([]byte(doc), &s)
fmt.Println(err)
de := err.(*DecodeError)
fmt.Println(de.String())
row, col := de.Position()
fmt.Println("error occured at row", row, "column", col)
// Output:
// toml: number must have at least one digit between underscores
// 1| name = 123__456
// | ~~ number must have at least one digit between underscores
// error occured at row 1 column 11
}
@@ -4,6 +4,7 @@ package imported_tests
// defaults of v2. // defaults of v2.
import ( import (
"fmt"
"testing" "testing"
"time" "time"
@@ -164,3 +165,34 @@ stringlist = []
require.Equal(t, string(expected), string(result)) require.Equal(t, string(expected), string(result))
} }
type textMarshaler struct {
FirstName string
LastName string
}
func (m textMarshaler) MarshalText() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestTextMarshaler(t *testing.T) {
type wrap struct {
TM textMarshaler
}
m := textMarshaler{FirstName: "Sally", LastName: "Fields"}
t.Run("at root", func(t *testing.T) {
_, err := toml.Marshal(m)
// in v2 we do not allow TextMarshaler at root
require.Error(t, err)
})
t.Run("leaf", func(t *testing.T) {
res, err := toml.Marshal(wrap{m})
require.NoError(t, err)
require.Equal(t, "TM = 'Sally Fields'\n", string(res))
})
}
+275 -271
View File
@@ -7,6 +7,7 @@ package imported_tests
// marked as skipped until we figure out if that's something we want in v2. // marked as skipped until we figure out if that's something we want in v2.
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -611,16 +612,6 @@ func (x *IntOrString) MarshalTOML() ([]byte, error) {
return []byte(s), nil return []byte(s), nil
} }
type textMarshaler struct {
FirstName string
LastName string
}
func (m textMarshaler) MarshalText() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestUnmarshalTextMarshaler(t *testing.T) { func TestUnmarshalTextMarshaler(t *testing.T) {
var nested = struct { var nested = struct {
Friends textMarshaler `toml:"friends"` Friends textMarshaler `toml:"friends"`
@@ -1429,211 +1420,210 @@ func TestUnmarshalPreservesUnexportedFields(t *testing.T) {
}) })
} }
// func TestUnmarshalLocalDate(t *testing.T) {
//func TestUnmarshalLocalDate(t *testing.T) { t.Run("ToLocalDate", func(t *testing.T) {
// t.Run("ToLocalDate", func(t *testing.T) { type dateStruct struct {
// type dateStruct struct { Date toml.LocalDate
// Date toml.LocalDate }
// }
// doc := `date = 1979-05-27`
// doc := `date = 1979-05-27`
// var obj dateStruct
// var obj dateStruct
// err := toml.Unmarshal([]byte(doc), &obj)
// err := toml.Unmarshal([]byte(doc), &obj)
// if err != nil {
// if err != nil { t.Fatal(err)
// t.Fatal(err) }
// }
// if obj.Date.Year != 1979 {
// if obj.Date.Year != 1979 { t.Errorf("expected year 1979, got %d", obj.Date.Year)
// t.Errorf("expected year 1979, got %d", obj.Date.Year) }
// } if obj.Date.Month != 5 {
// if obj.Date.Month != 5 { t.Errorf("expected month 5, got %d", obj.Date.Month)
// t.Errorf("expected month 5, got %d", obj.Date.Month) }
// } if obj.Date.Day != 27 {
// if obj.Date.Day != 27 { t.Errorf("expected day 27, got %d", obj.Date.Day)
// t.Errorf("expected day 27, got %d", obj.Date.Day) }
// } })
// })
// t.Run("ToLocalDate", func(t *testing.T) {
// t.Run("ToLocalDate", func(t *testing.T) { type dateStruct struct {
// type dateStruct struct { Date time.Time
// Date time.Time }
// }
// doc := `date = 1979-05-27`
// doc := `date = 1979-05-27`
// var obj dateStruct
// var obj dateStruct
// err := toml.Unmarshal([]byte(doc), &obj)
// err := toml.Unmarshal([]byte(doc), &obj)
// if err != nil {
// if err != nil { t.Fatal(err)
// t.Fatal(err) }
// }
// if obj.Date.Year() != 1979 {
// if obj.Date.Year() != 1979 { t.Errorf("expected year 1979, got %d", obj.Date.Year())
// t.Errorf("expected year 1979, got %d", obj.Date.Year()) }
// } if obj.Date.Month() != 5 {
// if obj.Date.Month() != 5 { t.Errorf("expected month 5, got %d", obj.Date.Month())
// t.Errorf("expected month 5, got %d", obj.Date.Month()) }
// } if obj.Date.Day() != 27 {
// if obj.Date.Day() != 27 { t.Errorf("expected day 27, got %d", obj.Date.Day())
// t.Errorf("expected day 27, got %d", obj.Date.Day()) }
// } })
// }) }
//}
// func TestUnmarshalLocalDateTime(t *testing.T) {
//func TestUnmarshalLocalDateTime(t *testing.T) { examples := []struct {
// examples := []struct { name string
// name string in string
// in string out toml.LocalDateTime
// out toml.LocalDateTime }{
// }{ {
// { name: "normal",
// name: "normal", in: "1979-05-27T07:32:00",
// in: "1979-05-27T07:32:00", out: toml.LocalDateTime{
// out: toml.LocalDateTime{ Date: toml.LocalDate{
// Date: toml.LocalDate{ Year: 1979,
// Year: 1979, Month: 5,
// Month: 5, Day: 27,
// Day: 27, },
// }, Time: toml.LocalTime{
// Time: toml.LocalTime{ Hour: 7,
// Hour: 7, Minute: 32,
// Minute: 32, Second: 0,
// Second: 0, Nanosecond: 0,
// Nanosecond: 0, },
// }, }},
// }}, {
// { name: "with nanoseconds",
// name: "with nanoseconds", in: "1979-05-27T00:32:00.999999",
// in: "1979-05-27T00:32:00.999999", out: toml.LocalDateTime{
// out: toml.LocalDateTime{ Date: toml.LocalDate{
// Date: toml.LocalDate{ Year: 1979,
// Year: 1979, Month: 5,
// Month: 5, Day: 27,
// Day: 27, },
// }, Time: toml.LocalTime{
// Time: toml.LocalTime{ Hour: 0,
// Hour: 0, Minute: 32,
// Minute: 32, Second: 0,
// Second: 0, Nanosecond: 999999000,
// Nanosecond: 999999000, },
// }, },
// }, },
// }, }
// }
// for i, example := range examples {
// for i, example := range examples { doc := fmt.Sprintf(`date = %s`, example.in)
// doc := fmt.Sprintf(`date = %s`, example.in)
// t.Run(fmt.Sprintf("ToLocalDateTime_%d_%s", i, example.name), func(t *testing.T) {
// t.Run(fmt.Sprintf("ToLocalDateTime_%d_%s", i, example.name), func(t *testing.T) { type dateStruct struct {
// type dateStruct struct { Date toml.LocalDateTime
// Date toml.LocalDateTime }
// }
// var obj dateStruct
// var obj dateStruct
// err := toml.Unmarshal([]byte(doc), &obj)
// err := toml.Unmarshal([]byte(doc), &obj)
// if err != nil {
// if err != nil { t.Fatal(err)
// t.Fatal(err) }
// }
// if obj.Date != example.out {
// if obj.Date != example.out { t.Errorf("expected '%s', got '%s'", example.out, obj.Date)
// t.Errorf("expected '%s', got '%s'", example.out, obj.Date) }
// } })
// })
// t.Run(fmt.Sprintf("ToTime_%d_%s", i, example.name), func(t *testing.T) {
// t.Run(fmt.Sprintf("ToTime_%d_%s", i, example.name), func(t *testing.T) { type dateStruct struct {
// type dateStruct struct { Date time.Time
// Date time.Time }
// }
// var obj dateStruct
// var obj dateStruct
// err := toml.Unmarshal([]byte(doc), &obj)
// err := toml.Unmarshal([]byte(doc), &obj)
// if err != nil {
// if err != nil { t.Fatal(err)
// t.Fatal(err) }
// }
// if obj.Date.Year() != example.out.Date.Year {
// if obj.Date.Year() != example.out.Date.Year { t.Errorf("expected year %d, got %d", example.out.Date.Year, obj.Date.Year())
// t.Errorf("expected year %d, got %d", example.out.Date.Year, obj.Date.Year()) }
// } if obj.Date.Month() != example.out.Date.Month {
// if obj.Date.Month() != example.out.Date.Month { t.Errorf("expected month %d, got %d", example.out.Date.Month, obj.Date.Month())
// t.Errorf("expected month %d, got %d", example.out.Date.Month, obj.Date.Month()) }
// } if obj.Date.Day() != example.out.Date.Day {
// if obj.Date.Day() != example.out.Date.Day { t.Errorf("expected day %d, got %d", example.out.Date.Day, obj.Date.Day())
// t.Errorf("expected day %d, got %d", example.out.Date.Day, obj.Date.Day()) }
// } if obj.Date.Hour() != example.out.Time.Hour {
// if obj.Date.Hour() != example.out.Time.Hour { t.Errorf("expected hour %d, got %d", example.out.Time.Hour, obj.Date.Hour())
// t.Errorf("expected hour %d, got %d", example.out.Time.Hour, obj.Date.Hour()) }
// } if obj.Date.Minute() != example.out.Time.Minute {
// if obj.Date.Minute() != example.out.Time.Minute { t.Errorf("expected minute %d, got %d", example.out.Time.Minute, obj.Date.Minute())
// t.Errorf("expected minute %d, got %d", example.out.Time.Minute, obj.Date.Minute()) }
// } if obj.Date.Second() != example.out.Time.Second {
// if obj.Date.Second() != example.out.Time.Second { t.Errorf("expected second %d, got %d", example.out.Time.Second, obj.Date.Second())
// t.Errorf("expected second %d, got %d", example.out.Time.Second, obj.Date.Second()) }
// } if obj.Date.Nanosecond() != example.out.Time.Nanosecond {
// if obj.Date.Nanosecond() != example.out.Time.Nanosecond { t.Errorf("expected nanoseconds %d, got %d", example.out.Time.Nanosecond, obj.Date.Nanosecond())
// t.Errorf("expected nanoseconds %d, got %d", example.out.Time.Nanosecond, obj.Date.Nanosecond()) }
// } })
// }) }
// } }
//}
// func TestUnmarshalLocalTime(t *testing.T) {
//func TestUnmarshalLocalTime(t *testing.T) { examples := []struct {
// examples := []struct { name string
// name string in string
// in string out toml.LocalTime
// out toml.LocalTime }{
// }{ {
// { name: "normal",
// name: "normal", in: "07:32:00",
// in: "07:32:00", out: toml.LocalTime{
// out: toml.LocalTime{ Hour: 7,
// Hour: 7, Minute: 32,
// Minute: 32, Second: 0,
// Second: 0, Nanosecond: 0,
// Nanosecond: 0, },
// }, },
// }, {
// { name: "with nanoseconds",
// name: "with nanoseconds", in: "00:32:00.999999",
// in: "00:32:00.999999", out: toml.LocalTime{
// out: toml.LocalTime{ Hour: 0,
// Hour: 0, Minute: 32,
// Minute: 32, Second: 0,
// Second: 0, Nanosecond: 999999000,
// Nanosecond: 999999000, },
// }, },
// }, }
// }
// for i, example := range examples {
// for i, example := range examples { doc := fmt.Sprintf(`Time = %s`, example.in)
// doc := fmt.Sprintf(`Time = %s`, example.in)
// t.Run(fmt.Sprintf("ToLocalTime_%d_%s", i, example.name), func(t *testing.T) {
// t.Run(fmt.Sprintf("ToLocalTime_%d_%s", i, example.name), func(t *testing.T) { type dateStruct struct {
// type dateStruct struct { Time toml.LocalTime
// Time toml.LocalTime }
// }
// var obj dateStruct
// var obj dateStruct
// err := toml.Unmarshal([]byte(doc), &obj)
// err := toml.Unmarshal([]byte(doc), &obj)
// if err != nil {
// if err != nil { t.Fatal(err)
// t.Fatal(err) }
// }
// if obj.Time != example.out {
// if obj.Time != example.out { t.Errorf("expected '%s', got '%s'", example.out, obj.Time)
// t.Errorf("expected '%s', got '%s'", example.out, obj.Time) }
// } })
// }) }
// } }
//}
// test case for issue #339 // test case for issue #339
func TestUnmarshalSameInnerField(t *testing.T) { func TestUnmarshalSameInnerField(t *testing.T) {
@@ -1955,66 +1945,80 @@ String2="2"`
assert.Error(t, err) assert.Error(t, err)
} }
func decoder(doc string) *toml.Decoder {
return toml.NewDecoder(bytes.NewReader([]byte(doc)))
}
func strictDecoder(doc string) *toml.Decoder {
d := decoder(doc)
d.SetStrict(true)
return d
}
func TestDecoderStrict(t *testing.T) { func TestDecoderStrict(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = ""
// [undecoded]
//[undecoded] key = ""
// key = ""
// [undecoded.inner]
// [undecoded.inner] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]` require.Error(t, err)
// require.IsType(t, &toml.StrictMissingError{}, err)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) se := err.(*toml.StrictMissingError)
// if err == nil {
// t.Error("expected error, got none") keys := []toml.Key{}
// } else if err.Error() != expected {
// t.Errorf("expect err: %s, got: %s", expected, err.Error()) for _, e := range se.Errors {
// } keys = append(keys, e.Key())
// }
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil {
// t.Errorf("unexpected err: %s", err) expectedKeys := []toml.Key{
// } {"undecoded"},
// {"undecoded", "inner"},
// var m map[string]interface{} {"undecoded", "array"},
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil { {"undecoded", "array"},
// t.Errorf("unexpected err: %s", err) }
// }
require.Equal(t, expectedKeys, keys)
err = decoder(input).Decode(&doc)
require.NoError(t, err)
var m map[string]interface{}
err = decoder(input).Decode(&m)
} }
func TestDecoderStrictValid(t *testing.T) { func TestDecoderStrictValid(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = "" `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) require.NoError(t, err)
// if err != nil {
// t.Fatal("unexpected error:", err)
// }
} }
type docUnmarshalTOML struct { type docUnmarshalTOML struct {
+50
View File
@@ -0,0 +1,50 @@
package tracker
import (
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked.
type KeyTracker struct {
k []string
}
// UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node ast.Node) {
t.reset()
t.Push(node)
}
// UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node ast.Node) {
t.reset()
t.Push(node)
}
// Push the given key on the stack.
func (t *KeyTracker) Push(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = append(t.k, string(it.Node().Data))
}
}
// Pop key from stack.
func (t *KeyTracker) Pop(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = t.k[:len(t.k)-1]
}
}
// Key returns the current key
func (t *KeyTracker) Key() []string {
k := make([]string, len(t.k))
copy(k, t.k)
return k
}
func (t *KeyTracker) reset() {
t.k = t.k[:0]
}
+200
View File
@@ -0,0 +1,200 @@
package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type SeenTracker struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *SeenTracker) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *SeenTracker) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("toml: key %s should be a table, not a %s", k, i.kind)
}
if i.explicit {
return fmt.Errorf("toml: table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *SeenTracker) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", info.kind, k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
-199
View File
@@ -1,200 +1 @@
package tracker package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// Tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type Seen struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *Seen) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *Seen) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *Seen) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *Seen) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
+24
View File
@@ -33,3 +33,27 @@ func SubsliceOffset(data []byte, subslice []byte) int {
return intoffset return intoffset
} }
func BytesRange(start []byte, end []byte) []byte {
if start == nil || end == nil {
panic("cannot call BytesRange with nil")
}
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
if startp.Data > endp.Data {
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
}
l := startp.Len
endLen := int(endp.Data-startp.Data) + endp.Len
if endLen > l {
l = endLen
}
if l > startp.Cap {
panic(fmt.Errorf("range length is larger than capacity"))
}
return start[:l]
}
+89
View File
@@ -77,3 +77,92 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) {
}) })
} }
} }
func TestUnsafeBytesRange(t *testing.T) {
type fn = func() ([]byte, []byte)
examples := []struct {
desc string
test fn
expected []byte
}{
{
desc: "simple",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[6:8]
},
expected: []byte("ello wo"),
},
{
desc: "full",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[0:1], full[len(full)-1:]
},
expected: []byte("hello world"),
},
{
desc: "end before start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[len(full)-1:], full[0:1]
},
},
{
desc: "nils",
test: func() ([]byte, []byte) {
return nil, nil
},
},
{
desc: "nils start",
test: func() ([]byte, []byte) {
return nil, []byte("foo")
},
},
{
desc: "nils end",
test: func() ([]byte, []byte) {
return []byte("foo"), nil
},
},
{
desc: "start is end",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[1:3]
},
expected: []byte("el"),
},
{
desc: "end contained in start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:7], full[2:4]
},
expected: []byte("ello w"),
},
{
desc: "different backing arrays",
test: func() ([]byte, []byte) {
one := []byte("hello world")
two := []byte("hello world")
return one, two
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
start, end := e.test()
if e.expected == nil {
require.Panics(t, func() {
unsafe.BytesRange(start, end)
})
} else {
res := unsafe.BytesRange(start, end)
require.Equal(t, e.expected, res)
}
})
}
}
+38 -19
View File
@@ -23,6 +23,7 @@
// //
// Because they lack location information, these types do not represent unique // Because they lack location information, these types do not represent unique
// moments or intervals of time. Use time.Time for that purpose. // moments or intervals of time. Use time.Time for that purpose.
package toml package toml
import ( import (
@@ -44,6 +45,7 @@ type LocalDate struct {
func LocalDateOf(t time.Time) LocalDate { func LocalDateOf(t time.Time) LocalDate {
var d LocalDate var d LocalDate
d.Year, d.Month, d.Day = t.Date() d.Year, d.Month, d.Day = t.Date()
return d return d
} }
@@ -53,6 +55,7 @@ func ParseLocalDate(s string) (LocalDate, error) {
if err != nil { if err != nil {
return LocalDate{}, err return LocalDate{}, err
} }
return LocalDateOf(t), nil return LocalDateOf(t), nil
} }
@@ -92,23 +95,28 @@ func (d LocalDate) DaysSince(s LocalDate) (days int) {
// We convert to Unix time so we do not have to worry about leap seconds: // We convert to Unix time so we do not have to worry about leap seconds:
// Unix time increases by exactly 86400 seconds per day. // Unix time increases by exactly 86400 seconds per day.
deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix() deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix()
return int(deltaUnix / 86400)
const secondsInADay = 86400
return int(deltaUnix / secondsInADay)
} }
// Before reports whether d1 occurs before d2. // Before reports whether d1 occurs before future date.
func (d1 LocalDate) Before(d2 LocalDate) bool { func (d LocalDate) Before(future LocalDate) bool {
if d1.Year != d2.Year { if d.Year != future.Year {
return d1.Year < d2.Year return d.Year < future.Year
} }
if d1.Month != d2.Month {
return d1.Month < d2.Month if d.Month != future.Month {
return d.Month < future.Month
} }
return d1.Day < d2.Day
return d.Day < future.Day
} }
// After reports whether d1 occurs after d2. // After reports whether d1 occurs after past date.
func (d1 LocalDate) After(d2 LocalDate) bool { func (d LocalDate) After(past LocalDate) bool {
return d2.Before(d1) return past.Before(d)
} }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
@@ -122,6 +130,7 @@ func (d LocalDate) MarshalText() ([]byte, error) {
func (d *LocalDate) UnmarshalText(data []byte) error { func (d *LocalDate) UnmarshalText(data []byte) error {
var err error var err error
*d, err = ParseLocalDate(string(data)) *d, err = ParseLocalDate(string(data))
return err return err
} }
@@ -145,6 +154,7 @@ func LocalTimeOf(t time.Time) LocalTime {
var tm LocalTime var tm LocalTime
tm.Hour, tm.Minute, tm.Second = t.Clock() tm.Hour, tm.Minute, tm.Second = t.Clock()
tm.Nanosecond = t.Nanosecond() tm.Nanosecond = t.Nanosecond()
return tm return tm
} }
@@ -158,6 +168,7 @@ func ParseLocalTime(s string) (LocalTime, error) {
if err != nil { if err != nil {
return LocalTime{}, err return LocalTime{}, err
} }
return LocalTimeOf(t), nil return LocalTimeOf(t), nil
} }
@@ -169,6 +180,7 @@ func (t LocalTime) String() string {
if t.Nanosecond == 0 { if t.Nanosecond == 0 {
return s return s
} }
return s + fmt.Sprintf(".%09d", t.Nanosecond) return s + fmt.Sprintf(".%09d", t.Nanosecond)
} }
@@ -176,6 +188,7 @@ func (t LocalTime) String() string {
func (t LocalTime) IsValid() bool { func (t LocalTime) IsValid() bool {
// Construct a non-zero time. // Construct a non-zero time.
tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC) tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC)
return LocalTimeOf(tm) == t return LocalTimeOf(tm) == t
} }
@@ -190,6 +203,7 @@ func (t LocalTime) MarshalText() ([]byte, error) {
func (t *LocalTime) UnmarshalText(data []byte) error { func (t *LocalTime) UnmarshalText(data []byte) error {
var err error var err error
*t, err = ParseLocalTime(string(data)) *t, err = ParseLocalTime(string(data))
return err return err
} }
@@ -226,6 +240,7 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) {
return LocalDateTime{}, err return LocalDateTime{}, err
} }
} }
return LocalDateTimeOf(t), nil return LocalDateTimeOf(t), nil
} }
@@ -253,17 +268,20 @@ func (dt LocalDateTime) IsValid() bool {
// //
// In panics if loc is nil. // In panics if loc is nil.
func (dt LocalDateTime) In(loc *time.Location) time.Time { func (dt LocalDateTime) In(loc *time.Location) time.Time {
return time.Date(dt.Date.Year, dt.Date.Month, dt.Date.Day, dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc) return time.Date(
dt.Date.Year, dt.Date.Month, dt.Date.Day,
dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc,
)
} }
// Before reports whether dt1 occurs before dt2. // Before reports whether dt occurs before future.
func (dt1 LocalDateTime) Before(dt2 LocalDateTime) bool { func (dt LocalDateTime) Before(future LocalDateTime) bool {
return dt1.In(time.UTC).Before(dt2.In(time.UTC)) return dt.In(time.UTC).Before(future.In(time.UTC))
} }
// After reports whether dt1 occurs after dt2. // After reports whether dt occurs after past.
func (dt1 LocalDateTime) After(dt2 LocalDateTime) bool { func (dt LocalDateTime) After(past LocalDateTime) bool {
return dt2.Before(dt1) return past.Before(dt)
} }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
@@ -273,9 +291,10 @@ func (dt LocalDateTime) MarshalText() ([]byte, error) {
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
// The datetime is expected to be a string in a format accepted by ParseLocalDateTime // The datetime is expected to be a string in a format accepted by ParseLocalDateTime.
func (dt *LocalDateTime) UnmarshalText(data []byte) error { func (dt *LocalDateTime) UnmarshalText(data []byte) error {
var err error var err error
*dt, err = ParseLocalDateTime(string(data)) *dt, err = ParseLocalDateTime(string(data))
return err return err
} }
+78 -17
View File
@@ -26,6 +26,8 @@ func cmpEqual(x, y interface{}) bool {
} }
func TestDates(t *testing.T) { func TestDates(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
date LocalDate date LocalDate
loc *time.Location loc *time.Location
@@ -54,6 +56,7 @@ func TestDates(t *testing.T) {
if got := test.date.String(); got != test.wantStr { if got := test.date.String(); got != test.wantStr {
t.Errorf("%#v.String() = %q, want %q", test.date, got, test.wantStr) t.Errorf("%#v.String() = %q, want %q", test.date, got, test.wantStr)
} }
if got := test.date.In(test.loc); !got.Equal(test.wantTime) { if got := test.date.In(test.loc); !got.Equal(test.wantTime) {
t.Errorf("%#v.In(%v) = %v, want %v", test.date, test.loc, got, test.wantTime) t.Errorf("%#v.In(%v) = %v, want %v", test.date, test.loc, got, test.wantTime)
} }
@@ -61,6 +64,8 @@ func TestDates(t *testing.T) {
} }
func TestDateIsValid(t *testing.T) { func TestDateIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
date LocalDate date LocalDate
want bool want bool
@@ -86,6 +91,10 @@ func TestDateIsValid(t *testing.T) {
} }
func TestParseDate(t *testing.T) { func TestParseDate(t *testing.T) {
t.Parallel()
var emptyDate LocalDate
for _, test := range []struct { for _, test := range []struct {
str string str string
want LocalDate // if empty, expect an error want LocalDate // if empty, expect an error
@@ -93,21 +102,24 @@ func TestParseDate(t *testing.T) {
{"2016-01-02", LocalDate{2016, 1, 2}}, {"2016-01-02", LocalDate{2016, 1, 2}},
{"2016-12-31", LocalDate{2016, 12, 31}}, {"2016-12-31", LocalDate{2016, 12, 31}},
{"0003-02-04", LocalDate{3, 2, 4}}, {"0003-02-04", LocalDate{3, 2, 4}},
{"999-01-26", LocalDate{}}, {"999-01-26", emptyDate},
{"", LocalDate{}}, {"", emptyDate},
{"2016-01-02x", LocalDate{}}, {"2016-01-02x", emptyDate},
} { } {
got, err := ParseLocalDate(test.str) got, err := ParseLocalDate(test.str)
if got != test.want { if got != test.want {
t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want) t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want)
} }
if err != nil && test.want != (LocalDate{}) {
if err != nil && test.want != (emptyDate) {
t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str) t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str)
} }
} }
} }
func TestDateArithmetic(t *testing.T) { func TestDateArithmetic(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
desc string desc string
start LocalDate start LocalDate
@@ -160,6 +172,7 @@ func TestDateArithmetic(t *testing.T) {
if got := test.start.AddDays(test.days); got != test.end { if got := test.start.AddDays(test.days); got != test.end {
t.Errorf("[%s] %#v.AddDays(%v) = %#v, want %#v", test.desc, test.start, test.days, got, test.end) t.Errorf("[%s] %#v.AddDays(%v) = %#v, want %#v", test.desc, test.start, test.days, got, test.end)
} }
if got := test.end.DaysSince(test.start); got != test.days { if got := test.end.DaysSince(test.start); got != test.days {
t.Errorf("[%s] %#v.Sub(%#v) = %v, want %v", test.desc, test.end, test.start, got, test.days) t.Errorf("[%s] %#v.Sub(%#v) = %v, want %v", test.desc, test.end, test.start, got, test.days)
} }
@@ -167,6 +180,8 @@ func TestDateArithmetic(t *testing.T) {
} }
func TestDateBefore(t *testing.T) { func TestDateBefore(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
d1, d2 LocalDate d1, d2 LocalDate
want bool want bool
@@ -183,6 +198,8 @@ func TestDateBefore(t *testing.T) {
} }
func TestDateAfter(t *testing.T) { func TestDateAfter(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
d1, d2 LocalDate d1, d2 LocalDate
want bool want bool
@@ -198,6 +215,8 @@ func TestDateAfter(t *testing.T) {
} }
func TestTimeToString(t *testing.T) { func TestTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
str string str string
time LocalTime time LocalTime
@@ -212,11 +231,14 @@ func TestTimeToString(t *testing.T) {
gotTime, err := ParseLocalTime(test.str) gotTime, err := ParseLocalTime(test.str)
if err != nil { if err != nil {
t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err) t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err)
continue continue
} }
if gotTime != test.time { if gotTime != test.time {
t.Errorf("ParseLocalTime(%q) = %+v, want %+v", test.str, gotTime, test.time) t.Errorf("ParseLocalTime(%q) = %+v, want %+v", test.str, gotTime, test.time)
} }
if test.roundTrip { if test.roundTrip {
gotStr := test.time.String() gotStr := test.time.String()
if gotStr != test.str { if gotStr != test.str {
@@ -227,6 +249,8 @@ func TestTimeToString(t *testing.T) {
} }
func TestTimeOf(t *testing.T) { func TestTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time time.Time time time.Time
want LocalTime want LocalTime
@@ -241,6 +265,8 @@ func TestTimeOf(t *testing.T) {
} }
func TestTimeIsValid(t *testing.T) { func TestTimeIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time LocalTime time LocalTime
want bool want bool
@@ -265,23 +291,28 @@ func TestTimeIsValid(t *testing.T) {
} }
func TestDateTimeToString(t *testing.T) { func TestDateTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
str string str string
dateTime LocalDateTime dateTime LocalDateTime
roundTrip bool // ParseLocalDateTime(str).String() == str? roundTrip bool // ParseLocalDateTime(str).String() == str?
}{ }{
{"2016-03-22T13:26:33", LocalDateTime{LocalDate{2016, 03, 22}, LocalTime{13, 26, 33, 0}}, true}, {"2016-03-22T13:26:33", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 0}}, true},
{"2016-03-22T13:26:33.000000600", LocalDateTime{LocalDate{2016, 03, 22}, LocalTime{13, 26, 33, 600}}, true}, {"2016-03-22T13:26:33.000000600", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 600}}, true},
{"2016-03-22t13:26:33", LocalDateTime{LocalDate{2016, 03, 22}, LocalTime{13, 26, 33, 0}}, false}, {"2016-03-22t13:26:33", LocalDateTime{LocalDate{2016, 3, 22}, LocalTime{13, 26, 33, 0}}, false},
} { } {
gotDateTime, err := ParseLocalDateTime(test.str) gotDateTime, err := ParseLocalDateTime(test.str)
if err != nil { if err != nil {
t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err) t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err)
continue continue
} }
if gotDateTime != test.dateTime { if gotDateTime != test.dateTime {
t.Errorf("ParseLocalDateTime(%q) = %+v, want %+v", test.str, gotDateTime, test.dateTime) t.Errorf("ParseLocalDateTime(%q) = %+v, want %+v", test.str, gotDateTime, test.dateTime)
} }
if test.roundTrip { if test.roundTrip {
gotStr := test.dateTime.String() gotStr := test.dateTime.String()
if gotStr != test.str { if gotStr != test.str {
@@ -292,6 +323,8 @@ func TestDateTimeToString(t *testing.T) {
} }
func TestParseDateTimeErrors(t *testing.T) { func TestParseDateTimeErrors(t *testing.T) {
t.Parallel()
for _, str := range []string{ for _, str := range []string{
"", "",
"2016-03-22", // just a date "2016-03-22", // just a date
@@ -306,14 +339,20 @@ func TestParseDateTimeErrors(t *testing.T) {
} }
func TestDateTimeOf(t *testing.T) { func TestDateTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time time.Time time time.Time
want LocalDateTime want LocalDateTime
}{ }{
{time.Date(2014, 8, 20, 15, 8, 43, 1, time.Local), {
LocalDateTime{LocalDate{2014, 8, 20}, LocalTime{15, 8, 43, 1}}}, time.Date(2014, 8, 20, 15, 8, 43, 1, time.Local),
{time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), LocalDateTime{LocalDate{2014, 8, 20}, LocalTime{15, 8, 43, 1}},
LocalDateTime{LocalDate{1, 1, 1}, LocalTime{0, 0, 0, 0}}}, },
{
time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
LocalDateTime{LocalDate{1, 1, 1}, LocalTime{0, 0, 0, 0}},
},
} { } {
if got := LocalDateTimeOf(test.time); got != test.want { if got := LocalDateTimeOf(test.time); got != test.want {
t.Errorf("LocalDateTimeOf(%v) = %+v, want %+v", test.time, got, test.want) t.Errorf("LocalDateTimeOf(%v) = %+v, want %+v", test.time, got, test.want)
@@ -322,6 +361,8 @@ func TestDateTimeOf(t *testing.T) {
} }
func TestDateTimeIsValid(t *testing.T) { func TestDateTimeIsValid(t *testing.T) {
t.Parallel()
// No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid. // No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid.
for _, test := range []struct { for _, test := range []struct {
dt LocalDateTime dt LocalDateTime
@@ -339,19 +380,24 @@ func TestDateTimeIsValid(t *testing.T) {
} }
func TestDateTimeIn(t *testing.T) { func TestDateTimeIn(t *testing.T) {
t.Parallel()
dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}} dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}}
got := dt.In(time.UTC)
want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC) want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC)
if !got.Equal(want) { if got := dt.In(time.UTC); !got.Equal(want) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
} }
} }
func TestDateTimeBefore(t *testing.T) { func TestDateTimeBefore(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31} d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1} d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8} t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9} t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct { for _, test := range []struct {
dt1, dt2 LocalDateTime dt1, dt2 LocalDateTime
want bool want bool
@@ -368,10 +414,13 @@ func TestDateTimeBefore(t *testing.T) {
} }
func TestDateTimeAfter(t *testing.T) { func TestDateTimeAfter(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31} d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1} d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8} t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9} t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct { for _, test := range []struct {
dt1, dt2 LocalDateTime dt1, dt2 LocalDateTime
want bool want bool
@@ -388,6 +437,8 @@ func TestDateTimeAfter(t *testing.T) {
} }
func TestMarshalJSON(t *testing.T) { func TestMarshalJSON(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
value interface{} value interface{}
want string want string
@@ -400,6 +451,7 @@ func TestMarshalJSON(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got := string(bgot); got != test.want { if got := string(bgot); got != test.want {
t.Errorf("%#v: got %s, want %s", test.value, got, test.want) t.Errorf("%#v: got %s, want %s", test.value, got, test.want)
} }
@@ -407,9 +459,14 @@ func TestMarshalJSON(t *testing.T) {
} }
func TestUnmarshalJSON(t *testing.T) { func TestUnmarshalJSON(t *testing.T) {
var d LocalDate t.Parallel()
var tm LocalTime
var dt LocalDateTime var (
d LocalDate
tm LocalTime
dt LocalDateTime
)
for _, test := range []struct { for _, test := range []struct {
data string data string
ptr interface{} ptr interface{}
@@ -423,12 +480,14 @@ func TestUnmarshalJSON(t *testing.T) {
if err := json.Unmarshal([]byte(test.data), test.ptr); err != nil { if err := json.Unmarshal([]byte(test.data), test.ptr); err != nil {
t.Fatalf("%s: %v", test.data, err) t.Fatalf("%s: %v", test.data, err)
} }
if !cmpEqual(test.ptr, test.want) { if !cmpEqual(test.ptr, test.want) {
t.Errorf("%s: got %#v, want %#v", test.data, test.ptr, test.want) t.Errorf("%s: got %#v, want %#v", test.data, test.ptr, test.want)
} }
} }
for _, bad := range []string{"", `""`, `"bad"`, `"1987-04-15x"`, for _, bad := range []string{
"", `""`, `"bad"`, `"1987-04-15x"`,
`19870415`, // a JSON number `19870415`, // a JSON number
`11987-04-15x`, // not a JSON string `11987-04-15x`, // not a JSON string
@@ -436,9 +495,11 @@ func TestUnmarshalJSON(t *testing.T) {
if json.Unmarshal([]byte(bad), &d) == nil { if json.Unmarshal([]byte(bad), &d) == nil {
t.Errorf("%q, LocalDate: got nil, want error", bad) t.Errorf("%q, LocalDate: got nil, want error", bad)
} }
if json.Unmarshal([]byte(bad), &tm) == nil { if json.Unmarshal([]byte(bad), &tm) == nil {
t.Errorf("%q, LocalTime: got nil, want error", bad) t.Errorf("%q, LocalTime: got nil, want error", bad)
} }
if json.Unmarshal([]byte(bad), &dt) == nil { if json.Unmarshal([]byte(bad), &dt) == nil {
t.Errorf("%q, LocalDateTime: got nil, want error", bad) t.Errorf("%q, LocalDateTime: got nil, want error", bad)
} }
+337 -160
View File
@@ -2,7 +2,7 @@ package toml
import ( import (
"bytes" "bytes"
"errors" "encoding"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
@@ -18,16 +18,130 @@ import (
func Marshal(v interface{}) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
enc := NewEncoder(&buf) enc := NewEncoder(&buf)
err := enc.Encode(v) err := enc.Encode(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// Encoder writes a TOML document to an output stream. // Encoder writes a TOML document to an output stream.
type Encoder struct { type Encoder struct {
// output
w io.Writer w io.Writer
// global settings
tablesInline bool
arraysMultiline bool
indentSymbol string
indentTables bool
}
// NewEncoder returns a new Encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: w,
indentSymbol: " ",
}
}
// SetTablesInline forces the encoder to emit all tables inline.
//
// This behavior can be controlled on an individual struct field basis with the
// inline tag:
//
// MyField `inline:"true"`
func (enc *Encoder) SetTablesInline(inline bool) {
enc.tablesInline = inline
}
// SetArraysMultiline forces the encoder to emit all arrays with one element per
// line.
//
// This behavior can be controlled on an individual struct field basis with the multiline tag:
//
// MyField `multiline:"true"`
func (enc *Encoder) SetArraysMultiline(multiline bool) {
enc.arraysMultiline = multiline
}
// SetIndentSymbol defines the string that should be used for indentation. The
// provided string is repeated for each indentation level. Defaults to two
// spaces.
func (enc *Encoder) SetIndentSymbol(s string) {
enc.indentSymbol = s
}
// SetIndentTables forces the encoder to intent tables and array tables.
func (enc *Encoder) SetIndentTables(indent bool) {
enc.indentTables = indent
}
// Encode writes a TOML representation of v to the stream.
//
// If v cannot be represented to TOML it returns an error.
//
// Encoding rules
//
// A top level slice containing only maps or structs is encoded as [[table
// array]].
//
// All slices not matching rule 1 are encoded as [array]. As a result, any map
// or struct they contain is encoded as an {inline table}.
//
// Nil interfaces and nil pointers are not supported.
//
// Keys in key-values always have one part.
//
// Intermediate tables are always printed.
//
// By default, strings are encoded as literal string, unless they contain either
// a newline character or a single quote. In that case they are emitted as quoted
// strings.
//
// When encoding structs, fields are encoded in order of definition, with their
// exact name.
//
// Struct tags
//
// The following struct tags are available to tweak encoding on a per-field
// basis:
//
// toml:"foo"
// Changes the name of the key to use for the field to foo.
//
// multiline:"true"
// When the field contains a string, it will be emitted as a quoted
// multi-line TOML string.
//
// inline:"true"
// When the field would normally be encoded as a table, it is instead
// encoded as an inline table.
func (enc *Encoder) Encode(v interface{}) error {
var (
b []byte
ctx encoderCtx
)
ctx.inline = enc.tablesInline
b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil {
return err
}
_, err = enc.w.Write(b)
if err != nil {
return fmt.Errorf("toml: cannot write: %w", err)
}
return nil
}
type valueOptions struct {
multiline bool
} }
type encoderCtx struct { type encoderCtx struct {
@@ -46,11 +160,14 @@ type encoderCtx struct {
// Set to true to skip the first table header in an array table. // Set to true to skip the first table header in an array table.
skipTableHeader bool skipTableHeader bool
options valueOptions // Should the next table be encoded as inline
} inline bool
type valueOptions struct { // Indentation level
multiline bool indent int
// Options coming from struct tags
options valueOptions
} }
func (ctx *encoderCtx) shiftKey() { func (ctx *encoderCtx) shiftKey() {
@@ -70,62 +187,34 @@ func (ctx *encoderCtx) clearKey() {
ctx.hasKey = false ctx.hasKey = false
} }
// NewEncoder returns a new Encoder that writes to w. func (ctx *encoderCtx) isRoot() bool {
func NewEncoder(w io.Writer) *Encoder { return len(ctx.parentKey) == 0 && !ctx.hasKey
return &Encoder{
w: w,
}
}
// Encode writes a TOML representation of v to the stream.
//
// If v cannot be represented to TOML it returns an error.
//
// Encoding rules:
//
// 1. A top level slice containing only maps or structs is encoded as [[table
// array]].
//
// 2. All slices not matching rule 1 are encoded as [array]. As a result, any
// map or struct they contain is encoded as an {inline table}.
//
// 3. Nil interfaces and nil pointers are not supported.
//
// 4. Keys in key-values always have one part.
//
// 5. Intermediate tables are always printed.
//
// By default, strings are encoded as literal string, unless they contain either
// a newline character or a single quote. In that case they are emited as quoted
// strings.
//
// When encoding structs, fields are encoded in order of definition, with their
// exact name. The following struct tags are available:
//
// `toml:"foo"`: changes the name of the key to use for the field to foo.
//
// `multiline:"true"`: when the field contains a string, it will be emitted as
// a quoted multi-line TOML string.
func (enc *Encoder) Encode(v interface{}) error {
var b []byte
var ctx encoderCtx
b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil {
return err
}
_, err = enc.w.Write(b)
return err
} }
//nolint:cyclop,funlen
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
switch i := v.Interface().(type) { i, ok := v.Interface().(time.Time)
case time.Time: // TODO: add TextMarshaler if ok {
b = i.AppendFormat(b, time.RFC3339) return i.AppendFormat(b, time.RFC3339), nil
}
if v.Type().Implements(textMarshalerType) {
if ctx.isRoot() {
return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type())
}
text, err := v.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return nil, err
}
b = enc.encodeString(b, string(text), ctx.options)
return b, nil return b, nil
} }
// containers
switch v.Kind() { switch v.Kind() {
// containers
case reflect.Map: case reflect.Map:
return enc.encodeMap(b, ctx, v) return enc.encodeMap(b, ctx, v)
case reflect.Struct: case reflect.Struct:
@@ -134,21 +223,20 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
return enc.encodeSlice(b, ctx, v) return enc.encodeSlice(b, ctx, v)
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return nil, errNilInterface return nil, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
case reflect.Ptr: case reflect.Ptr:
if v.IsNil() { if v.IsNil() {
return enc.encode(b, ctx, reflect.Zero(v.Type().Elem())) return enc.encode(b, ctx, reflect.Zero(v.Type().Elem()))
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
}
// values // values
var err error
switch v.Kind() {
case reflect.String: case reflect.String:
b, err = enc.encodeString(b, v.String(), ctx.options) b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32: case reflect.Float32:
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32) b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
case reflect.Float64: case reflect.Float64:
@@ -164,10 +252,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
b = strconv.AppendInt(b, v.Int(), 10) b = strconv.AppendInt(b, v.Int(), 10)
default: default:
err = fmt.Errorf("unsupported encode value kind: %s", v.Kind()) return nil, fmt.Errorf("toml: cannot encode value of type %s", v.Kind())
}
if err != nil {
return nil, err
} }
return b, nil return b, nil
@@ -193,6 +278,8 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
return b, nil return b, nil
} }
b = enc.indent(ctx.indent, b)
b, err = enc.encodeKey(b, ctx.key) b, err = enc.encodeKey(b, ctx.key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -217,30 +304,31 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
const literalQuote = '\'' const literalQuote = '\''
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) ([]byte, error) { func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte {
if needsQuoting(v) { if needsQuoting(v) {
b = enc.encodeQuotedString(options.multiline, b, v) return enc.encodeQuotedString(options.multiline, b, v)
} else {
b = enc.encodeLiteralString(b, v)
} }
return b, nil
return enc.encodeLiteralString(b, v)
} }
func needsQuoting(v string) bool { func needsQuoting(v string) bool {
return strings.ContainsAny(v, "'\b\f\n\r\t") return strings.ContainsAny(v, "'\b\f\n\r\t")
} }
// caller should have checked that the string does not contain new lines or ' // caller should have checked that the string does not contain new lines or ' .
func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte { func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte {
b = append(b, literalQuote) b = append(b, literalQuote)
b = append(b, v...) b = append(b, v...)
b = append(b, literalQuote) b = append(b, literalQuote)
return b return b
} }
//nolint:cyclop
func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte { func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte {
const hextable = "0123456789ABCDEF"
stringQuote := `"` stringQuote := `"`
if multiline { if multiline {
stringQuote = `"""` stringQuote = `"""`
} }
@@ -250,6 +338,16 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, '\n') b = append(b, '\n')
} }
const (
hextable = "0123456789ABCDEF"
// U+0000 to U+0008, U+000A to U+001F, U+007F
nul = 0x0
bs = 0x8
lf = 0xa
us = 0x1f
del = 0x7f
)
for _, r := range []byte(v) { for _, r := range []byte(v) {
switch r { switch r {
case '\\': case '\\':
@@ -272,7 +370,7 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, `\t`...) b = append(b, `\t`...)
default: default:
switch { switch {
case r >= 0x0 && r <= 0x8, r >= 0xA && r <= 0x1F, r == 0x7F: case r >= nul && r <= bs, r >= lf && r <= us, r == del:
b = append(b, `\u00`...) b = append(b, `\u00`...)
b = append(b, hextable[r>>4]) b = append(b, hextable[r>>4])
b = append(b, hextable[r&0x0f]) b = append(b, hextable[r&0x0f])
@@ -280,33 +378,37 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, r) b = append(b, r)
} }
} }
// U+0000 to U+0008, U+000A to U+001F, U+007F
} }
b = append(b, stringQuote...) b = append(b, stringQuote...)
return b return b
} }
// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ // called should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte { func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...) return append(b, v...)
} }
func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) { func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error) {
if len(key) == 0 { if len(ctx.parentKey) == 0 {
return b, nil return b, nil
} }
b = enc.indent(ctx.indent, b)
b = append(b, '[') b = append(b, '[')
var err error var err error
b, err = enc.encodeKey(b, key[0])
b, err = enc.encodeKey(b, ctx.parentKey[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, k := range key[1:] { for _, k := range ctx.parentKey[1:] {
b = append(b, '.') b = append(b, '.')
b, err = enc.encodeKey(b, k) b, err = enc.encodeKey(b, k)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -318,6 +420,7 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
return b, nil return b, nil
} }
//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
needsQuotation := false needsQuotation := false
cannotUseLiteral := false cannotUseLiteral := false
@@ -326,32 +429,37 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
continue continue
} }
if c == '\n' { if c == '\n' {
return nil, fmt.Errorf("TOML does not support multiline keys") return nil, fmt.Errorf("toml: new line characters in keys are not supported")
} }
if c == literalQuote { if c == literalQuote {
cannotUseLiteral = true cannotUseLiteral = true
} }
needsQuotation = true needsQuotation = true
} }
if cannotUseLiteral { switch {
b = enc.encodeQuotedString(false, b, k) case cannotUseLiteral:
} else if needsQuotation { return enc.encodeQuotedString(false, b, k), nil
b = enc.encodeLiteralString(b, k) case needsQuotation:
} else { return enc.encodeLiteralString(b, k), nil
b = enc.encodeUnquotedKey(b, k) default:
return enc.encodeUnquotedKey(b, k), nil
} }
return b, nil
} }
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String { if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("type '%s' not supported as map key", v.Type().Key().Kind()) return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind())
} }
t := table{} var (
t table
emptyValueOptions valueOptions
)
iter := v.MapRange() iter := v.MapRange()
for iter.Next() { for iter.Next() {
@@ -362,15 +470,15 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte
continue continue
} }
table, err := willConvertToTableOrArrayTable(v) table, err := willConvertToTableOrArrayTable(ctx, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if table { if table {
t.pushTable(k, v, valueOptions{}) t.pushTable(k, v, emptyValueOptions)
} else { } else {
t.pushKV(k, v, valueOptions{}) t.pushKV(k, v, emptyValueOptions)
} }
} }
@@ -405,13 +513,10 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) {
t.tables = append(t.tables, entry{Key: k, Value: v, Options: options}) t.tables = append(t.tables, entry{Key: k, Value: v, Options: options})
} }
func (t *table) hasKVs() bool {
return len(t.kvs) > 0
}
func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
t := table{} var t table
//nolint:godox
// TODO: cache this? // TODO: cache this?
typ := v.Type() typ := v.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
@@ -438,133 +543,162 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
continue continue
} }
willConvert, err := willConvertToTableOrArrayTable(f) willConvert, err := willConvertToTableOrArrayTable(ctx, f)
if err != nil { if err != nil {
return nil, err return nil, err
} }
options := valueOptions{} options := valueOptions{
multiline: fieldBoolTag(fieldType, "multiline"),
ml, ok := fieldType.Tag.Lookup("multiline")
if ok {
options.multiline = ml == "true"
} }
if willConvert { inline := fieldBoolTag(fieldType, "inline")
t.pushTable(k, f, options)
} else { if inline || !willConvert {
t.pushKV(k, f, options) t.pushKV(k, f, options)
} else {
t.pushTable(k, f, options)
} }
} }
return enc.encodeTable(b, ctx, t) return enc.encodeTable(b, ctx, t)
} }
func fieldBoolTag(field reflect.StructField, tag string) bool {
x, ok := field.Tag.Lookup(tag)
return ok && x == "true"
}
//nolint:cyclop
func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, error) { func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, error) {
var err error var err error
ctx.shiftKey() ctx.shiftKey()
if ctx.insideKv { if ctx.insideKv || (ctx.inline && !ctx.isRoot()) {
b = append(b, '{') return enc.encodeTableInline(b, ctx, t)
first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}\n"...)
return b, nil
} }
if !ctx.skipTableHeader { if !ctx.skipTableHeader {
b, err = enc.encodeTableHeader(b, ctx.parentKey) b, err = enc.encodeTableHeader(ctx, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if enc.indentTables && len(ctx.parentKey) > 0 {
ctx.indent++
}
} }
ctx.skipTableHeader = false ctx.skipTableHeader = false
for _, kv := range t.kvs { for _, kv := range t.kvs {
ctx.setKey(kv.Key) ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value) b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b = append(b, '\n') b = append(b, '\n')
} }
for _, table := range t.tables { for _, table := range t.tables {
ctx.setKey(table.Key) ctx.setKey(table.Key)
ctx.options = table.Options
b, err = enc.encode(b, ctx, table.Value) b, err = enc.encode(b, ctx, table.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b = append(b, '\n') b = append(b, '\n')
} }
return b, nil return b, nil
} }
var errNilInterface = errors.New("nil interface not supported") func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte, error) {
var errNilPointer = errors.New("nil pointer not supported") var err error
func willConvertToTable(v reflect.Value) (bool, error) { b = append(b, '{')
switch v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}"...)
return b, nil
}
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
if v.Type() == timeType || v.Type().Implements(textMarshalerType) {
return false, nil return false, nil
} }
t := v.Type() t := v.Type()
switch t.Kind() { switch t.Kind() {
case reflect.Map, reflect.Struct: case reflect.Map, reflect.Struct:
return true, nil return !ctx.inline, nil
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return willConvertToTable(v.Elem())
return willConvertToTable(ctx, v.Elem())
case reflect.Ptr: case reflect.Ptr:
if v.IsNil() { if v.IsNil() {
return false, nil return false, nil
} }
return willConvertToTable(v.Elem())
return willConvertToTable(ctx, v.Elem())
default: default:
return false, nil return false, nil
} }
} }
func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) { func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) (bool, error) {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Interface { if t.Kind() == reflect.Interface {
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return willConvertToTableOrArrayTable(v.Elem())
return willConvertToTableOrArrayTable(ctx, v.Elem())
} }
if t.Kind() == reflect.Slice { if t.Kind() == reflect.Slice {
@@ -572,28 +706,32 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
// An empty slice should be a kv = []. // An empty slice should be a kv = [].
return false, nil return false, nil
} }
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
t, err := willConvertToTable(v.Index(i)) t, err := willConvertToTable(ctx, v.Index(i))
if err != nil { if err != nil {
return false, err return false, err
} }
if !t { if !t {
return false, nil return false, nil
} }
} }
return true, nil return true, nil
} }
return willConvertToTable(v) return willConvertToTable(ctx, v)
} }
func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Len() == 0 { if v.Len() == 0 {
b = append(b, "[]"...) b = append(b, "[]"...)
return b, nil return b, nil
} }
allTables, err := willConvertToTableOrArrayTable(v) allTables, err := willConvertToTableOrArrayTable(ctx, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -617,45 +755,84 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
var err error var err error
scratch := make([]byte, 0, 64) scratch := make([]byte, 0, 64)
scratch = append(scratch, "[["...) scratch = append(scratch, "[["...)
for i, k := range ctx.parentKey { for i, k := range ctx.parentKey {
if i > 0 { if i > 0 {
scratch = append(scratch, '.') scratch = append(scratch, '.')
} }
scratch, err = enc.encodeKey(scratch, k) scratch, err = enc.encodeKey(scratch, k)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
scratch = append(scratch, "]]\n"...) scratch = append(scratch, "]]\n"...)
ctx.skipTableHeader = true ctx.skipTableHeader = true
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
b = append(b, scratch...) b = append(b, scratch...)
b, err = enc.encode(b, ctx, v.Index(i)) b, err = enc.encode(b, ctx, v.Index(i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return b, nil return b, nil
} }
func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
multiline := ctx.options.multiline || enc.arraysMultiline
separator := ", "
b = append(b, '[') b = append(b, '[')
subCtx := ctx
subCtx.options = valueOptions{}
if multiline {
separator = ",\n"
b = append(b, '\n')
subCtx.indent++
}
var err error var err error
first := true first := true
for i := 0; i < v.Len(); i++ {
if !first {
b = append(b, ", "...)
}
first = false
b, err = enc.encode(b, ctx, v.Index(i)) for i := 0; i < v.Len(); i++ {
if first {
first = false
} else {
b = append(b, separator...)
}
if multiline {
b = enc.indent(subCtx.indent, b)
}
b, err = enc.encode(b, subCtx, v.Index(i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if multiline {
b = append(b, '\n')
b = enc.indent(ctx.indent, b)
}
b = append(b, ']') b = append(b, ']')
return b, nil return b, nil
} }
func (enc *Encoder) indent(level int, b []byte) []byte {
for i := 0; i < level; i++ {
b = append(b, enc.indentSymbol...)
}
return b
}
+270 -7
View File
@@ -3,6 +3,7 @@ package toml_test
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"testing" "testing"
@@ -11,7 +12,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
//nolint:funlen
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
v interface{} v interface{}
@@ -65,8 +69,6 @@ hello = 'world'`,
a = 'test'`, a = 'test'`,
}, },
{ {
// TODO: this test is flaky because output changes depending on
// the map iteration order.
desc: "map in map in map and string with values", desc: "map in map in map and string with values",
v: map[string]interface{}{ v: map[string]interface{}{
"this": map[string]interface{}{ "this": map[string]interface{}{
@@ -89,6 +91,16 @@ a = 'test'`,
}, },
expected: `array = ['one', 'two', 'three']`, expected: `array = ['one', 'two', 'three']`,
}, },
{
desc: "empty string array",
v: map[string][]string{},
expected: ``,
},
{
desc: "map",
v: map[string][]string{},
expected: ``,
},
{ {
desc: "nested string arrays", desc: "nested string arrays",
v: map[string][][]string{ v: map[string][][]string{
@@ -104,7 +116,7 @@ a = 'test'`,
expected: `array = ['a string', ['one', 'two'], 'last']`, expected: `array = ['a string', ['one', 'two'], 'last']`,
}, },
{ {
desc: "slice of maps", desc: "array of maps",
v: map[string][]map[string]string{ v: map[string][]map[string]string{
"top": { "top": {
{"map1.1": "v1.1"}, {"map1.1": "v1.1"},
@@ -157,7 +169,7 @@ K2 = 'v2'
`, `,
}, },
{ {
desc: "structs in slice with interfaces", desc: "structs in array with interfaces",
v: map[string]interface{}{ v: map[string]interface{}{
"root": map[string]interface{}{ "root": map[string]interface{}{
"nested": []interface{}{ "nested": []interface{}{
@@ -234,17 +246,144 @@ name = 'Alice'
hello hello
world"""`, world"""`,
}, },
{
desc: "inline field",
v: struct {
A map[string]string `inline:"true"`
B map[string]string
}{
A: map[string]string{
"isinline": "yes",
},
B: map[string]string{
"isinline": "no",
},
},
expected: `
A = {isinline = 'yes'}
[B]
isinline = 'no'
`,
},
{
desc: "mutiline array int",
v: struct {
A []int `multiline:"true"`
B []int
}{
A: []int{1, 2, 3, 4},
B: []int{1, 2, 3, 4},
},
expected: `
A = [
1,
2,
3,
4
]
B = [1, 2, 3, 4]
`,
},
{
desc: "mutiline array in array",
v: struct {
A [][]int `multiline:"true"`
}{
A: [][]int{{1, 2}, {3, 4}},
},
expected: `
A = [
[1, 2],
[3, 4]
]
`,
},
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
b, err := toml.Marshal(e.v) b, err := toml.Marshal(e.v)
if e.err { if e.err {
require.Error(t, err) require.Error(t, err)
} else {
require.NoError(t, err) return
equalStringsIgnoreNewlines(t, e.expected, string(b))
} }
require.NoError(t, err)
equalStringsIgnoreNewlines(t, e.expected, string(b))
// make sure the output is always valid TOML
defaultMap := map[string]interface{}{}
err = toml.Unmarshal(b, &defaultMap)
require.NoError(t, err)
testWithAllFlags(t, func(t *testing.T, flags int) {
t.Helper()
var buf bytes.Buffer
enc := toml.NewEncoder(&buf)
setFlags(enc, flags)
err := enc.Encode(e.v)
require.NoError(t, err)
inlineMap := map[string]interface{}{}
err = toml.Unmarshal(buf.Bytes(), &inlineMap)
require.NoError(t, err)
require.Equal(t, defaultMap, inlineMap)
})
})
}
}
type flagsSetters []struct {
name string
f func(enc *toml.Encoder, flag bool)
}
var allFlags = flagsSetters{
{"arrays-multiline", (*toml.Encoder).SetArraysMultiline},
{"tables-inline", (*toml.Encoder).SetTablesInline},
{"indent-tables", (*toml.Encoder).SetIndentTables},
}
func setFlags(enc *toml.Encoder, flags int) {
for i := 0; i < len(allFlags); i++ {
enabled := flags&1 > 0
allFlags[i].f(enc, enabled)
}
}
func testWithAllFlags(t *testing.T, testfn func(t *testing.T, flags int)) {
t.Helper()
testWithFlags(t, 0, allFlags, testfn)
}
func testWithFlags(t *testing.T, flags int, setters flagsSetters, testfn func(t *testing.T, flags int)) {
t.Helper()
if len(setters) == 0 {
testfn(t, flags)
return
}
s := setters[0]
for _, enabled := range []bool{false, true} {
name := fmt.Sprintf("%s=%t", s.name, enabled)
newFlags := flags << 1
if enabled {
newFlags++
}
t.Run(name, func(t *testing.T) {
testWithFlags(t, newFlags, setters[1:], testfn)
}) })
} }
} }
@@ -255,7 +394,75 @@ func equalStringsIgnoreNewlines(t *testing.T, expected string, actual string) {
assert.Equal(t, strings.Trim(expected, cutset), strings.Trim(actual, cutset)) assert.Equal(t, strings.Trim(expected, cutset), strings.Trim(actual, cutset))
} }
//nolint:funlen
func TestMarshalIndentTables(t *testing.T) {
t.Parallel()
examples := []struct {
desc string
v interface{}
expected string
}{
{
desc: "one kv",
v: map[string]interface{}{
"foo": "bar",
},
expected: `foo = 'bar'`,
},
{
desc: "one level table",
v: map[string]map[string]string{
"foo": {
"one": "value1",
"two": "value2",
},
},
expected: `
[foo]
one = 'value1'
two = 'value2'
`,
},
{
desc: "two levels table",
v: map[string]interface{}{
"root": "value0",
"level1": map[string]interface{}{
"one": "value1",
"level2": map[string]interface{}{
"two": "value2",
},
},
},
expected: `
root = 'value0'
[level1]
one = 'value1'
[level1.level2]
two = 'value2'
`,
},
}
for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) {
t.Parallel()
var buf strings.Builder
enc := toml.NewEncoder(&buf)
enc.SetIndentTables(true)
err := enc.Encode(e.v)
require.NoError(t, err)
equalStringsIgnoreNewlines(t, e.expected, buf.String())
})
}
}
func TestIssue436(t *testing.T) { func TestIssue436(t *testing.T) {
t.Parallel()
data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`)
var v interface{} var v interface{}
@@ -273,3 +480,59 @@ c = 'd'
` `
equalStringsIgnoreNewlines(t, expected, buf.String()) equalStringsIgnoreNewlines(t, expected, buf.String())
} }
func TestIssue424(t *testing.T) {
t.Parallel()
type Message1 struct {
Text string
}
type Message2 struct {
Text string `multiline:"true"`
}
msg1 := Message1{"Hello\\World"}
msg2 := Message2{"Hello\\World"}
toml1, err := toml.Marshal(msg1)
require.NoError(t, err)
toml2, err := toml.Marshal(msg2)
require.NoError(t, err)
msg1parsed := Message1{}
err = toml.Unmarshal(toml1, &msg1parsed)
require.NoError(t, err)
require.Equal(t, msg1, msg1parsed)
msg2parsed := Message2{}
err = toml.Unmarshal(toml2, &msg2parsed)
require.NoError(t, err)
require.Equal(t, msg2, msg2parsed)
}
func ExampleMarshal() {
type MyConfig struct {
Version int
Name string
Tags []string
}
cfg := MyConfig{
Version: 2,
Name: "go-toml",
Tags: []string{"go", "toml"},
}
b, err := toml.Marshal(cfg)
if err != nil {
panic(err)
}
fmt.Println(string(b))
// Output:
// Version = 2
// Name = 'go-toml'
// Tags = ['go', 'toml']
}
+199 -117
View File
@@ -26,6 +26,7 @@ func (p *parser) Reset(b []byte) {
p.first = true p.first = true
} }
//nolint:cyclop
func (p *parser) NextExpression() bool { func (p *parser) NextExpression() bool {
if len(p.left) == 0 || p.err != nil { if len(p.left) == 0 || p.err != nil {
return false return false
@@ -53,11 +54,11 @@ func (p *parser) NextExpression() bool {
return false return false
} }
p.first = false
if p.ref.Valid() { if p.ref.Valid() {
return true return true
} }
p.first = false
} }
} }
@@ -73,18 +74,20 @@ func (p *parser) parseNewline(b []byte) ([]byte, error) {
if b[0] == '\n' { if b[0] == '\n' {
return b[1:], nil return b[1:], nil
} }
if b[0] == '\r' { if b[0] == '\r' {
_, rest, err := scanWindowsNewline(b) _, rest, err := scanWindowsNewline(b)
return rest, err return rest, err
} }
return nil, fmt.Errorf("expected newline but got %#U", b[0])
return nil, newDecodeError(b[0:1], "expected newline but got %#U", b[0])
} }
func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
//expression = ws [ comment ] // expression = ws [ comment ]
//expression =/ ws keyval ws [ comment ] // expression =/ ws keyval ws [ comment ]
//expression =/ ws table ws [ comment ] // expression =/ ws table ws [ comment ]
var ref ast.Reference var ref ast.Reference
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@@ -94,9 +97,11 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
} }
if b[0] == '#' { if b[0] == '#' {
_, rest, err := scanComment(b) _, rest := scanComment(b)
return ref, rest, err
return ref, rest, nil
} }
if b[0] == '\n' || b[0] == '\r' { if b[0] == '\n' || b[0] == '\r' {
return ref, b, nil return ref, b, nil
} }
@@ -107,6 +112,7 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
} else { } else {
ref, b, err = p.parseKeyval(b) ref, b, err = p.parseKeyval(b)
} }
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
} }
@@ -114,57 +120,63 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) > 0 && b[0] == '#' { if len(b) > 0 && b[0] == '#' {
_, rest, err := scanComment(b) _, rest := scanComment(b)
return ref, rest, err
return ref, rest, nil
} }
return ref, b, nil return ref, b, nil
} }
func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) {
//table = std-table / array-table // table = std-table / array-table
if len(b) > 1 && b[1] == '[' { if len(b) > 1 && b[1] == '[' {
return p.parseArrayTable(b) return p.parseArrayTable(b)
} }
return p.parseStdTable(b) return p.parseStdTable(b)
} }
func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) {
//array-table = array-table-open key array-table-close // array-table = array-table-open key array-table-close
//array-table-open = %x5B.5B ws ; [[ Double left square bracket // array-table-open = %x5B.5B ws ; [[ Double left square bracket
//array-table-close = ws %x5D.5D ; ]] Double right square bracket // array-table-close = ws %x5D.5D ; ]] Double right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(ast.Node{
Kind: ast.ArrayTable, Kind: ast.ArrayTable,
}) })
b = b[2:] b = b[2:]
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
k, b, err := p.parseKey(b) k, b, err := p.parseKey(b)
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
} }
p.builder.AttachChild(ref, k) p.builder.AttachChild(ref, k)
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
b, err = expect(']', b) b, err = expect(']', b)
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
} }
b, err = expect(']', b) b, err = expect(']', b)
return ref, b, err return ref, b, err
} }
func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) {
//std-table = std-table-open key std-table-close // std-table = std-table-open key std-table-close
//std-table-open = %x5B ws ; [ Left square bracket // std-table-open = %x5B ws ; [ Left square bracket
//std-table-close = ws %x5D ; ] Right square bracket // std-table-close = ws %x5D ; ] Right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(ast.Node{
Kind: ast.Table, Kind: ast.Table,
}) })
b = b[1:] b = b[1:]
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
key, b, err := p.parseKey(b) key, b, err := p.parseKey(b)
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
@@ -180,8 +192,7 @@ func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) {
} }
func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
//keyval = key keyval-sep val // keyval = key keyval-sep val
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(ast.Node{
Kind: ast.KeyValue, Kind: ast.KeyValue,
}) })
@@ -191,31 +202,35 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
return ast.Reference{}, nil, err return ast.Reference{}, nil, err
} }
//keyval-sep = ws %x3D ws ; = // keyval-sep = ws %x3D ws ; =
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
b, err = expect('=', b) b, err = expect('=', b)
if err != nil { if err != nil {
return ast.Reference{}, nil, err return ast.Reference{}, nil, err
} }
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
valRef, b, err := p.parseVal(b) valRef, b, err := p.parseVal(b)
if err != nil { if err != nil {
return ref, b, err return ref, b, err
} }
p.builder.Chain(valRef, key) p.builder.Chain(valRef, key)
p.builder.AttachChild(ref, valRef) p.builder.AttachChild(ref, valRef)
return ref, b, err return ref, b, err
} }
//nolint:cyclop,funlen
func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
// val = string / boolean / array / inline-table / date-time / float / integer // val = string / boolean / array / inline-table / date-time / float / integer
var ref ast.Reference var ref ast.Reference
if len(b) == 0 { if len(b) == 0 {
return ref, nil, fmt.Errorf("expected value, not eof") return ref, nil, newDecodeError(b, "expected value, not eof")
} }
var err error var err error
@@ -229,12 +244,14 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} else { } else {
v, b, err = p.parseBasicString(b) v, b, err = p.parseBasicString(b)
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
Kind: ast.String, Kind: ast.String,
Data: v, Data: v,
}) })
} }
return ref, b, err return ref, b, err
case '\'': case '\'':
var v []byte var v []byte
@@ -243,30 +260,36 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} else { } else {
v, b, err = p.parseLiteralString(b) v, b, err = p.parseLiteralString(b)
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
Kind: ast.String, Kind: ast.String,
Data: v, Data: v,
}) })
} }
return ref, b, err return ref, b, err
case 't': case 't':
if !scanFollowsTrue(b) { if !scanFollowsTrue(b) {
return ref, nil, fmt.Errorf("expected 'true'") return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
Kind: ast.Bool, Kind: ast.Bool,
Data: b[:4], Data: b[:4],
}) })
return ref, b[4:], nil return ref, b[4:], nil
case 'f': case 'f':
if !scanFollowsFalse(b) { if !scanFollowsFalse(b) {
return ast.Reference{}, nil, fmt.Errorf("expected 'false'") return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
Kind: ast.Bool, Kind: ast.Bool,
Data: b[:5], Data: b[:5],
}) })
return ref, b[5:], nil return ref, b[5:], nil
case '[': case '[':
return p.parseValArray(b) return p.parseValArray(b)
@@ -277,31 +300,40 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
} }
func atmost(b []byte, n int) []byte {
if n >= len(b) {
return b
}
return b[:n]
}
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) {
v, rest, err := scanLiteralString(b) v, rest, err := scanLiteralString(b)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return v[1 : len(v)-1], rest, nil return v[1 : len(v)-1], rest, nil
} }
func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
//inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close // inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close
//inline-table-open = %x7B ws ; { // inline-table-open = %x7B ws ; {
//inline-table-close = ws %x7D ; } // inline-table-close = ws %x7D ; }
//inline-table-sep = ws %x2C ws ; , Comma // inline-table-sep = ws %x2C ws ; , Comma
//inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ] // inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(ast.Node{
Kind: ast.InlineTable, Kind: ast.InlineTable,
}) })
first := true first := true
var child ast.Reference var child ast.Reference
b = b[1:] b = b[1:]
var err error var err error
for len(b) > 0 { for len(b) > 0 {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if b[0] == '}' { if b[0] == '}' {
@@ -315,7 +347,9 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
} }
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
} }
var kv ast.Reference var kv ast.Reference
kv, b, err = p.parseKeyval(b) kv, b, err = p.parseKeyval(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@@ -323,7 +357,6 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
if first { if first {
p.builder.AttachChild(parent, kv) p.builder.AttachChild(parent, kv)
first = false
} else { } else {
p.builder.Chain(child, kv) p.builder.Chain(child, kv)
} }
@@ -333,18 +366,19 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
} }
rest, err := expect('}', b) rest, err := expect('}', b)
return parent, rest, err return parent, rest, err
} }
//nolint:funlen,cyclop
func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
//array = array-open [ array-values ] ws-comment-newline array-close // array = array-open [ array-values ] ws-comment-newline array-close
//array-open = %x5B ; [ // array-open = %x5B ; [
//array-close = %x5D ; ] // array-close = %x5D ; ]
//array-values = ws-comment-newline val ws-comment-newline array-sep array-values // array-values = ws-comment-newline val ws-comment-newline array-sep array-values
//array-values =/ ws-comment-newline val ws-comment-newline [ array-sep ] // array-values =/ ws-comment-newline val ws-comment-newline [ array-sep ]
//array-sep = %x2C ; , Comma // array-sep = %x2C ; , Comma
//ws-comment-newline = *( wschar / [ comment ] newline ) // ws-comment-newline = *( wschar / [ comment ] newline )
b = b[1:] b = b[1:]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(ast.Node{
@@ -352,6 +386,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
}) })
first := true first := true
var lastChild ast.Reference var lastChild ast.Reference
var err error var err error
@@ -362,17 +397,20 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
if len(b) == 0 { if len(b) == 0 {
//nolint:godox
return parent, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF return parent, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF
} }
if b[0] == ']' { if b[0] == ']' {
break break
} }
if b[0] == ',' { if b[0] == ',' {
if first { if first {
return parent, nil, fmt.Errorf("array cannot start with comma") return parent, nil, newDecodeError(b[0:1], "array cannot start with comma")
} }
b = b[1:] b = b[1:]
b, err = p.parseOptionalWhitespaceCommentNewline(b) b, err = p.parseOptionalWhitespaceCommentNewline(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@@ -385,6 +423,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
var valueRef ast.Reference var valueRef ast.Reference
valueRef, b, err = p.parseVal(b) valueRef, b, err = p.parseVal(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@@ -392,7 +431,6 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if first { if first {
p.builder.AttachChild(parent, valueRef) p.builder.AttachChild(parent, valueRef)
first = false
} else { } else {
p.builder.Chain(lastChild, valueRef) p.builder.Chain(lastChild, valueRef)
} }
@@ -406,6 +444,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
rest, err := expect(']', b) rest, err := expect(']', b)
return parent, rest, err return parent, rest, err
} }
@@ -413,15 +452,15 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error)
for len(b) > 0 { for len(b) > 0 {
var err error var err error
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) > 0 && b[0] == '#' { if len(b) > 0 && b[0] == '#' {
_, b, err = scanComment(b) _, b = scanComment(b)
if err != nil {
return nil, err
}
} }
if len(b) == 0 { if len(b) == 0 {
break break
} }
if b[0] == '\n' || b[0] == '\r' { if b[0] == '\n' || b[0] == '\r' {
b, err = p.parseNewline(b) b, err = p.parseNewline(b)
if err != nil { if err != nil {
@@ -431,6 +470,7 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error)
break break
} }
} }
return b, nil return b, nil
} }
@@ -448,25 +488,27 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) {
} else if token[i] == '\r' && token[i+1] == '\n' { } else if token[i] == '\r' && token[i+1] == '\n' {
i += 2 i += 2
} }
return token[i : len(token)-3], rest, err return token[i : len(token)-3], rest, err
} }
//nolint:funlen,gocognit,cyclop
func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
//ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
//ml-basic-string-delim // ml-basic-string-delim
//ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
//ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ] // ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ]
// //
//mlb-content = mlb-char / newline / mlb-escaped-nl // mlb-content = mlb-char / newline / mlb-escaped-nl
//mlb-char = mlb-unescaped / escaped // mlb-char = mlb-unescaped / escaped
//mlb-quotes = 1*2quotation-mark // mlb-quotes = 1*2quotation-mark
//mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//mlb-escaped-nl = escape ws newline *( wschar / newline ) // mlb-escaped-nl = escape ws newline *( wschar / newline )
token, rest, err := scanMultilineBasicString(b) token, rest, err := scanMultilineBasicString(b)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var builder bytes.Buffer var builder bytes.Buffer
i := 3 i := 3
@@ -482,6 +524,8 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
// escapes are balanced. // escapes are balanced.
for ; i < len(token)-3; i++ { for ; i < len(token)-3; i++ {
c := token[i] c := token[i]
//nolint:nestif
if c == '\\' { if c == '\\' {
// When the last non-whitespace character on a line is an unescaped \, // When the last non-whitespace character on a line is an unescaped \,
// it will be trimmed along with all whitespace (including newlines) up // it will be trimmed along with all whitespace (including newlines) up
@@ -492,15 +536,18 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
c := token[i] c := token[i]
if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') { if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') {
i-- i--
break break
} }
} }
continue continue
} }
// handle escaping // handle escaping
i++ i++
c = token[i] c = token[i]
switch c { switch c {
case '"', '\\': case '"', '\\':
builder.WriteByte(c) builder.WriteByte(c)
@@ -519,6 +566,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
builder.WriteString(x) builder.WriteString(x)
i += 4 i += 4
case 'U': case 'U':
@@ -526,10 +574,11 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
builder.WriteString(x) builder.WriteString(x)
i += 8 i += 8
default: default:
return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
} else { } else {
builder.WriteByte(c) builder.WriteByte(c)
@@ -540,15 +589,14 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
} }
func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
//key = simple-key / dotted-key // key = simple-key / dotted-key
//simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
// //
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
//quoted-key = basic-string / literal-string // quoted-key = basic-string / literal-string
//dotted-key = simple-key 1*( dot-sep simple-key ) // dotted-key = simple-key 1*( dot-sep simple-key )
// //
//dot-sep = ws %x2E ws ; . Period // dot-sep = ws %x2E ws ; . Period
key, b, err := p.parseSimpleKey(b) key, b, err := p.parseSimpleKey(b)
if err != nil { if err != nil {
return ast.Reference{}, nil, err return ast.Reference{}, nil, err
@@ -566,11 +614,14 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
} }
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
key, b, err = p.parseSimpleKey(b) key, b, err = p.parseSimpleKey(b)
if err != nil { if err != nil {
return ref, nil, err return ref, nil, err
} }
p.builder.PushAndChain(ast.Node{ p.builder.PushAndChain(ast.Node{
Kind: ast.Key, Kind: ast.Key,
Data: key, Data: key,
@@ -584,46 +635,48 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
} }
func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) {
//simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
//quoted-key = basic-string / literal-string // quoted-key = basic-string / literal-string
if len(b) == 0 { if len(b) == 0 {
//nolint:godox
return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF
} }
if b[0] == '\'' { switch {
key, rest, err = p.parseLiteralString(b) case b[0] == '\'':
} else if b[0] == '"' { return p.parseLiteralString(b)
key, rest, err = p.parseBasicString(b) case b[0] == '"':
} else if isUnquotedKeyChar(b[0]) { return p.parseBasicString(b)
key, rest, err = scanUnquotedKey(b) case isUnquotedKeyChar(b[0]):
} else { return scanUnquotedKey(b)
err = unexpectedCharacter{b: b} // TODO: should contain expected characters default:
//nolint:godox
return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF
} }
return
} }
//nolint:funlen,cyclop
func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
//basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
//quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
//basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//escaped = escape escape-seq-char // escaped = escape escape-seq-char
//escape-seq-char = %x22 ; " quotation mark U+0022 // escape-seq-char = %x22 ; " quotation mark U+0022
//escape-seq-char =/ %x5C ; \ reverse solidus U+005C // escape-seq-char =/ %x5C ; \ reverse solidus U+005C
//escape-seq-char =/ %x62 ; b backspace U+0008 // escape-seq-char =/ %x62 ; b backspace U+0008
//escape-seq-char =/ %x66 ; f form feed U+000C // escape-seq-char =/ %x66 ; f form feed U+000C
//escape-seq-char =/ %x6E ; n line feed U+000A // escape-seq-char =/ %x6E ; n line feed U+000A
//escape-seq-char =/ %x72 ; r carriage return U+000D // escape-seq-char =/ %x72 ; r carriage return U+000D
//escape-seq-char =/ %x74 ; t tab U+0009 // escape-seq-char =/ %x74 ; t tab U+0009
//escape-seq-char =/ %x75 4HEXDIG ; uXXXX U+XXXX // escape-seq-char =/ %x75 4HEXDIG ; uXXXX U+XXXX
//escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX // escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX
token, rest, err := scanBasicString(b) token, rest, err := scanBasicString(b)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var builder bytes.Buffer var builder bytes.Buffer
// The scanner ensures that the token starts and ends with quotes and that // The scanner ensures that the token starts and ends with quotes and that
@@ -633,6 +686,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
if c == '\\' { if c == '\\' {
i++ i++
c = token[i] c = token[i]
switch c { switch c {
case '"', '\\': case '"', '\\':
builder.WriteByte(c) builder.WriteByte(c)
@@ -651,6 +705,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
builder.WriteString(x) builder.WriteString(x)
i += 4 i += 4
case 'U': case 'U':
@@ -658,10 +713,11 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
builder.WriteString(x) builder.WriteString(x)
i += 8 i += 8
default: default:
return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
} else { } else {
builder.WriteByte(c) builder.WriteByte(c)
@@ -673,39 +729,46 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
func hexToString(b []byte, length int) (string, error) { func hexToString(b []byte, length int) (string, error) {
if len(b) < length { if len(b) < length {
return "", fmt.Errorf("unicode point needs %d hex characters", length) return "", newDecodeError(b, "unicode point needs %d character, not %d", length, len(b))
} }
b = b[:length]
//nolint:godox
// TODO: slow // TODO: slow
intcode, err := strconv.ParseInt(string(b[:length]), 16, 32) intcode, err := strconv.ParseInt(string(b), 16, 32)
if err != nil { if err != nil {
return "", err return "", newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
} }
return string(rune(intcode)), nil return string(rune(intcode)), nil
} }
func (p *parser) parseWhitespace(b []byte) []byte { func (p *parser) parseWhitespace(b []byte) []byte {
//ws = *wschar // ws = *wschar
//wschar = %x20 ; Space // wschar = %x20 ; Space
//wschar =/ %x09 ; Horizontal tab // wschar =/ %x09 ; Horizontal tab
_, rest := scanWhitespace(b) _, rest := scanWhitespace(b)
return rest return rest
} }
//nolint:cyclop
func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) {
switch b[0] { switch b[0] {
case 'i': case 'i':
if !scanFollowsInf(b) { if !scanFollowsInf(b) {
return ast.Reference{}, nil, fmt.Errorf("expected 'inf'") return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'inf'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(ast.Node{
Kind: ast.Float, Kind: ast.Float,
Data: b[:3], Data: b[:3],
}), b[3:], nil }), b[3:], nil
case 'n': case 'n':
if !scanFollowsNan(b) { if !scanFollowsNan(b) {
return ast.Reference{}, nil, fmt.Errorf("expected 'nan'") return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'nan'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(ast.Node{
Kind: ast.Float, Kind: ast.Float,
Data: b[:3], Data: b[:3],
@@ -714,60 +777,71 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
//nolint:gomnd
if len(b) < 3 { if len(b) < 3 {
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
s := 5 s := 5
if len(b) < s { if len(b) < s {
s = len(b) s = len(b)
} }
for idx, c := range b[:s] { for idx, c := range b[:s] {
if isDigit(c) { if isDigit(c) {
continue continue
} }
if idx == 2 && c == ':' || (idx == 4 && c == '-') { if idx == 2 && c == ':' || (idx == 4 && c == '-') {
return p.scanDateTime(b) return p.scanDateTime(b)
} }
} }
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
func digitsToInt(b []byte) int { func digitsToInt(b []byte) int {
x := 0 x := 0
for _, d := range b { for _, d := range b {
x *= 10 x *= 10
x += int(d - '0') x += int(d - '0')
} }
return x return x
} }
//nolint:gocognit,cyclop
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
// followed by a digit. // followed by a digit.
hasTime := false hasTime := false
hasTz := false hasTz := false
seenSpace := false seenSpace := false
i := 0 i := 0
byteLoop:
for ; i < len(b); i++ { for ; i < len(b); i++ {
c := b[i] c := b[i]
if isDigit(c) || c == '-' {
} else if c == 'T' || c == ':' || c == '.' { switch {
case isDigit(c) || c == '-':
case c == 'T' || c == ':' || c == '.':
hasTime = true hasTime = true
continue
} else if c == '+' || c == '-' || c == 'Z' { continue byteLoop
case c == '+' || c == '-' || c == 'Z':
hasTz = true hasTz = true
} else if c == ' ' { case c == ' ':
if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) { if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) {
i += 2 i += 2
seenSpace = true seenSpace = true
hasTime = true hasTime = true
} else { } else {
break break byteLoop
} }
} else { default:
break break byteLoop
} }
} }
@@ -781,7 +855,7 @@ func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
} }
} else { } else {
if hasTz { if hasTz {
return ast.Reference{}, nil, fmt.Errorf("possible DateTime cannot have a timezone but no time component") return ast.Reference{}, nil, newDecodeError(b, "date-time has timezone but not time component")
} }
kind = ast.LocalDate kind = ast.LocalDate
} }
@@ -792,11 +866,13 @@ func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
}), b[i:], nil }), b[i:], nil
} }
//nolint:funlen,gocognit,cyclop
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
i := 0 i := 0
if len(b) > 2 && b[0] == '0' && b[1] != '.' { if len(b) > 2 && b[0] == '0' && b[1] != '.' {
var isValidRune validRuneFn var isValidRune validRuneFn
switch b[1] { switch b[1] {
case 'x': case 'x':
isValidRune = isValidHexRune isValidRune = isValidHexRune
@@ -834,6 +910,7 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
if c == '.' || c == 'e' || c == 'E' { if c == '.' || c == 'e' || c == 'E' {
isFloat = true isFloat = true
continue continue
} }
@@ -844,8 +921,10 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.Reference{}, nil, fmt.Errorf("unexpected character i while scanning for a number")
return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number")
} }
if c == 'n' { if c == 'n' {
if scanFollowsNan(b[i:]) { if scanFollowsNan(b[i:]) {
return p.builder.Push(ast.Node{ return p.builder.Push(ast.Node{
@@ -853,14 +932,15 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.Reference{}, nil, fmt.Errorf("unexpected character n while scanning for a number")
return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number")
} }
break break
} }
if i == 0 { if i == 0 {
return ast.Reference{}, b, fmt.Errorf("expected integer or float") return ast.Reference{}, b, newDecodeError(b, "incomplete number")
} }
kind := ast.Integer kind := ast.Integer
@@ -900,9 +980,11 @@ func expect(x byte, b []byte) ([]byte, error) {
if len(b) == 0 { if len(b) == 0 {
return nil, newDecodeError(b[:0], "expecting %#U", x) return nil, newDecodeError(b[:0], "expecting %#U", x)
} }
if b[0] != x { if b[0] != x {
return nil, newDecodeError(b[0:1], "expected character %U", x) return nil, newDecodeError(b[0:1], "expected character %U", x)
} }
return b[1:], nil return b[1:], nil
} }
@@ -914,7 +996,7 @@ type unexpectedCharacter struct {
func (u unexpectedCharacter) Error() string { func (u unexpectedCharacter) Error() string {
if len(u.b) == 0 { if len(u.b) == 0 {
return fmt.Sprintf("expected %#U, not EOF", u.r) return fmt.Sprintf("expected %#U, not EOF", u.r)
} }
return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0])
} }
+19 -59
View File
@@ -7,7 +7,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
//nolint:funlen
func TestParser_AST_Numbers(t *testing.T) { func TestParser_AST_Numbers(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input string input string
@@ -132,7 +135,9 @@ func TestParser_AST_Numbers(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
p := parser{} p := parser{}
p.Reset([]byte(`A = ` + e.input)) p.Reset([]byte(`A = ` + e.input))
p.NextExpression() p.NextExpression()
@@ -155,19 +160,16 @@ func TestParser_AST_Numbers(t *testing.T) {
} }
} }
type astRoot []astNode type (
type astNode struct { astNode struct {
Kind ast.Kind Kind ast.Kind
Data []byte Data []byte
Children []astNode Children []astNode
} }
)
func compareAST(t *testing.T, expected astRoot, actual *ast.Root) {
it := actual.Iterator()
compareIterator(t, expected, it)
}
func compareNode(t *testing.T, e astNode, n ast.Node) { func compareNode(t *testing.T, e astNode, n ast.Node) {
t.Helper()
require.Equal(t, e.Kind, n.Kind) require.Equal(t, e.Kind, n.Kind)
require.Equal(t, e.Data, n.Data) require.Equal(t, e.Data, n.Data)
@@ -175,6 +177,7 @@ func compareNode(t *testing.T, e astNode, n ast.Node) {
} }
func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) { func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) {
t.Helper()
idx := 0 idx := 0
for actual.Next() { for actual.Next() {
@@ -195,55 +198,10 @@ func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) {
} }
} }
func (r astRoot) toOrig() *ast.Root { //nolint:funlen
builder := &ast.Builder{}
var last ast.Reference
for i, n := range r {
ref := builder.Push(ast.Node{
Kind: n.Kind,
Data: n.Data,
})
if i > 0 {
builder.Chain(last, ref)
}
last = ref
if len(n.Children) > 0 {
c := childrenToOrig(builder, n.Children)
builder.AttachChild(ref, c)
}
}
return builder.Tree()
}
func childrenToOrig(b *ast.Builder, nodes []astNode) ast.Reference {
var first ast.Reference
var last ast.Reference
for i, n := range nodes {
ref := b.Push(ast.Node{
Kind: n.Kind,
Data: n.Data,
})
if i == 0 {
first = ref
} else {
b.Chain(last, ref)
}
last = ref
if len(n.Children) > 0 {
c := childrenToOrig(b, n.Children)
b.AttachChild(ref, c)
}
}
return first
}
func TestParser_AST(t *testing.T) { func TestParser_AST(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input string input string
@@ -380,7 +338,9 @@ func TestParser_AST(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
p := parser{} p := parser{}
p.Reset([]byte(e.input)) p.Reset([]byte(e.input))
p.NextExpression() p.NextExpression()
+74 -69
View File
@@ -1,35 +1,43 @@
package toml package toml
import "fmt" func scanFollows(b []byte, pattern string) bool {
n := len(pattern)
func scanFollows(pattern []byte) func(b []byte) bool { return len(b) >= n && string(b[:n]) == pattern
return func(b []byte) bool {
if len(b) < len(pattern) {
return false
}
for i, c := range pattern {
if b[i] != c {
return false
}
}
return true
}
} }
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'}) func scanFollowsMultilineBasicStringDelimiter(b []byte) bool {
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''}) return scanFollows(b, `"""`)
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'}) }
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'}) func scanFollowsMultilineLiteralStringDelimiter(b []byte) bool {
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'}) return scanFollows(b, `'''`)
}
func scanFollowsTrue(b []byte) bool {
return scanFollows(b, `true`)
}
func scanFollowsFalse(b []byte) bool {
return scanFollows(b, `false`)
}
func scanFollowsInf(b []byte) bool {
return scanFollows(b, `inf`)
}
func scanFollowsNan(b []byte) bool {
return scanFollows(b, `nan`)
}
func scanUnquotedKey(b []byte) ([]byte, []byte, error) { func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
for i := 0; i < len(b); i++ { for i := 0; i < len(b); i++ {
if !isUnquotedKeyChar(b[i]) { if !isUnquotedKeyChar(b[i]) {
return b[:i], b[i:], nil return b[:i], b[i:], nil
} }
} }
return b, b[len(b):], nil return b, b[len(b):], nil
} }
@@ -38,9 +46,9 @@ func isUnquotedKeyChar(r byte) bool {
} }
func scanLiteralString(b []byte) ([]byte, []byte, error) { func scanLiteralString(b []byte) ([]byte, []byte, error) {
//literal-string = apostrophe *literal-char apostrophe // literal-string = apostrophe *literal-char apostrophe
//apostrophe = %x27 ; ' apostrophe // apostrophe = %x27 ; ' apostrophe
//literal-char = %x09 / %x20-26 / %x28-7E / non-ascii // literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
for i := 1; i < len(b); i++ { for i := 1; i < len(b); i++ {
switch b[i] { switch b[i] {
case '\'': case '\'':
@@ -49,24 +57,22 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) {
return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines")
} }
} }
return nil, nil, newDecodeError(b[len(b):], "unterminated literal string") return nil, nil, newDecodeError(b[len(b):], "unterminated literal string")
} }
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
//ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body // ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body
//ml-literal-string-delim // ml-literal-string-delim
//ml-literal-string-delim = 3apostrophe // ml-literal-string-delim = 3apostrophe
//ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ] // ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ]
// //
//mll-content = mll-char / newline // mll-content = mll-char / newline
//mll-char = %x09 / %x20-26 / %x28-7E / non-ascii // mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
//mll-quotes = 1*2apostrophe // mll-quotes = 1*2apostrophe
for i := 3; i < len(b); i++ { for i := 3; i < len(b); i++ {
switch b[i] { if b[i] == '\'' && scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
case '\'': return b[:i+3], b[i+3:], nil
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
return b[:i+3], b[i+3:], nil
}
} }
} }
@@ -74,13 +80,16 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
} }
func scanWindowsNewline(b []byte) ([]byte, []byte, error) { func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
if len(b) < 2 { const lenCRLF = 2
return nil, nil, fmt.Errorf(`windows new line missing \n`) if len(b) < lenCRLF {
return nil, nil, newDecodeError(b, "windows new line expected")
} }
if b[1] != '\n' { if b[1] != '\n' {
return nil, nil, fmt.Errorf(`windows new line should be \r\n`) return nil, nil, newDecodeError(b, `windows new line should be \r\n`)
} }
return b[:2], b[2:], nil
return b[:lenCRLF], b[lenCRLF:], nil
} }
func scanWhitespace(b []byte) ([]byte, []byte) { func scanWhitespace(b []byte) ([]byte, []byte) {
@@ -92,34 +101,32 @@ func scanWhitespace(b []byte) ([]byte, []byte) {
return b[:i], b[i:] return b[:i], b[i:]
} }
} }
return b, b[len(b):] return b, b[len(b):]
} }
func scanComment(b []byte) ([]byte, []byte, error) { //nolint:unparam
//;; Comment func scanComment(b []byte) ([]byte, []byte) {
// comment-start-symbol = %x23 ; #
// non-ascii = %x80-D7FF / %xE000-10FFFF
// non-eol = %x09 / %x20-7F / non-ascii
// //
//comment-start-symbol = %x23 ; # // comment = comment-start-symbol *non-eol
//non-ascii = %x80-D7FF / %xE000-10FFFF
//non-eol = %x09 / %x20-7F / non-ascii
//
//comment = comment-start-symbol *non-eol
for i := 1; i < len(b); i++ { for i := 1; i < len(b); i++ {
switch b[i] { if b[i] == '\n' {
case '\n': return b[:i], b[i:]
return b[:i], b[i:], nil
} }
} }
return b, nil, nil
return b, nil
} }
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) { func scanBasicString(b []byte) ([]byte, []byte, error) {
//basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
//quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
//basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//escaped = escape escape-seq-char // escaped = escape escape-seq-char
for i := 1; i < len(b); i++ { for i := 1; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
@@ -134,22 +141,20 @@ func scanBasicString(b []byte) ([]byte, []byte, error) {
} }
} }
return nil, nil, fmt.Errorf(`basic string not terminated by "`) return nil, nil, newDecodeError(b[len(b):], `basic string not terminated by "`)
} }
// TODO perform validation on the string?
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
//ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
//ml-basic-string-delim // ml-basic-string-delim
//ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
//ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ] // ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ]
// //
//mlb-content = mlb-char / newline / mlb-escaped-nl // mlb-content = mlb-char / newline / mlb-escaped-nl
//mlb-char = mlb-unescaped / escaped // mlb-char = mlb-unescaped / escaped
//mlb-quotes = 1*2quotation-mark // mlb-quotes = 1*2quotation-mark
//mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//mlb-escaped-nl = escape ws newline *( wschar / newline ) // mlb-escaped-nl = escape ws newline *( wschar / newline )
for i := 3; i < len(b); i++ { for i := 3; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
+88
View File
@@ -0,0 +1,88 @@
package toml
import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker"
)
type strict struct {
Enabled bool
// Tracks the current key being processed.
key tracker.KeyTracker
missing []decodeError
}
func (s *strict) EnterTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateTable(node)
}
func (s *strict) EnterArrayTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateArrayTable(node)
}
func (s *strict) EnterKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Push(node)
}
func (s *strict) ExitKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Pop(node)
}
func (s *strict) MissingTable(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing table",
key: s.key.Key(),
})
}
func (s *strict) MissingField(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing field",
key: s.key.Key(),
})
}
func (s *strict) Error(doc []byte) error {
if !s.Enabled || len(s.missing) == 0 {
return nil
}
err := &StrictMissingError{
Errors: make([]DecodeError, 0, len(s.missing)),
}
for _, derr := range s.missing {
derr := derr
err.Errors = append(err.Errors, *wrapDecodeError(doc, &derr))
}
return err
}
+169 -183
View File
@@ -13,19 +13,19 @@ type target interface {
get() reflect.Value get() reflect.Value
// Store a string at the target. // Store a string at the target.
setString(v string) error setString(v string)
// Store a boolean at the target // Store a boolean at the target
setBool(v bool) error setBool(v bool)
// Store an int64 at the target // Store an int64 at the target
setInt64(v int64) error setInt64(v int64)
// Store a float64 at the target // Store a float64 at the target
setFloat64(v float64) error setFloat64(v float64)
// Stores any value at the target // Stores any value at the target
set(v reflect.Value) error set(v reflect.Value)
} }
// valueTarget just contains a reflect.Value that can be set. // valueTarget just contains a reflect.Value that can be set.
@@ -36,29 +36,24 @@ func (t valueTarget) get() reflect.Value {
return reflect.Value(t) return reflect.Value(t)
} }
func (t valueTarget) set(v reflect.Value) error { func (t valueTarget) set(v reflect.Value) {
reflect.Value(t).Set(v) reflect.Value(t).Set(v)
return nil
} }
func (t valueTarget) setString(v string) error { func (t valueTarget) setString(v string) {
t.get().SetString(v) t.get().SetString(v)
return nil
} }
func (t valueTarget) setBool(v bool) error { func (t valueTarget) setBool(v bool) {
t.get().SetBool(v) t.get().SetBool(v)
return nil
} }
func (t valueTarget) setInt64(v int64) error { func (t valueTarget) setInt64(v int64) {
t.get().SetInt(v) t.get().SetInt(v)
return nil
} }
func (t valueTarget) setFloat64(v float64) error { func (t valueTarget) setFloat64(v float64) {
t.get().SetFloat(v) t.get().SetFloat(v)
return nil
} }
// interfaceTarget wraps an other target to dereference on get. // interfaceTarget wraps an other target to dereference on get.
@@ -70,24 +65,24 @@ func (t interfaceTarget) get() reflect.Value {
return t.x.get().Elem() return t.x.get().Elem()
} }
func (t interfaceTarget) set(v reflect.Value) error { func (t interfaceTarget) set(v reflect.Value) {
return t.x.set(v) t.x.set(v)
} }
func (t interfaceTarget) setString(v string) error { func (t interfaceTarget) setString(v string) {
return t.x.setString(v) t.x.setString(v)
} }
func (t interfaceTarget) setBool(v bool) error { func (t interfaceTarget) setBool(v bool) {
return t.x.setBool(v) t.x.setBool(v)
} }
func (t interfaceTarget) setInt64(v int64) error { func (t interfaceTarget) setInt64(v int64) {
return t.x.setInt64(v) t.x.setInt64(v)
} }
func (t interfaceTarget) setFloat64(v float64) error { func (t interfaceTarget) setFloat64(v float64) {
return t.x.setFloat64(v) t.x.setFloat64(v)
} }
// mapTarget targets a specific key of a map. // mapTarget targets a specific key of a map.
@@ -100,27 +95,27 @@ func (t mapTarget) get() reflect.Value {
return t.v.MapIndex(t.k) return t.v.MapIndex(t.k)
} }
func (t mapTarget) set(v reflect.Value) error { func (t mapTarget) set(v reflect.Value) {
t.v.SetMapIndex(t.k, v) t.v.SetMapIndex(t.k, v)
return nil
} }
func (t mapTarget) setString(v string) error { func (t mapTarget) setString(v string) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setBool(v bool) error { func (t mapTarget) setBool(v bool) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setInt64(v int64) error { func (t mapTarget) setInt64(v int64) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setFloat64(v float64) error { func (t mapTarget) setFloat64(v float64) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
//nolint:cyclop
// makes sure that the value pointed at by t is indexable (Slice, Array), or // makes sure that the value pointed at by t is indexable (Slice, Array), or
// dereferences to an indexable (Ptr, Interface). // dereferences to an indexable (Ptr, Interface).
func ensureValueIndexable(t target) error { func ensureValueIndexable(t target) error {
@@ -129,40 +124,36 @@ func ensureValueIndexable(t target) error {
switch f.Type().Kind() { switch f.Type().Kind() {
case reflect.Slice: case reflect.Slice:
if f.IsNil() { if f.IsNil() {
return t.set(reflect.MakeSlice(f.Type(), 0, 0)) t.set(reflect.MakeSlice(f.Type(), 0, 0))
return nil
} }
case reflect.Interface: case reflect.Interface:
if f.IsNil() || f.Elem().Type() != sliceInterfaceType { if f.IsNil() || f.Elem().Type() != sliceInterfaceType {
return t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0))
} return nil
if f.Elem().Type().Kind() != reflect.Slice {
return fmt.Errorf("interface is pointing to a %s, not a slice", f.Kind())
} }
case reflect.Ptr: case reflect.Ptr:
if f.IsNil() { panic("pointer should have already been dereferenced")
ptr := reflect.New(f.Type().Elem())
err := t.set(ptr)
if err != nil {
return err
}
f = t.get()
}
return ensureValueIndexable(valueTarget(f.Elem()))
case reflect.Array: case reflect.Array:
// arrays are always initialized. // arrays are always initialized.
default: default:
return fmt.Errorf("cannot initialize a slice in %s", f.Kind()) return fmt.Errorf("toml: cannot store array in a %s", f.Kind())
} }
return nil return nil
} }
var sliceInterfaceType = reflect.TypeOf([]interface{}{}) var (
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) sliceInterfaceType = reflect.TypeOf([]interface{}{})
mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
)
func ensureMapIfInterface(x target) { func ensureMapIfInterface(x target) {
v := x.get() v := x.get()
if v.Kind() == reflect.Interface && v.IsNil() { if v.Kind() == reflect.Interface && v.IsNil() {
newElement := reflect.MakeMap(mapStringInterfaceType) newElement := reflect.MakeMap(mapStringInterfaceType)
x.set(newElement) x.set(newElement)
} }
} }
@@ -172,12 +163,14 @@ func setString(t target, v string) error {
switch f.Kind() { switch f.Kind() {
case reflect.String: case reflect.String:
return t.setString(v) t.setString(v)
case reflect.Interface: case reflect.Interface:
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
default: default:
return fmt.Errorf("cannot assign string to a %s", f.Kind()) return fmt.Errorf("toml: cannot assign string to a %s", f.Kind())
} }
return nil
} }
func setBool(t target, v bool) error { func setBool(t target, v bool) error {
@@ -185,83 +178,90 @@ func setBool(t target, v bool) error {
switch f.Kind() { switch f.Kind() {
case reflect.Bool: case reflect.Bool:
return t.setBool(v) t.setBool(v)
case reflect.Interface: case reflect.Interface:
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
default: default:
return fmt.Errorf("cannot assign bool to a %s", f.String()) return fmt.Errorf("toml: cannot assign boolean to a %s", f.Kind())
} }
return nil
} }
const maxInt = int64(^uint(0) >> 1) const (
const minInt = -maxInt - 1 maxInt = int64(^uint(0) >> 1)
minInt = -maxInt - 1
)
//nolint:funlen,gocognit,cyclop,gocyclo
func setInt64(t target, v int64) error { func setInt64(t target, v int64) error {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.Int64: case reflect.Int64:
return t.setInt64(v) t.setInt64(v)
case reflect.Int32: case reflect.Int32:
if v < math.MinInt32 || v > math.MaxInt32 { if v < math.MinInt32 || v > math.MaxInt32 {
return fmt.Errorf("integer %d does not fit in an int32", v) return fmt.Errorf("toml: number %d does not fit in an int32", v)
} }
return t.set(reflect.ValueOf(int32(v)))
t.set(reflect.ValueOf(int32(v)))
return nil
case reflect.Int16: case reflect.Int16:
if v < math.MinInt16 || v > math.MaxInt16 { if v < math.MinInt16 || v > math.MaxInt16 {
return fmt.Errorf("integer %d does not fit in an int16", v) return fmt.Errorf("toml: number %d does not fit in an int16", v)
} }
return t.set(reflect.ValueOf(int16(v)))
t.set(reflect.ValueOf(int16(v)))
case reflect.Int8: case reflect.Int8:
if v < math.MinInt8 || v > math.MaxInt8 { if v < math.MinInt8 || v > math.MaxInt8 {
return fmt.Errorf("integer %d does not fit in an int8", v) return fmt.Errorf("toml: number %d does not fit in an int8", v)
} }
return t.set(reflect.ValueOf(int8(v)))
t.set(reflect.ValueOf(int8(v)))
case reflect.Int: case reflect.Int:
if v < minInt || v > maxInt { if v < minInt || v > maxInt {
return fmt.Errorf("integer %d does not fit in an int", v) return fmt.Errorf("toml: number %d does not fit in an int", v)
} }
return t.set(reflect.ValueOf(int(v)))
t.set(reflect.ValueOf(int(v)))
case reflect.Uint64: case reflect.Uint64:
if v < 0 { if v < 0 {
return fmt.Errorf("negative integer %d cannot be stored in an uint64", v) return fmt.Errorf("toml: negative number %d does not fit in an uint64", v)
} }
return t.set(reflect.ValueOf(uint64(v)))
t.set(reflect.ValueOf(uint64(v)))
case reflect.Uint32: case reflect.Uint32:
if v < 0 { if v < 0 || v > math.MaxUint32 {
return fmt.Errorf("negative integer %d cannot be stored in an uint32", v) return fmt.Errorf("toml: negative number %d does not fit in an uint32", v)
} }
if v > math.MaxUint32 {
return fmt.Errorf("integer %d cannot be stored in an uint32", v) t.set(reflect.ValueOf(uint32(v)))
}
return t.set(reflect.ValueOf(uint32(v)))
case reflect.Uint16: case reflect.Uint16:
if v < 0 { if v < 0 || v > math.MaxUint16 {
return fmt.Errorf("negative integer %d cannot be stored in an uint16", v) return fmt.Errorf("toml: negative number %d does not fit in an uint16", v)
} }
if v > math.MaxUint16 {
return fmt.Errorf("integer %d cannot be stored in an uint16", v) t.set(reflect.ValueOf(uint16(v)))
}
return t.set(reflect.ValueOf(uint16(v)))
case reflect.Uint8: case reflect.Uint8:
if v < 0 { if v < 0 || v > math.MaxUint8 {
return fmt.Errorf("negative integer %d cannot be stored in an uint8", v) return fmt.Errorf("toml: negative number %d does not fit in an uint8", v)
} }
if v > math.MaxUint8 {
return fmt.Errorf("integer %d cannot be stored in an uint8", v) t.set(reflect.ValueOf(uint8(v)))
}
return t.set(reflect.ValueOf(uint8(v)))
case reflect.Uint: case reflect.Uint:
if v < 0 { if v < 0 {
return fmt.Errorf("negative integer %d cannot be stored in an uint", v) return fmt.Errorf("toml: negative number %d does not fit in an uint", v)
} }
return t.set(reflect.ValueOf(uint(v)))
t.set(reflect.ValueOf(uint(v)))
case reflect.Interface: case reflect.Interface:
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
default: default:
return fmt.Errorf("cannot assign int64 to a %s", f.String()) return fmt.Errorf("toml: integer cannot be assigned to %s", f.Kind())
} }
return nil
} }
func setFloat64(t target, v float64) error { func setFloat64(t target, v float64) error {
@@ -269,98 +269,89 @@ func setFloat64(t target, v float64) error {
switch f.Kind() { switch f.Kind() {
case reflect.Float64: case reflect.Float64:
return t.setFloat64(v) t.setFloat64(v)
case reflect.Float32: case reflect.Float32:
if v > math.MaxFloat32 { if v > math.MaxFloat32 {
return fmt.Errorf("float %f cannot be stored in a float32", v) return fmt.Errorf("toml: number %f does not fit in a float32", v)
} }
return t.set(reflect.ValueOf(float32(v)))
t.set(reflect.ValueOf(float32(v)))
case reflect.Interface: case reflect.Interface:
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
default: default:
return fmt.Errorf("cannot assign float64 to a %s", f.String()) return fmt.Errorf("toml: float cannot be assigned to %s", f.Kind())
} }
return nil
} }
//nolint:cyclop
// Returns the element at idx of the value pointed at by target, or an error if // Returns the element at idx of the value pointed at by target, or an error if
// t does not point to an indexable. // t does not point to an indexable.
// If the target points to an Array and idx is out of bounds, it returns // If the target points to an Array and idx is out of bounds, it returns
// (nil, nil) as this is not a fatal error (the unmarshaler will skip). // (nil, nil) as this is not a fatal error (the unmarshaler will skip).
func elementAt(t target, idx int) (target, error) { func elementAt(t target, idx int) target {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.Slice: case reflect.Slice:
//nolint:godox
// TODO: use the idx function argument and avoid alloc if possible. // TODO: use the idx function argument and avoid alloc if possible.
idx := f.Len() idx := f.Len()
err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
if err != nil { t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
return nil, err
} return valueTarget(t.get().Index(idx))
return valueTarget(t.get().Index(idx)), nil
case reflect.Array: case reflect.Array:
if idx >= f.Len() { if idx >= f.Len() {
return nil, nil return nil
} }
return valueTarget(f.Index(idx)), nil
return valueTarget(f.Index(idx))
case reflect.Interface: case reflect.Interface:
if f.IsNil() { // This function is called after ensureValueIndexable, so it's
panic("interface should have been initialized") // guaranteed that f contains an initialized slice.
}
ifaceElem := f.Elem() ifaceElem := f.Elem()
if ifaceElem.Kind() != reflect.Slice {
return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind())
}
idx := ifaceElem.Len() idx := ifaceElem.Len()
newElem := reflect.New(ifaceElem.Type().Elem()).Elem() newElem := reflect.New(ifaceElem.Type().Elem()).Elem()
newSlice := reflect.Append(ifaceElem, newElem) newSlice := reflect.Append(ifaceElem, newElem)
err := t.set(newSlice)
if err != nil { t.set(newSlice)
return nil, err
} return valueTarget(t.get().Elem().Index(idx))
return valueTarget(t.get().Elem().Index(idx)), nil
case reflect.Ptr:
return elementAt(valueTarget(f.Elem()), idx)
default: default:
return nil, fmt.Errorf("cannot elementAt on a %s", f.Kind()) // Why ensureValueIndexable let it go through?
panic(fmt.Errorf("elementAt received unhandled value type: %s", f.Kind()))
} }
} }
func (d *decoder) scopeTableTarget(append bool, t target, name string) (target, bool, error) { //nolint:cyclop
func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (target, bool, error) {
x := t.get() x := t.get()
switch x.Kind() { switch x.Kind() {
// Kinds that need to recurse // Kinds that need to recurse
case reflect.Interface: case reflect.Interface:
t, err := scopeInterface(append, t) t := scopeInterface(shouldAppend, t)
if err != nil { return d.scopeTableTarget(shouldAppend, t, name)
return t, false, err
}
return d.scopeTableTarget(append, t, name)
case reflect.Ptr: case reflect.Ptr:
t, err := scopePtr(t) t := scopePtr(t)
if err != nil { return d.scopeTableTarget(shouldAppend, t, name)
return t, false, err
}
return d.scopeTableTarget(append, t, name)
case reflect.Slice: case reflect.Slice:
t, err := scopeSlice(append, t) t := scopeSlice(shouldAppend, t)
if err != nil { shouldAppend = false
return t, false, err return d.scopeTableTarget(shouldAppend, t, name)
}
append = false
return d.scopeTableTarget(append, t, name)
case reflect.Array: case reflect.Array:
t, err := d.scopeArray(append, t) t, err := d.scopeArray(shouldAppend, t)
if err != nil { if err != nil {
return t, false, err return t, false, err
} }
append = false shouldAppend = false
return d.scopeTableTarget(append, t, name)
return d.scopeTableTarget(shouldAppend, t, name)
// Terminal kinds // Terminal kinds
case reflect.Struct: case reflect.Struct:
return scopeStruct(x, name) return scopeStruct(x, name)
case reflect.Map: case reflect.Map:
@@ -371,38 +362,33 @@ func (d *decoder) scopeTableTarget(append bool, t target, name string) (target,
return scopeMap(x, name) return scopeMap(x, name)
default: default:
panic(fmt.Errorf("can't scope on a %s", x.Kind())) panic(fmt.Sprintf("can't scope on a %s", x.Kind()))
} }
} }
func scopeInterface(append bool, t target) (target, error) { func scopeInterface(shouldAppend bool, t target) target {
err := initInterface(append, t) initInterface(shouldAppend, t)
if err != nil { return interfaceTarget{t}
return t, err
}
return interfaceTarget{t}, nil
} }
func scopePtr(t target) (target, error) { func scopePtr(t target) target {
err := initPtr(t) initPtr(t)
if err != nil { return valueTarget(t.get().Elem())
return t, err
}
return valueTarget(t.get().Elem()), nil
} }
func initPtr(t target) error { func initPtr(t target) {
x := t.get() x := t.get()
if !x.IsNil() { if !x.IsNil() {
return nil return
} }
return t.set(reflect.New(x.Type().Elem()))
t.set(reflect.New(x.Type().Elem()))
} }
// initInterface makes sure that the interface pointed at by the target is not // initInterface makes sure that the interface pointed at by the target is not
// nil. // nil.
// Returns the target to the initialized value of the target. // Returns the target to the initialized value of the target.
func initInterface(append bool, t target) error { func initInterface(shouldAppend bool, t target) {
x := t.get() x := t.get()
if x.Kind() != reflect.Interface { if x.Kind() != reflect.Interface {
@@ -410,45 +396,41 @@ func initInterface(append bool, t target) error {
} }
if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) { if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) {
return nil return
} }
var newElement reflect.Value var newElement reflect.Value
if append { if shouldAppend {
newElement = reflect.MakeSlice(sliceInterfaceType, 0, 0) newElement = reflect.MakeSlice(sliceInterfaceType, 0, 0)
} else { } else {
newElement = reflect.MakeMap(mapStringInterfaceType) newElement = reflect.MakeMap(mapStringInterfaceType)
} }
err := t.set(newElement)
if err != nil {
return err
}
return nil t.set(newElement)
} }
func scopeSlice(append bool, t target) (target, error) { func scopeSlice(shouldAppend bool, t target) target {
v := t.get() v := t.get()
if append { if shouldAppend {
newElem := reflect.New(v.Type().Elem()) newElem := reflect.New(v.Type().Elem())
newSlice := reflect.Append(v, newElem.Elem()) newSlice := reflect.Append(v, newElem.Elem())
err := t.set(newSlice)
if err != nil { t.set(newSlice)
return t, err
}
v = t.get() v = t.get()
} }
return valueTarget(v.Index(v.Len() - 1)), nil
return valueTarget(v.Index(v.Len() - 1))
} }
func (d *decoder) scopeArray(append bool, t target) (target, error) { func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) {
v := t.get() v := t.get()
idx := d.arrayIndex(append, v) idx := d.arrayIndex(shouldAppend, v)
if idx >= v.Len() { if idx >= v.Len() {
return nil, fmt.Errorf("not enough space in the array") return nil, fmt.Errorf("toml: impossible to insert element beyond array's size: %d", v.Len())
} }
return valueTarget(v.Index(idx)), nil return valueTarget(v.Index(idx)), nil
@@ -460,8 +442,9 @@ func scopeMap(v reflect.Value, name string) (target, bool, error) {
keyType := v.Type().Key() keyType := v.Type().Key()
if !k.Type().AssignableTo(keyType) { if !k.Type().AssignableTo(keyType) {
if !k.Type().ConvertibleTo(keyType) { if !k.Type().ConvertibleTo(keyType) {
return nil, false, fmt.Errorf("cannot convert string into map key type %s", keyType) return nil, false, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", k.Type(), keyType)
} }
k = k.Convert(keyType) k = k.Convert(keyType)
} }
@@ -487,6 +470,7 @@ func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) {
c.l.RLock() c.l.RLock()
paths, ok := c.m[t] paths, ok := c.m[t]
c.l.RUnlock() c.l.RUnlock()
return paths, ok return paths, ok
} }
@@ -502,13 +486,14 @@ var globalFieldPathsCache = fieldPathsCache{
} }
func scopeStruct(v reflect.Value, name string) (target, bool, error) { func scopeStruct(v reflect.Value, name string) (target, bool, error) {
//nolint:godox
// TODO: cache this, and reduce allocations // TODO: cache this, and reduce allocations
fieldPaths, ok := globalFieldPathsCache.get(v.Type()) fieldPaths, ok := globalFieldPathsCache.get(v.Type())
if !ok { if !ok {
fieldPaths = map[string][]int{} fieldPaths = map[string][]int{}
path := make([]int, 0, 16) path := make([]int, 0, 16)
var walk func(reflect.Value) var walk func(reflect.Value)
walk = func(v reflect.Value) { walk = func(v reflect.Value) {
t := v.Type() t := v.Type()
@@ -516,11 +501,11 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) {
l := len(path) l := len(path)
path = append(path, i) path = append(path, i)
f := t.Field(i) f := t.Field(i)
if f.PkgPath != "" {
// only consider exported fields if f.Anonymous {
} else if f.Anonymous {
walk(v.Field(i)) walk(v.Field(i))
} else { } else if f.PkgPath == "" {
// only consider exported fields
fieldName, ok := f.Tag.Lookup("toml") fieldName, ok := f.Tag.Lookup("toml")
if !ok { if !ok {
fieldName = f.Name fieldName = f.Name
@@ -546,6 +531,7 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) {
if !ok { if !ok {
path, ok = fieldPaths[strings.ToLower(name)] path, ok = fieldPaths[strings.ToLower(name)]
} }
if !ok { if !ok {
return nil, false, nil return nil, false, nil
} }
+34 -11
View File
@@ -9,6 +9,8 @@ import (
) )
func TestStructTarget_Ensure(t *testing.T) { func TestStructTarget_Ensure(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input reflect.Value input reflect.Value
@@ -31,14 +33,23 @@ func TestStructTarget_Ensure(t *testing.T) {
test: func(v reflect.Value, err error) { test: func(v reflect.Value, err error) {
assert.NoError(t, err) assert.NoError(t, err)
require.False(t, v.IsNil()) require.False(t, v.IsNil())
s := v.Interface().([]string)
s, ok := v.Interface().([]string)
if !ok {
t.Errorf("interface %v should be castable into []string", s)
return
}
assert.Equal(t, []string{"foo"}, s) assert.Equal(t, []string{"foo"}, s)
}, },
}, },
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
d := decoder{} d := decoder{}
target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err) require.NoError(t, err)
@@ -50,6 +61,8 @@ func TestStructTarget_Ensure(t *testing.T) {
} }
func TestStructTarget_SetString(t *testing.T) { func TestStructTarget_SetString(t *testing.T) {
t.Parallel()
str := "value" str := "value"
examples := []struct { examples := []struct {
@@ -86,7 +99,10 @@ func TestStructTarget_SetString(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
d := decoder{} d := decoder{}
target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name) target, _, err := d.scopeTableTarget(false, valueTarget(e.input), e.name)
require.NoError(t, err) require.NoError(t, err)
@@ -98,7 +114,11 @@ func TestStructTarget_SetString(t *testing.T) {
} }
func TestPushNew(t *testing.T) { func TestPushNew(t *testing.T) {
t.Parallel()
t.Run("slice of strings", func(t *testing.T) { t.Run("slice of strings", func(t *testing.T) {
t.Parallel()
type Doc struct { type Doc struct {
A []string A []string
} }
@@ -108,18 +128,18 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := elementAt(x, 0) n := elementAt(x, 0)
require.NoError(t, err) n.setString("hello")
require.NoError(t, n.setString("hello"))
require.Equal(t, []string{"hello"}, d.A) require.Equal(t, []string{"hello"}, d.A)
n, err = elementAt(x, 1) n = elementAt(x, 1)
require.NoError(t, err) n.setString("world")
require.NoError(t, n.setString("world"))
require.Equal(t, []string{"hello", "world"}, d.A) require.Equal(t, []string{"hello", "world"}, d.A)
}) })
t.Run("slice of interfaces", func(t *testing.T) { t.Run("slice of interfaces", func(t *testing.T) {
t.Parallel()
type Doc struct { type Doc struct {
A []interface{} A []interface{}
} }
@@ -129,19 +149,19 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := elementAt(x, 0) n := elementAt(x, 0)
require.NoError(t, err)
require.NoError(t, setString(n, "hello")) require.NoError(t, setString(n, "hello"))
require.Equal(t, []interface{}{"hello"}, d.A) require.Equal(t, []interface{}{"hello"}, d.A)
n, err = elementAt(x, 1) n = elementAt(x, 1)
require.NoError(t, err)
require.NoError(t, setString(n, "world")) require.NoError(t, setString(n, "world"))
require.Equal(t, []interface{}{"hello", "world"}, d.A) require.Equal(t, []interface{}{"hello", "world"}, d.A)
}) })
} }
func TestScope_Struct(t *testing.T) { func TestScope_Struct(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input reflect.Value input reflect.Value
@@ -167,7 +187,10 @@ func TestScope_Struct(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
dec := decoder{} dec := decoder{}
x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name) x, found, err := dec.scopeTableTarget(false, valueTarget(e.input), e.name)
assert.Equal(t, e.found, found) assert.Equal(t, e.found, found)
+1
View File
@@ -59,6 +59,7 @@ val = string / boolean / array / inline-table / date-time / float / integer
;; String ;; String
string = ml-basic-string / basic-string / ml-literal-string / literal-string string = ml-basic-string / basic-string / ml-literal-string / literal-string
;; Basic String ;; Basic String
basic-string = quotation-mark *basic-char quotation-mark basic-string = quotation-mark *basic-char quotation-mark
+27 -13
View File
@@ -30,6 +30,7 @@ func testgenValid(t *testing.T, input string, jsonRef string) {
t.Logf("Input TOML:\n%s", input) t.Logf("Input TOML:\n%s", input)
doc := map[string]interface{}{} doc := map[string]interface{}{}
err := toml.Unmarshal([]byte(input), &doc) err := toml.Unmarshal([]byte(input), &doc)
if err != nil { if err != nil {
t.Fatalf("failed parsing toml: %s", err) t.Fatalf("failed parsing toml: %s", err)
@@ -49,25 +50,23 @@ func testgenValid(t *testing.T, input string, jsonRef string) {
require.Equal(t, refDoc, doc2) require.Equal(t, refDoc, doc2)
} }
type testGenDescNode struct {
Type string
Value interface{}
}
func testgenBuildRefDoc(jsonRef string) map[string]interface{} { func testgenBuildRefDoc(jsonRef string) map[string]interface{} {
descTree := map[string]interface{}{} descTree := map[string]interface{}{}
err := json.Unmarshal([]byte(jsonRef), &descTree) err := json.Unmarshal([]byte(jsonRef), &descTree)
if err != nil { if err != nil {
panic(fmt.Errorf("reference doc should be valid JSON: %s", err)) panic(fmt.Sprintf("reference doc should be valid JSON: %s", err))
} }
doc := testGenTranslateDesc(descTree) doc := testGenTranslateDesc(descTree)
if doc == nil { if doc == nil {
return map[string]interface{}{} return map[string]interface{}{}
} }
return doc.(map[string]interface{}) return doc.(map[string]interface{})
} }
//nolint:funlen,gocognit,cyclop
func testGenTranslateDesc(input interface{}) interface{} { func testGenTranslateDesc(input interface{}) interface{} {
a, ok := input.([]interface{}) a, ok := input.([]interface{})
if ok { if ok {
@@ -75,48 +74,61 @@ func testGenTranslateDesc(input interface{}) interface{} {
for i, v := range a { for i, v := range a {
xs[i] = testGenTranslateDesc(v) xs[i] = testGenTranslateDesc(v)
} }
return xs return xs
} }
d := input.(map[string]interface{}) d, ok := input.(map[string]interface{})
if !ok {
panic(fmt.Sprintf("input should be valid map[string]: %v", input))
}
var dtype string var (
var dvalue interface{} dtype string
dvalue interface{}
)
//nolint:nestif
if len(d) == 2 { if len(d) == 2 {
dtypeiface, ok := d["type"] dtypeiface, ok := d["type"]
if ok { if ok {
dvalue, ok = d["value"] dvalue, ok = d["value"]
if ok { if ok {
dtype = dtypeiface.(string) dtype = dtypeiface.(string)
switch dtype { switch dtype {
case "string": case "string":
return dvalue.(string) return dvalue.(string)
case "float": case "float":
v, err := strconv.ParseFloat(dvalue.(string), 64) v, err := strconv.ParseFloat(dvalue.(string), 64)
if err != nil { if err != nil {
panic(fmt.Errorf("invalid float '%s': %s", dvalue, err)) panic(fmt.Sprintf("invalid float '%s': %s", dvalue, err))
} }
return v return v
case "integer": case "integer":
v, err := strconv.ParseInt(dvalue.(string), 10, 64) v, err := strconv.ParseInt(dvalue.(string), 10, 64)
if err != nil { if err != nil {
panic(fmt.Errorf("invalid int '%s': %s", dvalue, err)) panic(fmt.Sprintf("invalid int '%s': %s", dvalue, err))
} }
return v return v
case "bool": case "bool":
return dvalue.(string) == "true" return dvalue.(string) == "true"
case "datetime": case "datetime":
dt, err := time.Parse("2006-01-02T15:04:05Z", dvalue.(string)) dt, err := time.Parse("2006-01-02T15:04:05Z", dvalue.(string))
if err != nil { if err != nil {
panic(fmt.Errorf("invalid datetime '%s': %s", dvalue, err)) panic(fmt.Sprintf("invalid datetime '%s': %s", dvalue, err))
} }
return dt return dt
case "array": case "array":
if dvalue == nil { if dvalue == nil {
return nil return nil
} }
a := dvalue.([]interface{}) a := dvalue.([]interface{})
xs := make([]interface{}, len(a)) xs := make([]interface{}, len(a))
for i, v := range a { for i, v := range a {
@@ -125,7 +137,8 @@ func testGenTranslateDesc(input interface{}) interface{} {
return xs return xs
} }
panic(fmt.Errorf("unknown type: %s", dtype))
panic(fmt.Sprintf("unknown type: %s", dtype))
} }
} }
} }
@@ -134,5 +147,6 @@ func testGenTranslateDesc(input interface{}) interface{} {
for k, v := range d { for k, v := range d {
dest[k] = testGenTranslateDesc(v) dest[k] = testGenTranslateDesc(v)
} }
return dest return dest
} }
+150 -2
View File
@@ -6,26 +6,36 @@ import (
) )
func TestInvalidDatetimeMalformedNoLeads(t *testing.T) { func TestInvalidDatetimeMalformedNoLeads(t *testing.T) {
t.Parallel()
input := `no-leads = 1987-7-05T17:45:00Z` input := `no-leads = 1987-7-05T17:45:00Z`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidDatetimeMalformedNoSecs(t *testing.T) { func TestInvalidDatetimeMalformedNoSecs(t *testing.T) {
t.Parallel()
input := `no-secs = 1987-07-05T17:45Z` input := `no-secs = 1987-07-05T17:45Z`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidDatetimeMalformedNoT(t *testing.T) { func TestInvalidDatetimeMalformedNoT(t *testing.T) {
t.Parallel()
input := `no-t = 1987-07-0517:45:00Z` input := `no-t = 1987-07-0517:45:00Z`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidDatetimeMalformedWithMilli(t *testing.T) { func TestInvalidDatetimeMalformedWithMilli(t *testing.T) {
t.Parallel()
input := `with-milli = 1987-07-5T17:45:00.12Z` input := `with-milli = 1987-07-5T17:45:00.12Z`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidDuplicateKeyTable(t *testing.T) { func TestInvalidDuplicateKeyTable(t *testing.T) {
t.Parallel()
input := `[fruit] input := `[fruit]
type = "apple" type = "apple"
@@ -35,71 +45,97 @@ apple = "yes"`
} }
func TestInvalidDuplicateKeys(t *testing.T) { func TestInvalidDuplicateKeys(t *testing.T) {
t.Parallel()
input := `dupe = false input := `dupe = false
dupe = true` dupe = true`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidDuplicateTables(t *testing.T) { func TestInvalidDuplicateTables(t *testing.T) {
t.Parallel()
input := `[a] input := `[a]
[a]` [a]`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidEmptyImplicitTable(t *testing.T) { func TestInvalidEmptyImplicitTable(t *testing.T) {
t.Parallel()
input := `[naughty..naughty]` input := `[naughty..naughty]`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidEmptyTable(t *testing.T) { func TestInvalidEmptyTable(t *testing.T) {
t.Parallel()
input := `[]` input := `[]`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidFloatNoLeadingZero(t *testing.T) { func TestInvalidFloatNoLeadingZero(t *testing.T) {
t.Parallel()
input := `answer = .12345 input := `answer = .12345
neganswer = -.12345` neganswer = -.12345`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidFloatNoTrailingDigits(t *testing.T) { func TestInvalidFloatNoTrailingDigits(t *testing.T) {
t.Parallel()
input := `answer = 1. input := `answer = 1.
neganswer = -1.` neganswer = -1.`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeyEmpty(t *testing.T) { func TestInvalidKeyEmpty(t *testing.T) {
t.Parallel()
input := ` = 1` input := ` = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeyHash(t *testing.T) { func TestInvalidKeyHash(t *testing.T) {
t.Parallel()
input := `a# = 1` input := `a# = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeyNewline(t *testing.T) { func TestInvalidKeyNewline(t *testing.T) {
t.Parallel()
input := `a input := `a
= 1` = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeyOpenBracket(t *testing.T) { func TestInvalidKeyOpenBracket(t *testing.T) {
t.Parallel()
input := `[abc = 1` input := `[abc = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeySingleOpenBracket(t *testing.T) { func TestInvalidKeySingleOpenBracket(t *testing.T) {
t.Parallel()
input := `[` input := `[`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeySpace(t *testing.T) { func TestInvalidKeySpace(t *testing.T) {
t.Parallel()
input := `a b = 1` input := `a b = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidKeyStartBracket(t *testing.T) { func TestInvalidKeyStartBracket(t *testing.T) {
t.Parallel()
input := `[a] input := `[a]
[xyz = 5 [xyz = 5
[b]` [b]`
@@ -107,31 +143,43 @@ func TestInvalidKeyStartBracket(t *testing.T) {
} }
func TestInvalidKeyTwoEquals(t *testing.T) { func TestInvalidKeyTwoEquals(t *testing.T) {
t.Parallel()
input := `key= = 1` input := `key= = 1`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidStringBadByteEscape(t *testing.T) { func TestInvalidStringBadByteEscape(t *testing.T) {
t.Parallel()
input := `naughty = "\xAg"` input := `naughty = "\xAg"`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidStringBadEscape(t *testing.T) { func TestInvalidStringBadEscape(t *testing.T) {
t.Parallel()
input := `invalid-escape = "This string has a bad \a escape character."` input := `invalid-escape = "This string has a bad \a escape character."`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidStringByteEscapes(t *testing.T) { func TestInvalidStringByteEscapes(t *testing.T) {
t.Parallel()
input := `answer = "\x33"` input := `answer = "\x33"`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidStringNoClose(t *testing.T) { func TestInvalidStringNoClose(t *testing.T) {
t.Parallel()
input := `no-ending-quote = "One time, at band camp` input := `no-ending-quote = "One time, at band camp`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableArrayImplicit(t *testing.T) { func TestInvalidTableArrayImplicit(t *testing.T) {
t.Parallel()
input := "# This test is a bit tricky. It should fail because the first use of\n" + input := "# This test is a bit tricky. It should fail because the first use of\n" +
"# `[[albums.songs]]` without first declaring `albums` implies that `albums`\n" + "# `[[albums.songs]]` without first declaring `albums` implies that `albums`\n" +
"# must be a table. The alternative would be quite weird. Namely, it wouldn't\n" + "# must be a table. The alternative would be quite weird. Namely, it wouldn't\n" +
@@ -150,46 +198,62 @@ func TestInvalidTableArrayImplicit(t *testing.T) {
} }
func TestInvalidTableArrayMalformedBracket(t *testing.T) { func TestInvalidTableArrayMalformedBracket(t *testing.T) {
t.Parallel()
input := `[[albums] input := `[[albums]
name = "Born to Run"` name = "Born to Run"`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableArrayMalformedEmpty(t *testing.T) { func TestInvalidTableArrayMalformedEmpty(t *testing.T) {
t.Parallel()
input := `[[]] input := `[[]]
name = "Born to Run"` name = "Born to Run"`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableEmpty(t *testing.T) { func TestInvalidTableEmpty(t *testing.T) {
t.Parallel()
input := `[]` input := `[]`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableNestedBracketsClose(t *testing.T) { func TestInvalidTableNestedBracketsClose(t *testing.T) {
t.Parallel()
input := `[a]b] input := `[a]b]
zyx = 42` zyx = 42`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableNestedBracketsOpen(t *testing.T) { func TestInvalidTableNestedBracketsOpen(t *testing.T) {
t.Parallel()
input := `[a[b] input := `[a[b]
zyx = 42` zyx = 42`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableWhitespace(t *testing.T) { func TestInvalidTableWhitespace(t *testing.T) {
t.Parallel()
input := `[invalid key]` input := `[invalid key]`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTableWithPound(t *testing.T) { func TestInvalidTableWithPound(t *testing.T) {
t.Parallel()
input := `[key#group] input := `[key#group]
answer = 42` answer = 42`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTextAfterArrayEntries(t *testing.T) { func TestInvalidTextAfterArrayEntries(t *testing.T) {
t.Parallel()
input := `array = [ input := `array = [
"Is there life after an array separator?", No "Is there life after an array separator?", No
"Entry" "Entry"
@@ -198,21 +262,29 @@ func TestInvalidTextAfterArrayEntries(t *testing.T) {
} }
func TestInvalidTextAfterInteger(t *testing.T) { func TestInvalidTextAfterInteger(t *testing.T) {
t.Parallel()
input := `answer = 42 the ultimate answer?` input := `answer = 42 the ultimate answer?`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTextAfterString(t *testing.T) { func TestInvalidTextAfterString(t *testing.T) {
t.Parallel()
input := `string = "Is there life after strings?" No.` input := `string = "Is there life after strings?" No.`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTextAfterTable(t *testing.T) { func TestInvalidTextAfterTable(t *testing.T) {
t.Parallel()
input := `[error] this shouldn't be here` input := `[error] this shouldn't be here`
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestInvalidTextBeforeArraySeparator(t *testing.T) { func TestInvalidTextBeforeArraySeparator(t *testing.T) {
t.Parallel()
input := `array = [ input := `array = [
"Is there life before an array separator?" No, "Is there life before an array separator?" No,
"Entry" "Entry"
@@ -221,6 +293,8 @@ func TestInvalidTextBeforeArraySeparator(t *testing.T) {
} }
func TestInvalidTextInArray(t *testing.T) { func TestInvalidTextInArray(t *testing.T) {
t.Parallel()
input := `array = [ input := `array = [
"Entry 1", "Entry 1",
I don't belong, I don't belong,
@@ -230,6 +304,8 @@ func TestInvalidTextInArray(t *testing.T) {
} }
func TestValidArrayEmpty(t *testing.T) { func TestValidArrayEmpty(t *testing.T) {
t.Parallel()
input := `thevoid = [[[[[]]]]]` input := `thevoid = [[[[[]]]]]`
jsonRef := `{ jsonRef := `{
"thevoid": { "type": "array", "value": [ "thevoid": { "type": "array", "value": [
@@ -246,6 +322,8 @@ func TestValidArrayEmpty(t *testing.T) {
} }
func TestValidArrayNospaces(t *testing.T) { func TestValidArrayNospaces(t *testing.T) {
t.Parallel()
input := `ints = [1,2,3]` input := `ints = [1,2,3]`
jsonRef := `{ jsonRef := `{
"ints": { "ints": {
@@ -261,6 +339,8 @@ func TestValidArrayNospaces(t *testing.T) {
} }
func TestValidArraysHetergeneous(t *testing.T) { func TestValidArraysHetergeneous(t *testing.T) {
t.Parallel()
input := `mixed = [[1, 2], ["a", "b"], [1.1, 2.1]]` input := `mixed = [[1, 2], ["a", "b"], [1.1, 2.1]]`
jsonRef := `{ jsonRef := `{
"mixed": { "mixed": {
@@ -285,6 +365,8 @@ func TestValidArraysHetergeneous(t *testing.T) {
} }
func TestValidArraysNested(t *testing.T) { func TestValidArraysNested(t *testing.T) {
t.Parallel()
input := `nest = [["a"], ["b"]]` input := `nest = [["a"], ["b"]]`
jsonRef := `{ jsonRef := `{
"nest": { "nest": {
@@ -303,6 +385,8 @@ func TestValidArraysNested(t *testing.T) {
} }
func TestValidArrays(t *testing.T) { func TestValidArrays(t *testing.T) {
t.Parallel()
input := `ints = [1, 2, 3] input := `ints = [1, 2, 3]
floats = [1.1, 2.1, 3.1] floats = [1.1, 2.1, 3.1]
strings = ["a", "b", "c"] strings = ["a", "b", "c"]
@@ -349,6 +433,8 @@ dates = [
} }
func TestValidBool(t *testing.T) { func TestValidBool(t *testing.T) {
t.Parallel()
input := `t = true input := `t = true
f = false` f = false`
jsonRef := `{ jsonRef := `{
@@ -359,6 +445,8 @@ f = false`
} }
func TestValidCommentsEverywhere(t *testing.T) { func TestValidCommentsEverywhere(t *testing.T) {
t.Parallel()
input := `# Top comment. input := `# Top comment.
# Top comment. # Top comment.
# Top comment. # Top comment.
@@ -368,7 +456,7 @@ func TestValidCommentsEverywhere(t *testing.T) {
[group] # Comment [group] # Comment
answer = 42 # Comment answer = 42 # Comment
# no-extraneous-keys-please = 999 # no-extraneous-keys-please = 999
# Inbetween comment. # In between comment.
more = [ # Comment more = [ # Comment
# What about multiple # comments? # What about multiple # comments?
# Can you handle it? # Can you handle it?
@@ -399,6 +487,8 @@ more = [ # Comment
} }
func TestValidDatetime(t *testing.T) { func TestValidDatetime(t *testing.T) {
t.Parallel()
input := `bestdayever = 1987-07-05T17:45:00Z` input := `bestdayever = 1987-07-05T17:45:00Z`
jsonRef := `{ jsonRef := `{
"bestdayever": {"type": "datetime", "value": "1987-07-05T17:45:00Z"} "bestdayever": {"type": "datetime", "value": "1987-07-05T17:45:00Z"}
@@ -407,12 +497,16 @@ func TestValidDatetime(t *testing.T) {
} }
func TestValidEmpty(t *testing.T) { func TestValidEmpty(t *testing.T) {
t.Parallel()
input := `` input := ``
jsonRef := `{}` jsonRef := `{}`
testgenValid(t, input, jsonRef) testgenValid(t, input, jsonRef)
} }
func TestValidExample(t *testing.T) { func TestValidExample(t *testing.T) {
t.Parallel()
input := `best-day-ever = 1987-07-05T17:45:00Z input := `best-day-ever = 1987-07-05T17:45:00Z
[numtheory] [numtheory]
@@ -436,6 +530,8 @@ perfection = [6, 28, 496]`
} }
func TestValidFloat(t *testing.T) { func TestValidFloat(t *testing.T) {
t.Parallel()
input := `pi = 3.14 input := `pi = 3.14
negpi = -3.14` negpi = -3.14`
jsonRef := `{ jsonRef := `{
@@ -446,6 +542,8 @@ negpi = -3.14`
} }
func TestValidImplicitAndExplicitAfter(t *testing.T) { func TestValidImplicitAndExplicitAfter(t *testing.T) {
t.Parallel()
input := `[a.b.c] input := `[a.b.c]
answer = 42 answer = 42
@@ -465,6 +563,8 @@ better = 43`
} }
func TestValidImplicitAndExplicitBefore(t *testing.T) { func TestValidImplicitAndExplicitBefore(t *testing.T) {
t.Parallel()
input := `[a] input := `[a]
better = 43 better = 43
@@ -484,6 +584,8 @@ answer = 42`
} }
func TestValidImplicitGroups(t *testing.T) { func TestValidImplicitGroups(t *testing.T) {
t.Parallel()
input := `[a.b.c] input := `[a.b.c]
answer = 42` answer = 42`
jsonRef := `{ jsonRef := `{
@@ -499,6 +601,8 @@ answer = 42`
} }
func TestValidInteger(t *testing.T) { func TestValidInteger(t *testing.T) {
t.Parallel()
input := `answer = 42 input := `answer = 42
neganswer = -42` neganswer = -42`
jsonRef := `{ jsonRef := `{
@@ -509,6 +613,8 @@ neganswer = -42`
} }
func TestValidKeyEqualsNospace(t *testing.T) { func TestValidKeyEqualsNospace(t *testing.T) {
t.Parallel()
input := `answer=42` input := `answer=42`
jsonRef := `{ jsonRef := `{
"answer": {"type": "integer", "value": "42"} "answer": {"type": "integer", "value": "42"}
@@ -517,6 +623,8 @@ func TestValidKeyEqualsNospace(t *testing.T) {
} }
func TestValidKeySpace(t *testing.T) { func TestValidKeySpace(t *testing.T) {
t.Parallel()
input := `"a b" = 1` input := `"a b" = 1`
jsonRef := `{ jsonRef := `{
"a b": {"type": "integer", "value": "1"} "a b": {"type": "integer", "value": "1"}
@@ -525,6 +633,8 @@ func TestValidKeySpace(t *testing.T) {
} }
func TestValidKeySpecialChars(t *testing.T) { func TestValidKeySpecialChars(t *testing.T) {
t.Parallel()
input := "\"~!@$^&*()_+-`1234567890[]|/?><.,;:'\" = 1\n" input := "\"~!@$^&*()_+-`1234567890[]|/?><.,;:'\" = 1\n"
jsonRef := "{\n" + jsonRef := "{\n" +
" \"~!@$^&*()_+-`1234567890[]|/?><.,;:'\": {\n" + " \"~!@$^&*()_+-`1234567890[]|/?><.,;:'\": {\n" +
@@ -535,6 +645,8 @@ func TestValidKeySpecialChars(t *testing.T) {
} }
func TestValidLongFloat(t *testing.T) { func TestValidLongFloat(t *testing.T) {
t.Parallel()
input := `longpi = 3.141592653589793 input := `longpi = 3.141592653589793
neglongpi = -3.141592653589793` neglongpi = -3.141592653589793`
jsonRef := `{ jsonRef := `{
@@ -545,6 +657,8 @@ neglongpi = -3.141592653589793`
} }
func TestValidLongInteger(t *testing.T) { func TestValidLongInteger(t *testing.T) {
t.Parallel()
input := `answer = 9223372036854775807 input := `answer = 9223372036854775807
neganswer = -9223372036854775808` neganswer = -9223372036854775808`
jsonRef := `{ jsonRef := `{
@@ -555,6 +669,8 @@ neganswer = -9223372036854775808`
} }
func TestValidMultilineString(t *testing.T) { func TestValidMultilineString(t *testing.T) {
t.Parallel()
input := `multiline_empty_one = """""" input := `multiline_empty_one = """"""
multiline_empty_two = """ multiline_empty_two = """
""" """
@@ -612,6 +728,8 @@ equivalent_three = """\
} }
func TestValidRawMultilineString(t *testing.T) { func TestValidRawMultilineString(t *testing.T) {
t.Parallel()
input := `oneline = '''This string has a ' quote character.''' input := `oneline = '''This string has a ' quote character.'''
firstnl = ''' firstnl = '''
This string has a ' quote character.''' This string has a ' quote character.'''
@@ -639,6 +757,8 @@ in it.'''`
} }
func TestValidRawString(t *testing.T) { func TestValidRawString(t *testing.T) {
t.Parallel()
input := `backspace = 'This string has a \b backspace character.' input := `backspace = 'This string has a \b backspace character.'
tab = 'This string has a \t tab character.' tab = 'This string has a \t tab character.'
newline = 'This string has a \n new line character.' newline = 'This string has a \n new line character.'
@@ -680,6 +800,8 @@ backslash = 'This string has a \\ backslash character.'`
} }
func TestValidStringEmpty(t *testing.T) { func TestValidStringEmpty(t *testing.T) {
t.Parallel()
input := `answer = ""` input := `answer = ""`
jsonRef := `{ jsonRef := `{
"answer": { "answer": {
@@ -691,6 +813,8 @@ func TestValidStringEmpty(t *testing.T) {
} }
func TestValidStringEscapes(t *testing.T) { func TestValidStringEscapes(t *testing.T) {
t.Parallel()
input := `backspace = "This string has a \b backspace character." input := `backspace = "This string has a \b backspace character."
tab = "This string has a \t tab character." tab = "This string has a \t tab character."
newline = "This string has a \n new line character." newline = "This string has a \n new line character."
@@ -752,6 +876,8 @@ notunicode4 = "This string does not have a unicode \\\u0075 escape."`
} }
func TestValidStringSimple(t *testing.T) { func TestValidStringSimple(t *testing.T) {
t.Parallel()
input := `answer = "You are not drinking enough whisky."` input := `answer = "You are not drinking enough whisky."`
jsonRef := `{ jsonRef := `{
"answer": { "answer": {
@@ -763,6 +889,8 @@ func TestValidStringSimple(t *testing.T) {
} }
func TestValidStringWithPound(t *testing.T) { func TestValidStringWithPound(t *testing.T) {
t.Parallel()
input := `pound = "We see no # comments here." input := `pound = "We see no # comments here."
poundcomment = "But there are # some comments here." # Did I # mess you up?` poundcomment = "But there are # some comments here." # Did I # mess you up?`
jsonRef := `{ jsonRef := `{
@@ -776,6 +904,8 @@ poundcomment = "But there are # some comments here." # Did I # mess you up?`
} }
func TestValidTableArrayImplicit(t *testing.T) { func TestValidTableArrayImplicit(t *testing.T) {
t.Parallel()
input := `[[albums.songs]] input := `[[albums.songs]]
name = "Glory Days"` name = "Glory Days"`
jsonRef := `{ jsonRef := `{
@@ -789,6 +919,8 @@ name = "Glory Days"`
} }
func TestValidTableArrayMany(t *testing.T) { func TestValidTableArrayMany(t *testing.T) {
t.Parallel()
input := `[[people]] input := `[[people]]
first_name = "Bruce" first_name = "Bruce"
last_name = "Springsteen" last_name = "Springsteen"
@@ -820,6 +952,8 @@ last_name = "Seger"`
} }
func TestValidTableArrayNest(t *testing.T) { func TestValidTableArrayNest(t *testing.T) {
t.Parallel()
input := `[[albums]] input := `[[albums]]
name = "Born to Run" name = "Born to Run"
@@ -831,7 +965,7 @@ name = "Born to Run"
[[albums]] [[albums]]
name = "Born in the USA" name = "Born in the USA"
[[albums.songs]] [[albums.songs]]
name = "Glory Days" name = "Glory Days"
@@ -859,6 +993,8 @@ name = "Born in the USA"
} }
func TestValidTableArrayOne(t *testing.T) { func TestValidTableArrayOne(t *testing.T) {
t.Parallel()
input := `[[people]] input := `[[people]]
first_name = "Bruce" first_name = "Bruce"
last_name = "Springsteen"` last_name = "Springsteen"`
@@ -874,6 +1010,8 @@ last_name = "Springsteen"`
} }
func TestValidTableEmpty(t *testing.T) { func TestValidTableEmpty(t *testing.T) {
t.Parallel()
input := `[a]` input := `[a]`
jsonRef := `{ jsonRef := `{
"a": {} "a": {}
@@ -882,6 +1020,8 @@ func TestValidTableEmpty(t *testing.T) {
} }
func TestValidTableSubEmpty(t *testing.T) { func TestValidTableSubEmpty(t *testing.T) {
t.Parallel()
input := `[a] input := `[a]
[a.b]` [a.b]`
jsonRef := `{ jsonRef := `{
@@ -891,6 +1031,8 @@ func TestValidTableSubEmpty(t *testing.T) {
} }
func TestValidTableWhitespace(t *testing.T) { func TestValidTableWhitespace(t *testing.T) {
t.Parallel()
input := `["valid key"]` input := `["valid key"]`
jsonRef := `{ jsonRef := `{
"valid key": {} "valid key": {}
@@ -899,6 +1041,8 @@ func TestValidTableWhitespace(t *testing.T) {
} }
func TestValidTableWithPound(t *testing.T) { func TestValidTableWithPound(t *testing.T) {
t.Parallel()
input := `["key#group"] input := `["key#group"]
answer = 42` answer = 42`
jsonRef := `{ jsonRef := `{
@@ -910,6 +1054,8 @@ answer = 42`
} }
func TestValidUnicodeEscape(t *testing.T) { func TestValidUnicodeEscape(t *testing.T) {
t.Parallel()
input := `answer4 = "\u03B4" input := `answer4 = "\u03B4"
answer8 = "\U000003B4"` answer8 = "\U000003B4"`
jsonRef := `{ jsonRef := `{
@@ -920,6 +1066,8 @@ answer8 = "\U000003B4"`
} }
func TestValidUnicodeLiteral(t *testing.T) { func TestValidUnicodeLiteral(t *testing.T) {
t.Parallel()
input := `answer = "δ"` input := `answer = "δ"`
jsonRef := `{ jsonRef := `{
"answer": {"type": "string", "value": "δ"} "answer": {"type": "string", "value": "δ"}
+202 -56
View File
@@ -2,6 +2,7 @@ package toml
import ( import (
"encoding" "encoding"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@@ -10,18 +11,27 @@ import (
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/internal/unsafe"
) )
// Unmarshal deserializes a TOML document into a Go value.
//
// It is a shortcut for Decoder.Decode() with the default options.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
p := parser{} p := parser{}
p.Reset(data) p.Reset(data)
d := decoder{} d := decoder{}
return d.FromParser(&p, v) return d.FromParser(&p, v)
} }
// Decoder reads and decode a TOML document from an input stream. // Decoder reads and decode a TOML document from an input stream.
type Decoder struct { type Decoder struct {
// input
r io.Reader r io.Reader
// global settings
strict bool
} }
// NewDecoder creates a new Decoder that will read from r. // NewDecoder creates a new Decoder that will read from r.
@@ -29,21 +39,65 @@ func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r} return &Decoder{r: r}
} }
// SetStrict toggles decoding in stict mode.
//
// When the decoder is in strict mode, it will record fields from the document
// that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) {
d.strict = strict
}
// Decode the whole content of r into v. // Decode the whole content of r into v.
// //
// When a TOML local date is decoded into a time.Time, its value is represented // By default, values in the document that don't exist in the target Go value
// in time.Local timezone. // are ignored. See Decoder.SetStrict() to change this behavior.
//
// When a TOML local date, time, or date-time is decoded into a time.Time, its
// value is represented in time.Local timezone. Otherwise the approriate Local*
// structure is used.
// //
// Empty tables decoded in an interface{} create an empty initialized // Empty tables decoded in an interface{} create an empty initialized
// map[string]interface{}. // map[string]interface{}.
//
// Types implementing the encoding.TextUnmarshaler interface are decoded from a
// TOML string.
//
// When decoding a number, go-toml will return an error if the number is out of
// bounds for the target type (which includes negative numbers when decoding
// into an unsigned int).
//
// Type mapping
//
// List of supported TOML types and their associated accepted Go types:
//
// String -> string
// Integer -> uint*, int*, depending on size
// Float -> float*, depending on size
// Boolean -> bool
// Offset Date-Time -> time.Time
// Local Date-time -> LocalDateTime, time.Time
// Local Date -> LocalDate, time.Time
// Local Time -> LocalTime, time.Time
// Array -> slice and array, depending on elements types
// Table -> map and struct
// Inline Table -> same as Table
// Array of Tables -> same as Array and Table
func (d *Decoder) Decode(v interface{}) error { func (d *Decoder) Decode(v interface{}) error {
b, err := ioutil.ReadAll(d.r) b, err := ioutil.ReadAll(d.r)
if err != nil { if err != nil {
return err return fmt.Errorf("toml: %w", err)
} }
p := parser{} p := parser{}
p.Reset(b) p.Reset(b)
dec := decoder{} dec := decoder{
strict: strict{
Enabled: d.strict,
},
}
return dec.FromParser(&p, v) return dec.FromParser(&p, v)
} }
@@ -52,10 +106,13 @@ type decoder struct {
arrayIndexes map[reflect.Value]int arrayIndexes map[reflect.Value]int
// Tracks keys that have been seen, with which type. // Tracks keys that have been seen, with which type.
seen tracker.Seen seen tracker.SeenTracker
// Strict mode
strict strict
} }
func (d *decoder) arrayIndex(append bool, v reflect.Value) int { func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int {
if d.arrayIndexes == nil { if d.arrayIndexes == nil {
d.arrayIndexes = make(map[reflect.Value]int, 1) d.arrayIndexes = make(map[reflect.Value]int, 1)
} }
@@ -64,35 +121,62 @@ func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
if !ok { if !ok {
d.arrayIndexes[v] = 0 d.arrayIndexes[v] = 0
} else if append { } else if shouldAppend {
idx++ idx++
d.arrayIndexes[v] = idx d.arrayIndexes[v] = idx
} }
return idx return idx
} }
func (d *decoder) FromParser(p *parser, v interface{}) error { func (d *decoder) FromParser(p *parser, v interface{}) error {
err := d.fromParser(p, v) err := d.fromParser(p, v)
if err != nil { if err == nil {
de, ok := err.(*decodeError) return d.strict.Error(p.data)
if ok {
err = wrapDecodeError(p.data, de)
}
} }
var e *decodeError
if errors.As(err, &e) {
return wrapDecodeError(p.data, e)
}
return err return err
} }
func keyLocation(node ast.Node) []byte {
k := node.Key()
hasOne := k.Next()
if !hasOne {
panic("should not be called with empty key")
}
start := k.Node().Data
end := k.Node().Data
for k.Next() {
end = k.Node().Data
}
return unsafe.BytesRange(start, end)
}
//nolint:funlen,cyclop
func (d *decoder) fromParser(p *parser, v interface{}) error { func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v) r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr { if r.Kind() != reflect.Ptr {
return fmt.Errorf("need to target a pointer, not %s", r.Kind()) return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
}
if r.IsNil() {
return fmt.Errorf("target pointer must be non-nil")
} }
var skipUntilTable bool if r.IsNil() {
var root target = valueTarget(r.Elem()) return fmt.Errorf("toml: decoding pointer target cannot be nil")
}
var (
skipUntilTable bool
root target = valueTarget(r.Elem())
)
current := root current := root
for p.NextExpression() { for p.NextExpression() {
@@ -108,11 +192,14 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
} }
var found bool var found bool
switch node.Kind { switch node.Kind {
case ast.KeyValue: case ast.KeyValue:
err = d.unmarshalKeyValue(current, node) err = d.unmarshalKeyValue(current, node)
found = true found = true
case ast.Table: case ast.Table:
d.strict.EnterTable(node)
current, found, err = d.scopeWithKey(root, node.Key()) current, found, err = d.scopeWithKey(root, node.Key())
if err == nil && found { if err == nil && found {
// In case this table points to an interface, // In case this table points to an interface,
@@ -123,9 +210,10 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
ensureMapIfInterface(current) ensureMapIfInterface(current)
} }
case ast.ArrayTable: case ast.ArrayTable:
d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key()) current, found, err = d.scopeWithArrayTable(root, node.Key())
default: default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind))
} }
if err != nil { if err != nil {
@@ -134,6 +222,8 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
if !found { if !found {
skipUntilTable = true skipUntilTable = true
d.strict.MissingTable(node)
} }
} }
@@ -149,38 +239,49 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
// When encountering slices, it should always use its last element, and error // When encountering slices, it should always use its last element, and error
// if the slice does not have any. // if the slice does not have any.
func (d *decoder) scopeWithKey(x target, key ast.Iterator) (target, bool, error) { func (d *decoder) scopeWithKey(x target, key ast.Iterator) (target, bool, error) {
var err error var (
found := true err error
found bool
)
for key.Next() { for key.Next() {
n := key.Node() n := key.Node()
x, found, err = d.scopeTableTarget(false, x, string(n.Data)) x, found, err = d.scopeTableTarget(false, x, string(n.Data))
if err != nil || !found { if err != nil || !found {
return nil, found, err return nil, found, err
} }
} }
return x, true, nil return x, true, nil
} }
//nolint:cyclop
// scopeWithArrayTable performs target scoping when unmarshaling an // scopeWithArrayTable performs target scoping when unmarshaling an
// ast.ArrayTable node. // ast.ArrayTable node.
// //
// It is the same as scopeWithKey, but when scoping the last part of the key // It is the same as scopeWithKey, but when scoping the last part of the key
// it creates a new element in the array instead of using the last one. // it creates a new element in the array instead of using the last one.
func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool, error) { func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool, error) {
var err error var (
found := true err error
found bool
)
for key.Next() { for key.Next() {
n := key.Node() n := key.Node()
if !n.Next().Valid() { // want to stop at one before last if !n.Next().Valid() { // want to stop at one before last
break break
} }
x, found, err = d.scopeTableTarget(false, x, string(n.Data)) x, found, err = d.scopeTableTarget(false, x, string(n.Data))
if err != nil || !found { if err != nil || !found {
return nil, found, err return nil, found, err
} }
} }
n := key.Node() n := key.Node()
x, found, err = d.scopeTableTarget(false, x, string(n.Data)) x, found, err = d.scopeTableTarget(false, x, string(n.Data))
if err != nil || !found { if err != nil || !found {
return x, found, err return x, found, err
@@ -189,26 +290,21 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
v := x.get() v := x.get()
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
x, err = scopePtr(x) x = scopePtr(x)
if err != nil {
return x, false, err
}
v = x.get() v = x.get()
} }
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
x, err = scopeInterface(true, x) x = scopeInterface(true, x)
if err != nil {
return x, found, err
}
v = x.get() v = x.get()
} }
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
x, err = scopeSlice(true, x) x = scopeSlice(true, x)
case reflect.Array: case reflect.Array:
x, err = d.scopeArray(true, x) x, err = d.scopeArray(true, x)
default:
} }
return x, found, err return x, found, err
@@ -217,6 +313,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error { func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
assertNode(ast.KeyValue, node) assertNode(ast.KeyValue, node)
d.strict.EnterKeyValue(node)
defer d.strict.ExitKeyValue(node)
x, found, err := d.scopeWithKey(x, node.Key()) x, found, err := d.scopeWithKey(x, node.Key())
if err != nil { if err != nil {
return err return err
@@ -224,6 +323,8 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
// A struct in the path was not found. Skip this value. // A struct in the path was not found. Skip this value.
if !found { if !found {
d.strict.MissingField(node)
return nil return nil
} }
@@ -239,31 +340,43 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
return false, nil return false, nil
} }
// Special case for time, becase we allow to unmarshal to it from // Special case for time, because we allow to unmarshal to it from
// different kind of AST nodes. // different kind of AST nodes.
if v.Type() == timeType { if v.Type() == timeType {
return false, nil return false, nil
} }
if v.Type().Implements(textUnmarshalerType) { if v.Type().Implements(textUnmarshalerType) {
return true, v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil {
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
}
return true, nil
} }
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
return true, v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil {
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
}
return true, nil
} }
return false, nil return false, nil
} }
//nolint:cyclop
func (d *decoder) unmarshalValue(x target, node ast.Node) error { func (d *decoder) unmarshalValue(x target, node ast.Node) error {
v := x.get() v := x.get()
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if !v.Elem().IsValid() { if !v.Elem().IsValid() {
err := x.set(reflect.New(v.Type().Elem())) x.set(reflect.New(v.Type().Elem()))
if err != nil {
return err
}
v = x.get() v = x.get()
} }
return d.unmarshalValue(valueTarget(v.Elem()), node) return d.unmarshalValue(valueTarget(v.Elem()), node)
} }
@@ -292,85 +405,113 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error {
case ast.LocalDate: case ast.LocalDate:
return unmarshalLocalDate(x, node) return unmarshalLocalDate(x, node)
default: default:
panic(fmt.Errorf("unhandled unmarshalValue kind %s", node.Kind)) panic(fmt.Sprintf("unhandled node kind %s", node.Kind))
} }
} }
func unmarshalLocalDate(x target, node ast.Node) error { func unmarshalLocalDate(x target, node ast.Node) error {
assertNode(ast.LocalDate, node) assertNode(ast.LocalDate, node)
v, err := parseLocalDate(node.Data) v, err := parseLocalDate(node.Data)
if err != nil { if err != nil {
return err return err
} }
return setDate(x, v)
setDate(x, v)
return nil
} }
func unmarshalLocalDateTime(x target, node ast.Node) error { func unmarshalLocalDateTime(x target, node ast.Node) error {
assertNode(ast.LocalDateTime, node) assertNode(ast.LocalDateTime, node)
v, rest, err := parseLocalDateTime(node.Data) v, rest, err := parseLocalDateTime(node.Data)
if err != nil { if err != nil {
return err return err
} }
if len(rest) > 0 { if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local date time") return newDecodeError(rest, "extra characters at the end of a local date time")
} }
return setLocalDateTime(x, v)
setLocalDateTime(x, v)
return nil
} }
func unmarshalDateTime(x target, node ast.Node) error { func unmarshalDateTime(x target, node ast.Node) error {
assertNode(ast.DateTime, node) assertNode(ast.DateTime, node)
v, err := parseDateTime(node.Data) v, err := parseDateTime(node.Data)
if err != nil { if err != nil {
return err return err
} }
return setDateTime(x, v)
setDateTime(x, v)
return nil
} }
func setLocalDateTime(x target, v LocalDateTime) error { func setLocalDateTime(x target, v LocalDateTime) {
return x.set(reflect.ValueOf(v)) if x.get().Type() == timeType {
cast := v.In(time.Local)
setDateTime(x, cast)
return
}
x.set(reflect.ValueOf(v))
} }
func setDateTime(x target, v time.Time) error { func setDateTime(x target, v time.Time) {
return x.set(reflect.ValueOf(v)) x.set(reflect.ValueOf(v))
} }
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf(time.Time{})
func setDate(x target, v LocalDate) error { func setDate(x target, v LocalDate) {
if x.get().Type() == timeType { if x.get().Type() == timeType {
cast := v.In(time.Local) cast := v.In(time.Local)
return setDateTime(x, cast)
setDateTime(x, cast)
return
} }
return x.set(reflect.ValueOf(v)) x.set(reflect.ValueOf(v))
} }
func unmarshalString(x target, node ast.Node) error { func unmarshalString(x target, node ast.Node) error {
assertNode(ast.String, node) assertNode(ast.String, node)
return setString(x, string(node.Data)) return setString(x, string(node.Data))
} }
func unmarshalBool(x target, node ast.Node) error { func unmarshalBool(x target, node ast.Node) error {
assertNode(ast.Bool, node) assertNode(ast.Bool, node)
v := node.Data[0] == 't' v := node.Data[0] == 't'
return setBool(x, v) return setBool(x, v)
} }
func unmarshalInteger(x target, node ast.Node) error { func unmarshalInteger(x target, node ast.Node) error {
assertNode(ast.Integer, node) assertNode(ast.Integer, node)
v, err := parseInteger(node.Data) v, err := parseInteger(node.Data)
if err != nil { if err != nil {
return err return err
} }
return setInt64(x, v) return setInt64(x, v)
} }
func unmarshalFloat(x target, node ast.Node) error { func unmarshalFloat(x target, node ast.Node) error {
assertNode(ast.Float, node) assertNode(ast.Float, node)
v, err := parseFloat(node.Data) v, err := parseFloat(node.Data)
if err != nil { if err != nil {
return err return err
} }
return setFloat64(x, v) return setFloat64(x, v)
} }
@@ -382,11 +523,13 @@ func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
err := d.unmarshalKeyValue(x, n) err := d.unmarshalKeyValue(x, n)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }
@@ -398,30 +541,33 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error {
return err return err
} }
it := node.Children()
idx := 0 idx := 0
it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
v, err := elementAt(x, idx)
if err != nil { v := elementAt(x, idx)
return err
}
if v == nil { if v == nil {
// when we go out of bound for an array just stop processing it to // when we go out of bound for an array just stop processing it to
// mimic encoding/json // mimic encoding/json
break break
} }
err = d.unmarshalValue(v, n) err = d.unmarshalValue(v, n)
if err != nil { if err != nil {
return err return err
} }
idx++ idx++
} }
return nil return nil
} }
func assertNode(expected ast.Kind, node ast.Node) { func assertNode(expected ast.Kind, node ast.Node) {
if node.Kind != expected { if node.Kind != expected {
panic(fmt.Errorf("expected node of kind %s, not %s", expected, node.Kind)) panic(fmt.Sprintf("expected node of kind %s, not %s", expected, node.Kind))
} }
} }
+441 -16
View File
@@ -1,8 +1,11 @@
package toml_test package toml_test
import ( import (
"errors"
"fmt"
"math" "math"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -12,6 +15,8 @@ import (
) )
func TestUnmarshal_Integers(t *testing.T) { func TestUnmarshal_Integers(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input string input string
@@ -33,6 +38,11 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `+99`, input: `+99`,
expected: 99, expected: 99,
}, },
{
desc: "integer decimal underscore",
input: `123_456`,
expected: 123456,
},
{ {
desc: "integer hex uppercase", desc: "integer hex uppercase",
input: `0xDEADBEEF`, input: `0xDEADBEEF`,
@@ -53,6 +63,21 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `0b11010110`, input: `0b11010110`,
expected: 0b11010110, expected: 0b11010110,
}, },
{
desc: "double underscore",
input: "12__3",
err: true,
},
{
desc: "starts with underscore",
input: "_1",
err: true,
},
{
desc: "ends with underscore",
input: "1_",
err: true,
},
} }
type doc struct { type doc struct {
@@ -60,16 +85,26 @@ func TestUnmarshal_Integers(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
doc := doc{} doc := doc{}
err := toml.Unmarshal([]byte(`A = `+e.input), &doc) err := toml.Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err) if e.err {
assert.Equal(t, e.expected, doc.A) require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, e.expected, doc.A)
}
}) })
} }
} }
//nolint:funlen
func TestUnmarshal_Floats(t *testing.T) { func TestUnmarshal_Floats(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
input string input string
@@ -132,6 +167,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan", desc: "nan",
input: `nan`, input: `nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -139,6 +175,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan negative", desc: "nan negative",
input: `-nan`, input: `-nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -146,6 +183,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan positive", desc: "nan positive",
input: `+nan`, input: `+nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -156,7 +194,10 @@ func TestUnmarshal_Floats(t *testing.T) {
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
doc := doc{} doc := doc{}
err := toml.Unmarshal([]byte(`A = `+e.input), &doc) err := toml.Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err) require.NoError(t, err)
@@ -169,7 +210,10 @@ func TestUnmarshal_Floats(t *testing.T) {
} }
} }
//nolint:funlen
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
t.Parallel()
type test struct { type test struct {
target interface{} target interface{}
expected interface{} expected interface{}
@@ -188,6 +232,7 @@ func TestUnmarshal(t *testing.T) {
type doc struct { type doc struct {
A string A string
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{A: "foo"}, expected: &doc{A: "foo"},
@@ -200,6 +245,7 @@ func TestUnmarshal(t *testing.T) {
fruit . flavor = "banana"`, fruit . flavor = "banana"`,
gen: func() test { gen: func() test {
m := map[string]interface{}{} m := map[string]interface{}{}
return test{ return test{
target: &m, target: &m,
expected: &map[string]interface{}{ expected: &map[string]interface{}{
@@ -217,6 +263,7 @@ func TestUnmarshal(t *testing.T) {
"\"b\"" = 2`, "\"b\"" = 2`,
gen: func() test { gen: func() test {
m := map[string]interface{}{} m := map[string]interface{}{}
return test{ return test{
target: &m, target: &m,
expected: &map[string]interface{}{ expected: &map[string]interface{}{
@@ -234,6 +281,7 @@ func TestUnmarshal(t *testing.T) {
type doc struct { type doc struct {
A string A string
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{A: "Test"}, expected: &doc{A: "Test"},
@@ -247,6 +295,7 @@ func TestUnmarshal(t *testing.T) {
type doc struct { type doc struct {
A bool A bool
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{A: true}, expected: &doc{A: true},
@@ -260,6 +309,7 @@ func TestUnmarshal(t *testing.T) {
type doc struct { type doc struct {
A bool A bool
} }
return test{ return test{
target: &doc{A: true}, target: &doc{A: true},
expected: &doc{A: false}, expected: &doc{A: false},
@@ -273,6 +323,7 @@ func TestUnmarshal(t *testing.T) {
type doc struct { type doc struct {
A []string A []string
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{A: []string{"foo", "bar"}}, expected: &doc{A: []string{"foo", "bar"}},
@@ -290,6 +341,7 @@ B = "data"`,
type doc struct { type doc struct {
A A A A
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{A: A{B: "data"}}, expected: &doc{A: A{B: "data"}},
@@ -301,6 +353,7 @@ B = "data"`,
input: `[A]`, input: `[A]`,
gen: func() test { gen: func() test {
var v map[string]interface{} var v map[string]interface{}
return test{ return test{
target: &v, target: &v,
expected: &map[string]interface{}{`A`: map[string]interface{}{}}, expected: &map[string]interface{}{`A`: map[string]interface{}{}},
@@ -318,6 +371,7 @@ B = "data"`,
type doc struct { type doc struct {
Name name Name name
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{Name: name{ expected: &doc{Name: name{
@@ -332,6 +386,7 @@ B = "data"`,
input: `A = {}`, input: `A = {}`,
gen: func() test { gen: func() test {
var v map[string]interface{} var v map[string]interface{}
return test{ return test{
target: &v, target: &v,
expected: &map[string]interface{}{`A`: map[string]interface{}{}}, expected: &map[string]interface{}{`A`: map[string]interface{}{}},
@@ -349,6 +404,7 @@ B = "data"`,
type doc struct { type doc struct {
Names []name Names []name
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{ expected: &doc{
@@ -371,6 +427,7 @@ B = "data"`,
input: `A = "foo"`, input: `A = "foo"`,
gen: func() test { gen: func() test {
doc := map[string]interface{}{} doc := map[string]interface{}{}
return test{ return test{
target: &doc, target: &doc,
expected: &map[string]interface{}{ expected: &map[string]interface{}{
@@ -385,6 +442,7 @@ B = "data"`,
B = 42`, B = 42`,
gen: func() test { gen: func() test {
doc := map[string]interface{}{} doc := map[string]interface{}{}
return test{ return test{
target: &doc, target: &doc,
expected: &map[string]interface{}{ expected: &map[string]interface{}{
@@ -399,6 +457,7 @@ B = "data"`,
input: `A = ["foo", "bar"]`, input: `A = ["foo", "bar"]`,
gen: func() test { gen: func() test {
doc := map[string]interface{}{} doc := map[string]interface{}{}
return test{ return test{
target: &doc, target: &doc,
expected: &map[string]interface{}{ expected: &map[string]interface{}{
@@ -412,6 +471,7 @@ B = "data"`,
input: `A = "foo"`, input: `A = "foo"`,
gen: func() test { gen: func() test {
doc := map[string]string{} doc := map[string]string{}
return test{ return test{
target: &doc, target: &doc,
expected: &map[string]string{ expected: &map[string]string{
@@ -425,6 +485,7 @@ B = "data"`,
input: `A = 42.0`, input: `A = 42.0`,
gen: func() test { gen: func() test {
doc := map[string]string{} doc := map[string]string{}
return test{ return test{
target: &doc, target: &doc,
err: true, err: true,
@@ -442,6 +503,7 @@ B = "data"`,
type Doc struct { type Doc struct {
First []First First []First
} }
return test{ return test{
target: &Doc{}, target: &Doc{},
expected: &Doc{ expected: &Doc{
@@ -459,13 +521,13 @@ B = "data"`,
input: `[[Products]] input: `[[Products]]
Name = "Hammer" Name = "Hammer"
Sku = 738594937 Sku = 738594937
[[Products]] # empty table within the array [[Products]] # empty table within the array
[[Products]] [[Products]]
Name = "Nail" Name = "Nail"
Sku = 284758393 Sku = 284758393
Color = "gray"`, Color = "gray"`,
gen: func() test { gen: func() test {
type Product struct { type Product struct {
@@ -476,6 +538,7 @@ B = "data"`,
type Doc struct { type Doc struct {
Products []Product Products []Product
} }
return test{ return test{
target: &Doc{}, target: &Doc{},
expected: &Doc{ expected: &Doc{
@@ -493,13 +556,13 @@ B = "data"`,
input: `[[Products]] input: `[[Products]]
Name = "Hammer" Name = "Hammer"
Sku = 738594937 Sku = 738594937
[[Products]] # empty table within the array [[Products]] # empty table within the array
[[Products]] [[Products]]
Name = "Nail" Name = "Nail"
Sku = 284758393 Sku = 284758393
Color = "gray"`, Color = "gray"`,
gen: func() test { gen: func() test {
return test{ return test{
@@ -649,6 +712,7 @@ B = "data"`,
A *[]*string A *[]*string
} }
hello := "Hello" hello := "Hello"
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{ expected: &doc{
@@ -668,6 +732,7 @@ B = "data"`,
type doc struct { type doc struct {
A interface{} A interface{}
} }
return test{ return test{
target: &doc{ target: &doc{
A: inner{ A: inner{
@@ -695,6 +760,7 @@ B = "data"`,
type doc struct { type doc struct {
A [4]inner A [4]inner
} }
return test{ return test{
target: &doc{}, target: &doc{},
expected: &doc{ expected: &doc{
@@ -706,10 +772,91 @@ B = "data"`,
} }
}, },
}, },
{
desc: "windows line endings",
input: "A = 1\r\n\r\nB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
expected: &map[string]interface{}{
"A": int64(1),
"B": int64(2),
},
}
},
},
{
desc: "dangling CR",
input: "A = 1\r",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
{
desc: "missing NL after CR",
input: "A = 1\rB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
{
desc: "no newline (#526)",
input: `a = 1z = 2`,
gen: func() test {
m := map[string]interface{}{}
return test{
target: &m,
err: true,
}
},
},
{
desc: "mismatch types int to string",
input: `A = 42`,
gen: func() test {
type S struct {
A string
}
return test{
target: &S{},
err: true,
}
},
},
{
desc: "mismatch types array of int to interface with non-slice",
input: `A = [[42]]`,
skip: true,
gen: func() test {
type S struct {
A *string
}
return test{
target: &S{},
expected: &S{},
}
},
},
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
if e.skip { if e.skip {
t.Skip() t.Skip()
} }
@@ -719,6 +866,9 @@ B = "data"`,
} }
err := toml.Unmarshal([]byte(e.input), test.target) err := toml.Unmarshal([]byte(e.input), test.target)
if test.err { if test.err {
if err == nil {
t.Log("=>", test.target)
}
require.Error(t, err) require.Error(t, err)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
@@ -739,9 +889,10 @@ func (i Integer484) MarshalText() ([]byte, error) {
func (i *Integer484) UnmarshalText(data []byte) error { func (i *Integer484) UnmarshalText(data []byte) error {
conv, err := strconv.Atoi(string(data)) conv, err := strconv.Atoi(string(data))
if err != nil { if err != nil {
return err return fmt.Errorf("UnmarshalText: %w", err)
} }
i.Value = conv i.Value = conv
return nil return nil
} }
@@ -750,7 +901,10 @@ type Config484 struct {
} }
func TestIssue484(t *testing.T) { func TestIssue484(t *testing.T) {
t.Parallel()
raw := []byte(`integers = ["1","2","3","100"]`) raw := []byte(`integers = ["1","2","3","100"]`)
var cfg Config484 var cfg Config484
err := toml.Unmarshal(raw, &cfg) err := toml.Unmarshal(raw, &cfg)
require.NoError(t, err) require.NoError(t, err)
@@ -759,14 +913,18 @@ func TestIssue484(t *testing.T) {
}, cfg) }, cfg)
} }
type Map458 map[string]interface{} type (
type Slice458 []interface{} Map458 map[string]interface{}
Slice458 []interface{}
)
func (m Map458) A(s string) Slice458 { func (m Map458) A(s string) Slice458 {
return m[s].([]interface{}) return m[s].([]interface{})
} }
func TestIssue458(t *testing.T) { func TestIssue458(t *testing.T) {
t.Parallel()
s := []byte(`[[package]] s := []byte(`[[package]]
dependencies = ["regex"] dependencies = ["regex"]
name = "decode" name = "decode"
@@ -779,18 +937,21 @@ version = "0.1.0"`)
map[string]interface{}{ map[string]interface{}{
"dependencies": []interface{}{"regex"}, "dependencies": []interface{}{"regex"},
"name": "decode", "name": "decode",
"version": "0.1.0"}, "version": "0.1.0",
},
} }
assert.Equal(t, expected, a) assert.Equal(t, expected, a)
} }
func TestIssue252(t *testing.T) { func TestIssue252(t *testing.T) {
t.Parallel()
type config struct { type config struct {
Val1 string `toml:"val1"` Val1 string `toml:"val1"`
Val2 string `toml:"val2"` Val2 string `toml:"val2"`
} }
var configFile = []byte( configFile := []byte(
` `
val1 = "test1" val1 = "test1"
`) `)
@@ -805,10 +966,13 @@ val1 = "test1"
} }
func TestIssue494(t *testing.T) { func TestIssue494(t *testing.T) {
t.Parallel()
data := ` data := `
foo = 2021-04-08 foo = 2021-04-08
bar = 2021-04-08 bar = 2021-04-08
` `
type s struct { type s struct {
Foo time.Time `toml:"foo"` Foo time.Time `toml:"foo"`
Bar time.Time `toml:"bar"` Bar time.Time `toml:"bar"`
@@ -818,7 +982,19 @@ bar = 2021-04-08
require.NoError(t, err) require.NoError(t, err)
} }
func TestIssue507(t *testing.T) {
t.Parallel()
data := []byte{'0', '=', '\n', '0', 'a', 'm', 'e'}
m := map[string]interface{}{}
err := toml.Unmarshal(data, &m)
require.Error(t, err)
}
//nolint:funlen
func TestUnmarshalDecodeErrors(t *testing.T) { func TestUnmarshalDecodeErrors(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
data string data string
@@ -893,23 +1069,98 @@ world'`,
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
m := map[string]interface{}{} m := map[string]interface{}{}
err := toml.Unmarshal([]byte(e.data), &m) err := toml.Unmarshal([]byte(e.data), &m)
require.Error(t, err) require.Error(t, err)
de, ok := err.(*toml.DecodeError)
if !ok { var de *toml.DecodeError
if !errors.As(err, &de) {
t.Fatalf("err should have been a *toml.DecodeError, but got %s (%T)", err, err) t.Fatalf("err should have been a *toml.DecodeError, but got %s (%T)", err, err)
} }
if e.msg != "" { if e.msg != "" {
t.Log("\n" + de.String()) t.Log("\n" + de.String())
require.Equal(t, e.msg, de.Error()) require.Equal(t, "toml: "+e.msg, de.Error())
} }
}) })
} }
} }
//nolint:funlen
func TestLocalDateTime(t *testing.T) {
t.Parallel()
examples := []struct {
desc string
input string
}{
{
desc: "9 digits",
input: "2006-01-02T15:04:05.123456789",
},
{
desc: "8 digits",
input: "2006-01-02T15:04:05.12345678",
},
{
desc: "7 digits",
input: "2006-01-02T15:04:05.1234567",
},
{
desc: "6 digits",
input: "2006-01-02T15:04:05.123456",
},
{
desc: "5 digits",
input: "2006-01-02T15:04:05.12345",
},
{
desc: "4 digits",
input: "2006-01-02T15:04:05.1234",
},
{
desc: "3 digits",
input: "2006-01-02T15:04:05.123",
},
{
desc: "2 digits",
input: "2006-01-02T15:04:05.12",
},
{
desc: "1 digit",
input: "2006-01-02T15:04:05.1",
},
{
desc: "0 digit",
input: "2006-01-02T15:04:05",
},
}
for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) {
t.Parallel()
t.Log("input:", e.input)
doc := `a = ` + e.input
m := map[string]toml.LocalDateTime{}
err := toml.Unmarshal([]byte(doc), &m)
require.NoError(t, err)
actual := m["a"]
golang, err := time.Parse("2006-01-02T15:04:05.999999999", e.input)
require.NoError(t, err)
expected := toml.LocalDateTimeOf(golang)
require.Equal(t, expected, actual)
})
}
}
func TestIssue287(t *testing.T) { func TestIssue287(t *testing.T) {
t.Parallel()
b := `y=[[{}]]` b := `y=[[{}]]`
v := map[string]interface{}{} v := map[string]interface{}{}
err := toml.Unmarshal([]byte(b), &v) err := toml.Unmarshal([]byte(b), &v)
@@ -924,3 +1175,177 @@ func TestIssue287(t *testing.T) {
} }
require.Equal(t, expected, v) require.Equal(t, expected, v)
} }
func TestIssue508(t *testing.T) {
t.Parallel()
type head struct {
Title string `toml:"title"`
}
type text struct {
head
}
b := []byte(`title = "This is a title"`)
t1 := text{}
err := toml.Unmarshal(b, &t1)
require.NoError(t, err)
require.Equal(t, "This is a title", t1.head.Title)
}
//nolint:funlen
func TestDecoderStrict(t *testing.T) {
t.Parallel()
examples := []struct {
desc string
input string
expected string
target interface{}
}{
{
desc: "multiple missing root keys",
input: `
key1 = "value1"
key2 = "missing2"
key3 = "missing3"
key4 = "value4"
`,
expected: `
2| key1 = "value1"
3| key2 = "missing2"
| ~~~~ missing field
4| key3 = "missing3"
5| key4 = "value4"
---
2| key1 = "value1"
3| key2 = "missing2"
4| key3 = "missing3"
| ~~~~ missing field
5| key4 = "value4"
`,
target: &struct {
Key1 string
Key4 string
}{},
},
{
desc: "multi-part key",
input: `a.short.key="foo"`,
expected: `
1| a.short.key="foo"
| ~~~~~~~~~~~ missing field
`,
},
{
desc: "missing table",
input: `
[foo]
bar = 42
`,
expected: `
2| [foo]
| ~~~ missing table
3| bar = 42
`,
},
{
desc: "missing array table",
input: `
[[foo]]
bar = 42
`,
expected: `
2| [[foo]]
| ~~~ missing table
3| bar = 42
`,
},
}
for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) {
t.Parallel()
r := strings.NewReader(e.input)
d := toml.NewDecoder(r)
d.SetStrict(true)
x := e.target
if x == nil {
x = &struct{}{}
}
err := d.Decode(x)
var tsm *toml.StrictMissingError
if errors.As(err, &tsm) {
equalStringsIgnoreNewlines(t, e.expected, tsm.String())
} else {
t.Fatalf("err should have been a *toml.StrictMissingError, but got %s (%T)", err, err)
}
})
}
}
func ExampleDecoder_SetStrict() {
type S struct {
Key1 string
Key3 string
}
doc := `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
r := strings.NewReader(doc)
d := toml.NewDecoder(r)
d.SetStrict(true)
s := S{}
err := d.Decode(&s)
fmt.Println(err.Error())
var details *toml.StrictMissingError
if !errors.As(err, &details) {
panic(fmt.Sprintf("err should have been a *toml.StrictMissingError, but got %s (%T)", err, err))
}
fmt.Println(details.String())
// Output:
// strict mode: fields in the document are missing in the target struct
// 2| key1 = "value1"
// 3| key2 = "value2"
// | ~~~~ missing field
// 4| key3 = "value3"
}
func ExampleUnmarshal() {
type MyConfig struct {
Version int
Name string
Tags []string
}
doc := `
version = 2
name = "go-toml"
tags = ["go", "toml"]
`
var cfg MyConfig
err := toml.Unmarshal([]byte(doc), &cfg)
if err != nil {
panic(err)
}
fmt.Println("version:", cfg.Version)
fmt.Println("name:", cfg.Name)
fmt.Println("tags:", cfg.Tags)
// Output:
// version: 2
// name: go-toml
// tags: [go toml]
}