Compare commits

...

45 Commits

Author SHA1 Message Date
Thomas Pelletier 1b1dd3d6d5 Exclude testing PRs from release notes 2021-12-29 09:53:43 -05:00
Cameron Moore 128b7a8bfb Decode: check buffer length before parsing simple key (#717)
Fixes #714
2021-12-29 08:58:42 -05:00
Cameron Moore 892df5c28e Decode: fix index out of range bug (#716)
Fixes #715
2021-12-29 08:49:33 -05:00
Thomas Pelletier d58eb50ebf Doc: clarify errors returned by Decode (#713)
Fixes #625
2021-12-26 20:04:09 +01:00
Thomas Pelletier 535fc65c5f Fix link in README 2021-12-26 19:49:35 +01:00
Thomas Pelletier f158d7d278 Readme: document more differences with v1 (#712)
* Readme: document more changes with v1

Closes #552
2021-12-26 19:47:03 +01:00
Thomas Pelletier 5fd6e9cce0 Encode: add comment struct tag (#711)
Similar to v1, add a `comment` struct that that makes the encoder emit a comment
before the annotated element, if permitted. Unlike v1, comments are compact by
default (and cannot be changed).

Fixes #595.
2021-12-26 18:29:46 +01:00
Thomas Pelletier 8ce5c3d78f Decoder: time allows extra precision (#710)
As discussed[1], this change allows times to provide precision beyond the
nanosecond (nine digits fractional part). Extra precision is truncated according
to the TOML specificiation.

[1]: https://github.com/pelletier/go-toml/discussions/707
2021-12-26 17:05:10 +01:00
Thomas Pelletier 177b4a5e53 Decode: allow \r\n as line whitespace before \ (#709)
Fixes #708
2021-12-26 16:38:15 +01:00
Cameron Moore 5cbdea6192 decode: fix maximum time offset values (#706)
According to RFC3339 section 5.6, the maximum time offset values for
hours and minutes is 23 and 59, respectively.
2021-12-22 10:29:52 +01:00
Thomas Pelletier 696dd25c17 Decoder: disallow modification of existing table (#704)
Fixes #703
2021-12-15 11:05:27 -05:00
Thomas Pelletier facb2b13e8 Decoder: prevent modification of inline tables (#702)
Fixes #701
2021-12-12 09:43:42 -05:00
Cameron Moore 8bbb519477 Decode: ensure signed exponents don't start with an underscore (#699) 2021-12-05 20:02:19 -05:00
Cameron Moore b37e11d74d Decode: allow maximum seconds value of 60 (#700)
RFC3339 allows seconds to be 60 when adding leap seconds
2021-12-05 20:00:42 -05:00
Cameron Moore 6cd86876b8 Decode: ensure signed numbers don't start with an underscore (#698) 2021-12-04 16:56:48 -05:00
Cameron Moore f53bc740c1 Decode: restrict timezone offset values (#696)
Don't allow hours greater than 24 and minutes greater than 60 per RFC
3339.
2021-12-02 18:59:32 -05:00
Thomas Pelletier 9bf9be681e Decoder: check for invalid chars in timezone (#695)
Fixes #694
2021-12-02 09:00:20 -05:00
Thomas Pelletier c862c344b3 Decoder: allow commas in tags (#693) 2021-11-30 21:59:22 -05:00
Thomas Pelletier 0d20a84523 Encoder: omitempty flag (#692)
Fixes #597
2021-11-30 21:32:28 -05:00
Thomas Pelletier 3990899d7e Decoder: check tz has : between hours and minutes (#691)
Fixes #690
2021-11-30 20:22:11 -05:00
Cameron Moore 4c7a337083 Decoder: fix typo in test description (#689) 2021-11-30 15:28:01 -05:00
Thomas Pelletier bbaae540ce Decoder: check timezones start with +,-,z,Z (#688)
Also simplifies local time seconds scanning.

Fixes #686
2021-11-30 13:01:15 -05:00
Thomas Pelletier ede6445608 Decoder: flag bad \r in literal multiline strings (#687)
Fixes #685
2021-11-30 10:44:48 -05:00
Thomas Pelletier b226db6a29 Decoder: show struct field in type mismatch errors (#684)
The goal is to provide some context as to why the type were mismatched. This
change only works for that case, on structs. This is the same a encoding/json. A
more general solution would be great, but this would require a broader change in
the decoder, which I don't think is necessary at the moment.

Fixes #628
2021-11-24 20:43:56 -05:00
Thomas Pelletier d8997efb5a Mention "-" to prevent encoding field in doc (#683) 2021-11-24 19:52:23 -05:00
Thomas Pelletier 79e78b234c Decoder: fix panic on table array behind a pointer (#682)
Fixes #677
2021-11-24 18:50:04 -05:00
Thomas Pelletier 1b5a25c0ef Decoder: fail on unescaped \r not followed by \n (#681)
Fixes #674
2021-11-24 18:11:36 -05:00
Thomas Pelletier 8eae15b2ee Decoder: validate bounds of day and month in dates (#680)
Fixes #676
2021-11-24 17:42:01 -05:00
Thomas Pelletier 2b3de620e8 Encoder: try to use pointer type TextMarshaler (#679)
If a type does not implement the encoding.TextMarshaler interface but
its pointer type does, use it if possible.

Fixes #678
2021-11-24 14:43:49 -05:00
Cameron Moore 8645d6376b Decoder: flag invalid carriage returns in literal strings (#673) 2021-11-23 22:41:59 -05:00
Thomas Pelletier 64fe47161f API: Encoder and Decoder options are chainable (#670)
Fixes #583
2021-11-13 19:04:53 -05:00
Thomas Pelletier 4dff8eaa4d Decoder: prevent duplicates of inline tables (#667)
* seen: prevent duplicates of inline tables

* Provide clearer error message for redefined keys

For example:

``
toml: key b is already defined
```
2021-11-10 10:04:43 -05:00
Cameron Moore 2dbd29a565 parser: Fix missing check for upper exponent (#665) 2021-11-09 21:15:23 -05:00
Thomas Pelletier f27a07d31a seen: verify arrays (#663)
Fixes #662
2021-11-09 20:26:30 -05:00
Thomas Pelletier 644515958c Update TOML test suite (#661)
Ref #658
2021-11-08 22:35:35 -05:00
Thomas Pelletier 8683be35f6 seen: check inline tables (#660)
Fixes #658
2021-11-08 21:53:02 -05:00
Thomas Pelletier dc1740d473 Decode: code cleanup for struct cache (#659) 2021-11-07 18:35:30 -05:00
Thomas Pelletier 11f789ef11 Decode: prevent comments that look like dates to be accepted (#657)
* parser: fix date detection

When the parser has to decide between parsing and integer or a date, it should
check that all characters are actually acceptable (digits, or date/time
elements).

Fixes #655
2021-11-04 22:06:12 -04:00
Thomas Pelletier 74d21b367f scanner: handle carriage return in comments (#656)
Fixes #653
2021-11-04 21:40:16 -04:00
Thomas Pelletier 6617e7e73d utf8: use lookup table to validate ASCII (#654) 2021-11-04 16:05:36 -04:00
Thomas Pelletier 3dbca20bc9 Decoder: flag invalid carriage returns in strings (#652)
Fixes #651
2021-11-02 10:02:25 -04:00
Thomas Pelletier 85c0658984 Decode: add missing checks for LocalTime (#650) 2021-10-29 22:13:08 -04:00
Thomas Pelletier 772d169b52 testsuite: return error when can't encode tag (#648) 2021-10-29 21:51:50 -04:00
Cameron Moore b4ec220f7e Update tomltestgen and regenerate tests (#645)
Remove testsuite build tag from generated tests file

Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
2021-10-28 20:46:08 -04:00
Thomas Pelletier 3694ae88f6 decode: error on _ before exponent in floats (#647)
Fixes #646
2021-10-28 20:41:10 -04:00
20 changed files with 1708 additions and 746 deletions
+1
View File
@@ -2,6 +2,7 @@ changelog:
exclude: exclude:
labels: labels:
- build - build
- testing
categories: categories:
- title: What's new - title: What's new
labels: labels:
+109 -4
View File
@@ -55,8 +55,9 @@ to check for typos. [See example in the documentation][strict].
### Contextualized errors ### Contextualized errors
When decoding errors occur, go-toml returns [`DecodeError`][decode-err]), which When most decoding errors occur, go-toml returns [`DecodeError`][decode-err]),
contains a human readable contextualized version of the error. For example: which contains a human readable contextualized version of the error. For
example:
``` ```
2| key1 = "value1" 2| key1 = "value1"
@@ -324,6 +325,29 @@ The recommended replacement is pre-filling the struct before unmarshaling.
[go-defaults]: https://github.com/mcuadros/go-defaults [go-defaults]: https://github.com/mcuadros/go-defaults
#### `toml.Tree` replacement
This structure was the initial attempt at providing a document model for
go-toml. It allows manipulating the structure of any document, encoding and
decoding from their TOML representation. While a more robust feature was
initially planned in go-toml v2, this has been ultimately [removed from
scope][nodoc] of this library, with no plan to add it back at the moment. The
closest equivalent at the moment would be to unmarshal into an `interface{}` and
use type assertions and/or reflection to manipulate the arbitrary
structure. However this would fall short of providing all of the TOML features
such as adding comments and be specific about whitespace.
#### `toml.Position` are not retrievable anymore
The API for retrieving the position (line, column) of a specific TOML element do
not exist anymore. This was done to minimize the amount of concepts introduced
by the library (query path), and avoid the performance hit related to storing
positions in the absence of a document model, for a feature that seemed to have
little use. Errors however have gained more detailed position
information. Position retrieval seems better fitted for a document model, which
has been [removed from the scope][nodoc] of go-toml v2 at the moment.
### Encoding / Marshal ### Encoding / Marshal
#### Default struct fields order #### Default struct fields order
@@ -359,7 +383,8 @@ fmt.Println("v2:\n" + string(b))
``` ```
There is no way to make v2 encoder behave like v1. A workaround could be to There is no way to make v2 encoder behave like v1. A workaround could be to
manually sort the fields alphabetically in the struct definition. manually sort the fields alphabetically in the struct definition, or generate
struct types using `reflect.StructOf`.
#### No indentation by default #### No indentation by default
@@ -407,7 +432,9 @@ fmt.Println("v2 Encoder:\n" + string(buf.Bytes()))
V1 always uses double quotes (`"`) around strings and keys that cannot be V1 always uses double quotes (`"`) around strings and keys that cannot be
represented bare (unquoted). V2 uses single quotes instead by default (`'`), represented bare (unquoted). V2 uses single quotes instead by default (`'`),
unless a character cannot be represented, then falls back to double quotes. unless a character cannot be represented, then falls back to double quotes. As a
result of this change, `Encoder.QuoteMapKeys` has been removed, as it is not
useful anymore.
There is no way to make v2 encoder behave like v1. There is no way to make v2 encoder behave like v1.
@@ -422,6 +449,84 @@ There is no way to make v2 encoder behave like v1.
[tm]: https://golang.org/pkg/encoding/#TextMarshaler [tm]: https://golang.org/pkg/encoding/#TextMarshaler
#### `Encoder.CompactComments` has been removed
Emitting compact comments is now the default behavior of go-toml. This option
is not necessary anymore.
#### Struct tags have been merged
V1 used to provide multiple struct tags: `comment`, `commented`, `multiline`,
`toml`, and `omitempty`. To behave more like the standard library, v2 has merged
`toml`, `multiline`, and `omitempty`. For example:
```go
type doc struct {
// v1
F string `toml:"field" multiline:"true" omitempty:"true"`
// v2
F string `toml:"field,multiline,omitempty"`
}
```
Has a result, the `Encoder.SetTag*` methods have been removed, as there is just
one tag now.
#### `commented` tag has been removed
There is no replacement for the `commented` tag. This feature would be better
suited in a proper document model for go-toml v2, which has been [cut from
scope][nodoc] at the moment.
#### `Encoder.ArraysWithOneElementPerLine` has been renamed
The new name is `Encoder.SetArraysMultiline`. The behavior should be the same.
#### `Encoder.Indentation` has been renamed
The new name is `Encoder.SetIndentSymbol`. The behavior should be the same.
#### Embedded structs are tables
V1 defaults to merging embedded struct fields into the embedding struct. This
behavior was unexpected because it does not follow the standard library. To
avoid breaking backward compatibility, the `Encoder.PromoteAnonymous` method was
added to make the encoder behave correctly. Given backward compatibility is not
a problem anymore, v2 does the right thing by default. There is no way to revert
to the old behavior, and `Encoder.PromoteAnonymous` has been removed.
```go
type Embedded struct {
Value string `toml:"value"`
}
type Doc struct {
Embedded
}
d := Doc{}
fmt.Println("v1:")
b, err := v1.Marshal(d)
fmt.Println(string(b))
fmt.Println("v2:")
b, err = v2.Marshal(d)
fmt.Println(string(b))
// Output:
// v1:
// value = ""
//
// v2:
// [Embedded]
// value = ''
```
[nodoc]: https://github.com/pelletier/go-toml/discussions/506#discussioncomment-1526038
## License ## License
The MIT License (MIT). Read [LICENSE](LICENSE). The MIT License (MIT). Read [LICENSE](LICENSE).
+1 -2
View File
@@ -43,8 +43,7 @@ type testsCollection struct {
Count int Count int
} }
const srcTemplate = "// +build testsuite\n\n" + const srcTemplate = "// Generated by tomltestgen for toml-test ref {{.Ref}} on {{.Timestamp}}\n" +
"// Generated by tomltestgen for toml-test ref {{.Ref}} on {{.Timestamp}}\n" +
"package toml_test\n" + "package toml_test\n" +
" import (\n" + " import (\n" +
" \"testing\"\n" + " \"testing\"\n" +
+107 -31
View File
@@ -35,13 +35,22 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD")
} }
date.Year = parseDecimalDigits(b[0:4]) var err error
v := parseDecimalDigits(b[5:7]) date.Year, err = parseDecimalDigits(b[0:4])
if err != nil {
return LocalDate{}, err
}
date.Month = v date.Month, err = parseDecimalDigits(b[5:7])
if err != nil {
return LocalDate{}, err
}
date.Day = parseDecimalDigits(b[8:10]) date.Day, err = parseDecimalDigits(b[8:10])
if err != nil {
return LocalDate{}, err
}
if !isValidDate(date.Year, date.Month, date.Day) { if !isValidDate(date.Year, date.Month, date.Day) {
return LocalDate{}, newDecodeError(b, "impossible date") return LocalDate{}, newDecodeError(b, "impossible date")
@@ -50,15 +59,18 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, nil return date, nil
} }
func parseDecimalDigits(b []byte) int { func parseDecimalDigits(b []byte) (int, error) {
v := 0 v := 0
for _, c := range b { for i, c := range b {
if c < '0' || c > '9' {
return 0, newDecodeError(b[i:i+1], "expected digit (0-9)")
}
v *= 10 v *= 10
v += int(c - '0') v += int(c - '0')
} }
return v return v, nil
} }
func parseDateTime(b []byte) (time.Time, error) { func parseDateTime(b []byte) (time.Time, error) {
@@ -87,13 +99,36 @@ func parseDateTime(b []byte) (time.Time, error) {
if len(b) != dateTimeByteLen { if len(b) != dateTimeByteLen {
return time.Time{}, newDecodeError(b, "invalid date-time timezone") return time.Time{}, newDecodeError(b, "invalid date-time timezone")
} }
direction := 1 var direction int
if b[0] == '-' { switch b[0] {
case '-':
direction = -1 direction = -1
case '+':
direction = +1
default:
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset character")
}
if b[3] != ':' {
return time.Time{}, newDecodeError(b[3:4], "expected a : separator")
}
hours, err := parseDecimalDigits(b[1:3])
if err != nil {
return time.Time{}, err
}
if hours > 23 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset hours")
}
minutes, err := parseDecimalDigits(b[4:6])
if err != nil {
return time.Time{}, err
}
if minutes > 59 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset minutes")
} }
hours := digitsToInt(b[1:3])
minutes := digitsToInt(b[4:6])
seconds := direction * (hours*3600 + minutes*60) seconds := direction * (hours*3600 + minutes*60)
zone = time.FixedZone("", seconds) zone = time.FixedZone("", seconds)
b = b[dateTimeByteLen:] b = b[dateTimeByteLen:]
@@ -159,7 +194,13 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
t.Hour = parseDecimalDigits(b[0:2]) var err error
t.Hour, err = parseDecimalDigits(b[0:2])
if err != nil {
return t, nil, err
}
if t.Hour > 23 { if t.Hour > 23 {
return t, nil, newDecodeError(b[0:2], "hour cannot be greater 23") return t, nil, newDecodeError(b[0:2], "hour cannot be greater 23")
} }
@@ -167,7 +208,10 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes")
} }
t.Minute = parseDecimalDigits(b[3:5]) t.Minute, err = parseDecimalDigits(b[3:5])
if err != nil {
return t, nil, err
}
if t.Minute > 59 { if t.Minute > 59 {
return t, nil, newDecodeError(b[3:5], "minutes cannot be greater 59") return t, nil, newDecodeError(b[3:5], "minutes cannot be greater 59")
} }
@@ -175,42 +219,58 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds")
} }
t.Second = parseDecimalDigits(b[6:8]) t.Second, err = parseDecimalDigits(b[6:8])
if t.Second > 59 { if err != nil {
return t, nil, newDecodeError(b[3:5], "seconds cannot be greater 59") return t, nil, err
} }
const minLengthWithFrac = 9 if t.Second > 60 {
if len(b) >= minLengthWithFrac && b[minLengthWithFrac-1] == '.' { return t, nil, newDecodeError(b[6:8], "seconds cannot be greater 60")
}
b = b[8:]
if len(b) >= 1 && b[0] == '.' {
frac := 0 frac := 0
precision := 0
digits := 0 digits := 0
for i, c := range b[minLengthWithFrac:] { for i, c := range b[1:] {
if !isDigit(c) { if !isDigit(c) {
if i == 0 { if i == 0 {
return t, nil, newDecodeError(b[i:i+1], "need at least one digit after fraction point") return t, nil, newDecodeError(b[0:1], "need at least one digit after fraction point")
} }
break break
} }
digits++
const maxFracPrecision = 9 const maxFracPrecision = 9
if i >= maxFracPrecision { if i >= maxFracPrecision {
return t, nil, newDecodeError(b[i:i+1], "maximum precision for date time is nanosecond") // go-toml allows decoding fractional seconds
// beyond the supported precision of 9
// digits. It truncates the fractional component
// to the supported precision and ignores the
// remaining digits.
//
// https://github.com/pelletier/go-toml/discussions/707
continue
} }
frac *= 10 frac *= 10
frac += int(c - '0') frac += int(c - '0')
digits++ precision++
} }
t.Nanosecond = frac * nspow[digits] if precision == 0 {
t.Precision = digits return t, nil, newDecodeError(b[:1], "nanoseconds need at least one digit")
}
return t, b[9+digits:], nil t.Nanosecond = frac * nspow[precision]
t.Precision = precision
return t, b[1+digits:], nil
} }
return t, b, nil
return t, b[8:], nil
} }
//nolint:cyclop //nolint:cyclop
@@ -335,8 +395,17 @@ func parseIntDec(b []byte) (int64, error) {
} }
func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) { func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
if b[0] == '_' { start := 0
return nil, newDecodeError(b[0:1], "number cannot start with underscore") if b[start] == '+' || b[start] == '-' {
start++
}
if len(b) == start {
return b, nil
}
if b[start] == '_' {
return nil, newDecodeError(b[start:start+1], "number cannot start with underscore")
} }
if b[len(b)-1] == '_' { if b[len(b)-1] == '_' {
@@ -405,6 +474,13 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
if !before { if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores")
} }
if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore before exponent")
}
before = false
case '+', '-':
// signed exponents
cleaned = append(cleaned, c)
before = false before = false
case 'e', 'E': case 'e', 'E':
if i < len(b)-1 && b[i+1] == '_' { if i < len(b)-1 && b[i+1] == '_' {
@@ -430,7 +506,7 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
// isValidDate checks if a provided date is a date that exists. // isValidDate checks if a provided date is a date that exists.
func isValidDate(year int, month int, day int) bool { func isValidDate(year int, month int, day int) bool {
return day <= daysIn(month, year) return month > 0 && month < 13 && day > 0 && day <= daysIn(month, year)
} }
// daysBefore[m] counts the number of days in a non-leap year // daysBefore[m] counts the number of days in a non-leap year
+6 -6
View File
@@ -212,12 +212,12 @@ func ExampleDecodeError() {
fmt.Println(err) fmt.Println(err)
//nolint:errorlint var derr *DecodeError
de := err.(*DecodeError) if errors.As(err, &derr) {
fmt.Println(de.String()) fmt.Println(derr.String())
row, col := derr.Position()
row, col := de.Position() fmt.Println("error occurred at row", row, "column", col)
fmt.Println("error occurred at row", row, "column", col) }
// Output: // Output:
// toml: number must have at least one digit between underscores // toml: number must have at least one digit between underscores
// 1| name = 123__456 // 1| name = 123__456
+24 -32
View File
@@ -20,8 +20,8 @@ type Iterator struct {
node *Node node *Node
} }
// Next moves the iterator forward and returns true if points to a node, false // Next moves the iterator forward and returns true if points to a
// otherwise. // node, false otherwise.
func (c *Iterator) Next() bool { func (c *Iterator) Next() bool {
if !c.started { if !c.started {
c.started = true c.started = true
@@ -31,8 +31,8 @@ func (c *Iterator) Next() bool {
return c.node.Valid() return c.node.Valid()
} }
// IsLast returns true if the current node of the iterator is the last one. // IsLast returns true if the current node of the iterator is the last
// Subsequent call to Next() will return false. // one. Subsequent call to Next() will return false.
func (c *Iterator) IsLast() bool { func (c *Iterator) IsLast() bool {
return c.node.next == 0 return c.node.next == 0
} }
@@ -62,20 +62,20 @@ func (r *Root) at(idx Reference) *Node {
return &r.nodes[idx] return &r.nodes[idx]
} }
// Arrays have one child per element in the array. // Arrays have one child per element in the array. InlineTables have
// InlineTables have one child per key-value pair in the table. // one child per key-value pair in the table. KeyValues have at least
// KeyValues have at least two children. The first one is the value. The // two children. The first one is the value. The rest make a
// rest make a potentially dotted key. // potentially dotted key. Table and Array table have one child per
// Table and Array table have one child per element of the key they // element of the key they represent (same as KeyValue, but without
// represent (same as KeyValue, but without the last node being the value). // the last node being the value).
// children []Node
type Node struct { type Node struct {
Kind Kind Kind Kind
Raw Range // Raw bytes from the input. Raw Range // Raw bytes from the input.
Data []byte // Node value (could be either allocated or referencing the input). Data []byte // Node value (either allocated or referencing the input).
// References to other nodes, as offsets in the backing array from this // References to other nodes, as offsets in the backing array
// node. References can go backward, so those can be negative. // from this node. References can go backward, so those can be
// negative.
next int // 0 if last element next int // 0 if last element
child int // 0 if no child child int // 0 if no child
} }
@@ -85,8 +85,8 @@ type Range struct {
Length uint32 Length uint32
} }
// Next returns a copy of the next node, or an invalid Node if there is no // Next returns a copy of the next node, or an invalid Node if there
// next node. // is no next node.
func (n *Node) Next() *Node { func (n *Node) Next() *Node {
if n.next == 0 { if n.next == 0 {
return nil return nil
@@ -96,9 +96,9 @@ func (n *Node) Next() *Node {
return (*Node)(danger.Stride(ptr, size, n.next)) return (*Node)(danger.Stride(ptr, size, n.next))
} }
// Child returns a copy of the first child node of this node. Other children // Child returns a copy of the first child node of this node. Other
// can be accessed calling Next on the first child. // children can be accessed calling Next on the first child. Returns
// Returns an invalid Node if there is none. // an invalid Node if there is none.
func (n *Node) Child() *Node { func (n *Node) Child() *Node {
if n.child == 0 { if n.child == 0 {
return nil return nil
@@ -113,10 +113,9 @@ func (n *Node) Valid() bool {
return n != nil return n != nil
} }
// Key returns the child nodes making the Key on a supported node. Panics // Key returns the child nodes making the Key on a supported
// otherwise. // node. Panics otherwise. They are guaranteed to be all be of the
// They are guaranteed to be all be of the Kind Key. A simple key would return // Kind Key. A simple key would return just one element.
// just one element.
func (n *Node) Key() Iterator { func (n *Node) Key() Iterator {
switch n.Kind { switch n.Kind {
case KeyValue: case KeyValue:
@@ -133,10 +132,9 @@ func (n *Node) Key() Iterator {
} }
// Value returns a pointer to the value node of a KeyValue. // Value returns a pointer to the value node of a KeyValue.
// Guaranteed to be non-nil. // Guaranteed to be non-nil. Panics if not called on a KeyValue node,
// Panics if not called on a KeyValue node, or if the Children are malformed. // or if the Children are malformed.
func (n *Node) Value() *Node { func (n *Node) Value() *Node {
assertKind(KeyValue, *n)
return n.Child() return n.Child()
} }
@@ -144,9 +142,3 @@ func (n *Node) Value() *Node {
func (n *Node) Children() Iterator { func (n *Node) Children() Iterator {
return Iterator{node: n.Child()} return Iterator{node: n.Child()}
} }
func assertKind(k Kind, n Node) {
if n.Kind != k {
panic(fmt.Errorf("method was expecting a %s, not a %s", k, n.Kind))
}
}
+23
View File
@@ -0,0 +1,23 @@
package danger
import (
"reflect"
"unsafe"
)
// typeID is used as key in encoder and decoder caches to enable using
// the optimize runtime.mapaccess2_fast64 function instead of the more
// expensive lookup if we were to use reflect.Type as map key.
//
// typeID holds the pointer to the reflect.Type value, which is unique
// in the program.
//
// https://github.com/segmentio/encoding/blob/master/json/codec.go#L59-L61
type TypeID unsafe.Pointer
func MakeTypeID(t reflect.Type) TypeID {
// reflect.Type has the fields:
// typ unsafe.Pointer
// ptr unsafe.Pointer
return TypeID((*[2]unsafe.Pointer)(unsafe.Pointer(&t))[1])
}
@@ -457,35 +457,6 @@ func TestEmptytomlUnmarshal(t *testing.T) {
assert.Equal(t, emptyTestData, result) assert.Equal(t, emptyTestData, result)
} }
func TestEmptyUnmarshalOmit(t *testing.T) {
t.Skipf("Have not figured yet if omitempty is a good idea")
type emptyMarshalTestStruct2 struct {
Title string `toml:"title"`
Bool bool `toml:"bool,omitempty"`
Int int `toml:"int, omitempty"`
String string `toml:"string,omitempty "`
StringList []string `toml:"stringlist,omitempty"`
Ptr *basicMarshalTestStruct `toml:"ptr,omitempty"`
Map map[string]string `toml:"map,omitempty"`
}
emptyTestData2 := emptyMarshalTestStruct2{
Title: "Placeholder",
Bool: false,
Int: 0,
String: "",
StringList: []string{},
Ptr: nil,
Map: map[string]string{},
}
result := emptyMarshalTestStruct2{}
err := toml.Unmarshal(emptyTestToml, &result)
require.NoError(t, err)
assert.Equal(t, emptyTestData2, result)
}
type pointerMarshalTestStruct struct { type pointerMarshalTestStruct struct {
Str *string Str *string
List *[]string List *[]string
+153 -61
View File
@@ -3,6 +3,7 @@ package tracker
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"sync"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
) )
@@ -54,63 +55,99 @@ func (k keyKind) String() string {
type SeenTracker struct { type SeenTracker struct {
entries []entry entries []entry
currentIdx int currentIdx int
nextID int }
var pool sync.Pool
func (s *SeenTracker) reset() {
// Always contains a root element at index 0.
s.currentIdx = 0
if len(s.entries) == 0 {
s.entries = make([]entry, 1, 2)
} else {
s.entries = s.entries[:1]
}
s.entries[0].child = -1
s.entries[0].next = -1
} }
type entry struct { type entry struct {
id int // Use -1 to indicate no child or no sibling.
parent int child int
next int
name []byte name []byte
kind keyKind kind keyKind
explicit bool explicit bool
} }
// Remove all descendent of node at position idx. // Find the index of the child of parentIdx with key k. Returns -1 if
func (s *SeenTracker) clear(idx int) { // it does not exist.
p := s.entries[idx].id func (s *SeenTracker) find(parentIdx int, k []byte) int {
rest := clear(p, s.entries[idx+1:]) for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries = s.entries[:idx+1+len(rest)] if bytes.Equal(s.entries[i].name, k) {
} return i
func clear(parentID int, entries []entry) []entry {
for i := 0; i < len(entries); {
if entries[i].parent == parentID {
id := entries[i].id
copy(entries[i:], entries[i+1:])
entries = entries[:len(entries)-1]
rest := clear(id, entries[i:])
entries = entries[:i+len(rest)]
} else {
i++
} }
} }
return entries return -1
}
// Remove all descendants of node at position idx.
func (s *SeenTracker) clear(idx int) {
if idx >= len(s.entries) {
return
}
for i := s.entries[idx].child; i >= 0; {
next := s.entries[i].next
n := s.entries[0].next
s.entries[0].next = i
s.entries[i].next = n
s.entries[i].name = nil
s.clear(i)
i = next
}
s.entries[idx].child = -1
} }
func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int { func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool) int {
parentID := s.id(parentIdx) e := entry{
child: -1,
next: s.entries[parentIdx].child,
idx := len(s.entries)
s.entries = append(s.entries, entry{
id: s.nextID,
parent: parentID,
name: name, name: name,
kind: kind, kind: kind,
explicit: explicit, explicit: explicit,
}) }
s.nextID++ var idx int
if s.entries[0].next >= 0 {
idx = s.entries[0].next
s.entries[0].next = s.entries[idx].next
s.entries[idx] = e
} else {
idx = len(s.entries)
s.entries = append(s.entries, e)
}
s.entries[parentIdx].child = idx
return idx return idx
} }
// CheckExpression takes a top-level node and checks that it does not contain keys func (s *SeenTracker) setExplicitFlag(parentIdx int) {
// that have been seen in previous calls, and validates that types are consistent. for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
s.entries[i].explicit = true
s.setExplicitFlag(i)
}
}
// CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error { func (s *SeenTracker) CheckExpression(node *ast.Node) error {
if s.entries == nil { if s.entries == nil {
// Skip ID = 0 to remove the confusion between nodes whose parent has s.reset()
// id 0 and root nodes (parent id is 0 because it's the zero value).
s.nextID = 1
// Start unscoped, so idx is negative.
s.currentIdx = -1
} }
switch node.Kind { switch node.Kind {
case ast.KeyValue: case ast.KeyValue:
@@ -125,9 +162,13 @@ func (s *SeenTracker) CheckExpression(node *ast.Node) error {
} }
func (s *SeenTracker) checkTable(node *ast.Node) error { func (s *SeenTracker) checkTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key() it := node.Key()
parentIdx := -1 parentIdx := 0
// This code is duplicated in checkArrayTable. This is because factoring // This code is duplicated in checkArrayTable. This is because factoring
// it in a function requires to copy the iterator, or allocate it to the // it in a function requires to copy the iterator, or allocate it to the
@@ -143,6 +184,11 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
if idx < 0 { if idx < 0 {
idx = s.create(parentIdx, k, tableKind, false) idx = s.create(parentIdx, k, tableKind, false)
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
} }
parentIdx = idx parentIdx = idx
} }
@@ -169,9 +215,13 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
} }
func (s *SeenTracker) checkArrayTable(node *ast.Node) error { func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
it := node.Key() it := node.Key()
parentIdx := -1 parentIdx := 0
for it.Next() { for it.Next() {
if it.IsLast() { if it.IsLast() {
@@ -184,7 +234,13 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
if idx < 0 { if idx < 0 {
idx = s.create(parentIdx, k, tableKind, false) idx = s.create(parentIdx, k, tableKind, false)
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
} }
parentIdx = idx parentIdx = idx
} }
@@ -207,53 +263,89 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
} }
func (s *SeenTracker) checkKeyValue(node *ast.Node) error { func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
it := node.Key()
parentIdx := s.currentIdx parentIdx := s.currentIdx
it := node.Key()
for it.Next() { for it.Next() {
k := it.Node().Data k := it.Node().Data
idx := s.find(parentIdx, k) idx := s.find(parentIdx, k)
if idx >= 0 { if idx < 0 {
if s.entries[idx].kind != tableKind { idx = s.create(parentIdx, k, tableKind, false)
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), s.entries[idx].kind) } else {
} entry := s.entries[idx]
if s.entries[idx].explicit { if it.IsLast() {
return fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit {
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k)) return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
} }
} else {
idx = s.create(parentIdx, k, tableKind, false)
} }
parentIdx = idx parentIdx = idx
} }
kind := valueKind s.entries[parentIdx].kind = valueKind
if node.Value().Kind == ast.InlineTable { value := node.Value()
kind = tableKind
switch value.Kind {
case ast.InlineTable:
return s.checkInlineTable(value)
case ast.Array:
return s.checkArray(value)
} }
s.entries[parentIdx].kind = kind
return nil return nil
} }
func (s *SeenTracker) id(idx int) int { func (s *SeenTracker) checkArray(node *ast.Node) error {
if idx >= 0 { it := node.Children()
return s.entries[idx].id for it.Next() {
n := it.Node()
switch n.Kind {
case ast.InlineTable:
err := s.checkInlineTable(n)
if err != nil {
return err
}
case ast.Array:
err := s.checkArray(n)
if err != nil {
return err
}
}
} }
return 0 return nil
} }
func (s *SeenTracker) find(parentIdx int, k []byte) int { func (s *SeenTracker) checkInlineTable(node *ast.Node) error {
parentID := s.id(parentIdx) if pool.New == nil {
pool.New = func() interface{} {
for i := parentIdx + 1; i < len(s.entries); i++ { return &SeenTracker{}
if s.entries[i].parent == parentID && bytes.Equal(s.entries[i].name, k) {
return i
} }
} }
return -1 s = pool.Get().(*SeenTracker)
s.reset()
it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(n)
if err != nil {
return err
}
}
// As inline tables are self-contained, the tracker does not
// need to retain the details of what they contain. The
// keyValue element that creates the inline table is kept to
// mark the presence of the inline table and prevent
// redefinition of its keys: check* functions cannot walk into
// a value.
pool.Put(s)
return nil
} }
+139 -35
View File
@@ -11,6 +11,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unicode"
) )
// Marshal serializes a Go value as a TOML document. // Marshal serializes a Go value as a TOML document.
@@ -54,8 +55,9 @@ func NewEncoder(w io.Writer) *Encoder {
// inline tag: // inline tag:
// //
// MyField `inline:"true"` // MyField `inline:"true"`
func (enc *Encoder) SetTablesInline(inline bool) { func (enc *Encoder) SetTablesInline(inline bool) *Encoder {
enc.tablesInline = inline enc.tablesInline = inline
return enc
} }
// SetArraysMultiline forces the encoder to emit all arrays with one element per // SetArraysMultiline forces the encoder to emit all arrays with one element per
@@ -64,20 +66,23 @@ func (enc *Encoder) SetTablesInline(inline bool) {
// This behavior can be controlled on an individual struct field basis with the multiline tag: // This behavior can be controlled on an individual struct field basis with the multiline tag:
// //
// MyField `multiline:"true"` // MyField `multiline:"true"`
func (enc *Encoder) SetArraysMultiline(multiline bool) { func (enc *Encoder) SetArraysMultiline(multiline bool) *Encoder {
enc.arraysMultiline = multiline enc.arraysMultiline = multiline
return enc
} }
// SetIndentSymbol defines the string that should be used for indentation. The // SetIndentSymbol defines the string that should be used for indentation. The
// provided string is repeated for each indentation level. Defaults to two // provided string is repeated for each indentation level. Defaults to two
// spaces. // spaces.
func (enc *Encoder) SetIndentSymbol(s string) { func (enc *Encoder) SetIndentSymbol(s string) *Encoder {
enc.indentSymbol = s enc.indentSymbol = s
return enc
} }
// SetIndentTables forces the encoder to intent tables and array tables. // SetIndentTables forces the encoder to intent tables and array tables.
func (enc *Encoder) SetIndentTables(indent bool) { func (enc *Encoder) SetIndentTables(indent bool) *Encoder {
enc.indentTables = indent enc.indentTables = indent
return enc
} }
// Encode writes a TOML representation of v to the stream. // Encode writes a TOML representation of v to the stream.
@@ -99,27 +104,31 @@ func (enc *Encoder) SetIndentTables(indent bool) {
// Intermediate tables are always printed. // Intermediate tables are always printed.
// //
// By default, strings are encoded as literal string, unless they contain either // By default, strings are encoded as literal string, unless they contain either
// a newline character or a single quote. In that case they are emitted as quoted // a newline character or a single quote. In that case they are emitted as
// strings. // quoted strings.
// //
// When encoding structs, fields are encoded in order of definition, with their // When encoding structs, fields are encoded in order of definition, with their
// exact name. // exact name.
// //
// Struct tags // Struct tags
// //
// The following struct tags are available to tweak encoding on a per-field // The encoding of each public struct field can be customized by the format
// basis: // string in the "toml" key of the struct field's tag. This follows
// encoding/json's convention. The format string starts with the name of the
// field, optionally followed by a comma-separated list of options. The name may
// be empty in order to provide options without overriding the default name.
// //
// toml:"foo" // The "multiline" option emits strings as quoted multi-line TOML strings. It
// Changes the name of the key to use for the field to foo. // has no effect on fields that would not be encoded as strings.
// //
// multiline:"true" // The "inline" option turns fields that would be emitted as tables into inline
// When the field contains a string, it will be emitted as a quoted // tables instead. It has no effect on other fields.
// multi-line TOML string.
// //
// inline:"true" // The "omitempty" option prevents empty values or groups from being emitted.
// When the field would normally be encoded as a table, it is instead //
// encoded as an inline table. // In addition to the "toml" tag struct tag, a "comment" tag can be used to emit
// a TOML comment before the value being annotated. Comments are ignored inside
// inline tables.
func (enc *Encoder) Encode(v interface{}) error { func (enc *Encoder) Encode(v interface{}) error {
var ( var (
b []byte b []byte
@@ -147,6 +156,8 @@ func (enc *Encoder) Encode(v interface{}) error {
type valueOptions struct { type valueOptions struct {
multiline bool multiline bool
omitempty bool
comment string
} }
type encoderCtx struct { type encoderCtx struct {
@@ -196,7 +207,6 @@ func (ctx *encoderCtx) isRoot() bool {
return len(ctx.parentKey) == 0 && !ctx.hasKey return len(ctx.parentKey) == 0 && !ctx.hasKey
} }
//nolint:cyclop,funlen
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if !v.IsZero() { if !v.IsZero() {
i, ok := v.Interface().(time.Time) i, ok := v.Interface().(time.Time)
@@ -205,7 +215,12 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
} }
} }
if v.Type().Implements(textMarshalerType) { hasTextMarshaler := v.Type().Implements(textMarshalerType)
if hasTextMarshaler || (v.CanAddr() && reflect.PtrTo(v.Type()).Implements(textMarshalerType)) {
if !hasTextMarshaler {
v = v.Addr()
}
if ctx.isRoot() { if ctx.isRoot() {
return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type()) return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type())
} }
@@ -288,6 +303,15 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
if !ctx.hasKey { if !ctx.hasKey {
panic("caller of encodeKv should have set the key in the context") panic("caller of encodeKv should have set the key in the context")
} }
if (ctx.options.omitempty || options.omitempty) && isEmptyValue(v) {
return b, nil
}
if !ctx.inline {
b = enc.encodeComment(ctx.indent, options.comment, b)
}
b = enc.indent(ctx.indent, b) b = enc.indent(ctx.indent, b)
b, err = enc.encodeKey(b, ctx.key) b, err = enc.encodeKey(b, ctx.key)
@@ -312,6 +336,24 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
return b, nil return b, nil
} }
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return false
}
const literalQuote = '\'' const literalQuote = '\''
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte { func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte {
@@ -405,6 +447,8 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
return b, nil return b, nil
} }
b = enc.encodeComment(ctx.indent, ctx.options.comment, b)
b = enc.indent(ctx.indent, b) b = enc.indent(ctx.indent, b)
b = append(b, '[') b = append(b, '[')
@@ -521,8 +565,7 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) {
func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
var t table var t table
//nolint:godox // TODO: cache this
// TODO: cache this?
typ := v.Type() typ := v.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
fieldType := typ.Field(i) fieldType := typ.Field(i)
@@ -532,16 +575,20 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
continue continue
} }
k, ok := fieldType.Tag.Lookup("toml") k := fieldType.Name
if !ok {
k = fieldType.Name tag := fieldType.Tag.Get("toml")
}
// special field name to skip field // special field name to skip field
if k == "-" { if tag == "-" {
continue continue
} }
name, opts := parseTag(tag)
if isValidName(name) {
k = name
}
f := v.Field(i) f := v.Field(i)
if isNil(f) { if isNil(f) {
@@ -549,12 +596,12 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
} }
options := valueOptions{ options := valueOptions{
multiline: fieldBoolTag(fieldType, "multiline"), multiline: opts.multiline,
omitempty: opts.omitempty,
comment: fieldType.Tag.Get("comment"),
} }
inline := fieldBoolTag(fieldType, "inline") if opts.inline || !willConvertToTableOrArrayTable(ctx, f) {
if inline || !willConvertToTableOrArrayTable(ctx, f) {
t.pushKV(k, f, options) t.pushKV(k, f, options)
} else { } else {
t.pushTable(k, f, options) t.pushTable(k, f, options)
@@ -564,13 +611,70 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
return enc.encodeTable(b, ctx, t) return enc.encodeTable(b, ctx, t)
} }
func fieldBoolTag(field reflect.StructField, tag string) bool { func (enc *Encoder) encodeComment(indent int, comment string, b []byte) []byte {
x, ok := field.Tag.Lookup(tag) if comment != "" {
b = enc.indent(indent, b)
return ok && x == "true" b = append(b, "# "...)
b = append(b, comment...)
b = append(b, '\n')
}
return b
}
func isValidName(s string) bool {
if s == "" {
return false
}
for _, c := range s {
switch {
case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c):
// Backslash and quote chars are reserved, but
// otherwise any punctuation chars are allowed
// in a tag name.
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
return false
}
}
return true
}
type tagOptions struct {
multiline bool
inline bool
omitempty bool
}
func parseTag(tag string) (string, tagOptions) {
opts := tagOptions{}
idx := strings.Index(tag, ",")
if idx == -1 {
return tag, opts
}
raw := tag[idx+1:]
tag = string(tag[:idx])
for raw != "" {
var o string
i := strings.Index(raw, ",")
if i >= 0 {
o, raw = raw[:i], raw[i+1:]
} else {
o, raw = raw, ""
}
switch o {
case "multiline":
opts.multiline = true
case "inline":
opts.inline = true
case "omitempty":
opts.omitempty = true
}
}
return tag, opts
} }
//nolint:cyclop
func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, error) { func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, error) {
var err error var err error
@@ -653,7 +757,7 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) bool {
if !v.IsValid() { if !v.IsValid() {
return false return false
} }
if v.Type() == timeType || v.Type().Implements(textMarshalerType) { if v.Type() == timeType || v.Type().Implements(textMarshalerType) || (v.Kind() != reflect.Ptr && v.CanAddr() && reflect.PtrTo(v.Type()).Implements(textMarshalerType)) {
return false return false
} }
+159 -33
View File
@@ -4,20 +4,27 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/big"
"strings" "strings"
"testing" "testing"
"time"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
//nolint:funlen
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
someInt := 42 someInt := 42
type structInline struct { type structInline struct {
A interface{} `inline:"true"` A interface{} `toml:",inline"`
}
type comments struct {
One int
Two int `comment:"Before kv"`
Three []int `comment:"Before array"`
} }
examples := []struct { examples := []struct {
@@ -193,9 +200,9 @@ name = 'Alice'
{ {
desc: "string escapes", desc: "string escapes",
v: map[string]interface{}{ v: map[string]interface{}{
"a": `'"\`, "a": "'\b\f\r\t\"\\",
}, },
expected: `a = "'\"\\"`, expected: `a = "'\b\f\r\t\"\\"`,
}, },
{ {
desc: "string utf8 low", desc: "string utf8 low",
@@ -242,7 +249,7 @@ name = 'Alice'
{ {
desc: "multi-line forced", desc: "multi-line forced",
v: struct { v: struct {
A string `multiline:"true"` A string `toml:",multiline"`
}{ }{
A: "hello\nworld", A: "hello\nworld",
}, },
@@ -253,7 +260,7 @@ world"""`,
{ {
desc: "inline field", desc: "inline field",
v: struct { v: struct {
A map[string]string `inline:"true"` A map[string]string `toml:",inline"`
B map[string]string B map[string]string
}{ }{
A: map[string]string{ A: map[string]string{
@@ -272,7 +279,7 @@ isinline = 'no'
{ {
desc: "mutiline array int", desc: "mutiline array int",
v: struct { v: struct {
A []int `multiline:"true"` A []int `toml:",multiline"`
B []int B []int
}{ }{
A: []int{1, 2, 3, 4}, A: []int{1, 2, 3, 4},
@@ -291,7 +298,7 @@ B = [1, 2, 3, 4]
{ {
desc: "mutiline array in array", desc: "mutiline array in array",
v: struct { v: struct {
A [][]int `multiline:"true"` A [][]int `toml:",multiline"`
}{ }{
A: [][]int{{1, 2}, {3, 4}}, A: [][]int{{1, 2}, {3, 4}},
}, },
@@ -469,6 +476,28 @@ hello = 'world'`,
}, },
err: true, err: true,
}, },
{
desc: "time",
v: struct {
T time.Time
}{
T: time.Time{},
},
expected: `T = '0001-01-01T00:00:00Z'`,
},
{
desc: "bool",
v: struct {
A bool
B bool
}{
A: false,
B: true,
},
expected: `
A = false
B = true`,
},
{ {
desc: "numbers", desc: "numbers",
v: struct { v: struct {
@@ -483,6 +512,7 @@ hello = 'world'`,
I int16 I int16
J int8 J int8
K int K int
L float64
}{ }{
A: 1.1, A: 1.1,
B: 42, B: 42,
@@ -495,6 +525,7 @@ hello = 'world'`,
I: 42, I: 42,
J: 42, J: 42,
K: 42, K: 42,
L: 2.2,
}, },
expected: ` expected: `
A = 1.1 A = 1.1
@@ -507,7 +538,29 @@ G = 42
H = 42 H = 42
I = 42 I = 42
J = 42 J = 42
K = 42`, K = 42
L = 2.2`,
},
{
desc: "comments",
v: struct {
Table comments `comment:"Before table"`
}{
Table: comments{
One: 1,
Two: 2,
Three: []int{1, 2, 3},
},
},
expected: `
# Before table
[Table]
One = 1
# Before kv
Two = 2
# Before array
Three = [1, 2, 3]
`,
}, },
} }
@@ -551,7 +604,7 @@ K = 42`,
type flagsSetters []struct { type flagsSetters []struct {
name string name string
f func(enc *toml.Encoder, flag bool) f func(enc *toml.Encoder, flag bool) *toml.Encoder
} }
var allFlags = flagsSetters{ var allFlags = flagsSetters{
@@ -734,6 +787,60 @@ func TestEncoderSetIndentSymbol(t *testing.T) {
equalStringsIgnoreNewlines(t, expected, w.String()) equalStringsIgnoreNewlines(t, expected, w.String())
} }
func TestEncoderOmitempty(t *testing.T) {
type doc struct {
String string `toml:",omitempty,multiline"`
Bool bool `toml:",omitempty,multiline"`
Int int `toml:",omitempty,multiline"`
Int8 int8 `toml:",omitempty,multiline"`
Int16 int16 `toml:",omitempty,multiline"`
Int32 int32 `toml:",omitempty,multiline"`
Int64 int64 `toml:",omitempty,multiline"`
Uint uint `toml:",omitempty,multiline"`
Uint8 uint8 `toml:",omitempty,multiline"`
Uint16 uint16 `toml:",omitempty,multiline"`
Uint32 uint32 `toml:",omitempty,multiline"`
Uint64 uint64 `toml:",omitempty,multiline"`
Float32 float32 `toml:",omitempty,multiline"`
Float64 float64 `toml:",omitempty,multiline"`
MapNil map[string]string `toml:",omitempty,multiline"`
Slice []string `toml:",omitempty,multiline"`
Ptr *string `toml:",omitempty,multiline"`
Iface interface{} `toml:",omitempty,multiline"`
Struct struct{} `toml:",omitempty,multiline"`
}
d := doc{}
b, err := toml.Marshal(d)
require.NoError(t, err)
expected := `[Struct]`
equalStringsIgnoreNewlines(t, expected, string(b))
}
func TestEncoderTagFieldName(t *testing.T) {
type doc struct {
String string `toml:"hello"`
OkSym string `toml:"#"`
Bad string `toml:"\"`
}
d := doc{String: "world"}
b, err := toml.Marshal(d)
require.NoError(t, err)
expected := `
hello = 'world'
'#' = ''
Bad = ''
`
equalStringsIgnoreNewlines(t, expected, string(b))
}
func TestIssue436(t *testing.T) { func TestIssue436(t *testing.T) {
data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`)
@@ -798,6 +905,48 @@ func TestIssue590(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestIssue571(t *testing.T) {
type Foo struct {
Float32 float32
Float64 float64
}
const closeEnough = 1e-9
foo := Foo{
Float32: 42,
Float64: 43,
}
b, err := toml.Marshal(foo)
require.NoError(t, err)
var foo2 Foo
err = toml.Unmarshal(b, &foo2)
require.NoError(t, err)
assert.InDelta(t, 42, foo2.Float32, closeEnough)
assert.InDelta(t, 43, foo2.Float64, closeEnough)
}
func TestIssue678(t *testing.T) {
type Config struct {
BigInt big.Int
}
cfg := &Config{
BigInt: *big.NewInt(123),
}
out, err := toml.Marshal(cfg)
require.NoError(t, err)
equalStringsIgnoreNewlines(t, "BigInt = '123'", string(out))
cfg2 := &Config{}
err = toml.Unmarshal(out, cfg2)
require.NoError(t, err)
require.Equal(t, cfg, cfg2)
}
func ExampleMarshal() { func ExampleMarshal() {
type MyConfig struct { type MyConfig struct {
Version int Version int
@@ -822,26 +971,3 @@ func ExampleMarshal() {
// Name = 'go-toml' // Name = 'go-toml'
// Tags = ['go', 'toml'] // Tags = ['go', 'toml']
} }
func TestIssue571(t *testing.T) {
type Foo struct {
Float32 float32
Float64 float64
}
const closeEnough = 1e-9
foo := Foo{
Float32: 42,
Float64: 43,
}
b, err := toml.Marshal(foo)
require.NoError(t, err)
var foo2 Foo
err = toml.Unmarshal(b, &foo2)
require.NoError(t, err)
assert.InDelta(t, 42, foo2.Float32, closeEnough)
assert.InDelta(t, 43, foo2.Float64, closeEnough)
}
+13 -20
View File
@@ -549,7 +549,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
startIdx := i startIdx := i
endIdx := len(token) - len(`"""`) endIdx := len(token) - len(`"""`)
if escaped < 0 { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
@@ -578,6 +578,10 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
switch token[i+j] { switch token[i+j] {
case ' ', '\t': case ' ', '\t':
continue continue
case '\r':
if token[i+j+1] == '\n' {
continue
}
case '\n': case '\n':
isLastNonWhitespaceOnLine = true isLastNonWhitespaceOnLine = true
} }
@@ -689,13 +693,13 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
} }
func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) { func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
if len(b) == 0 {
return nil, nil, nil, newDecodeError(b, "expected key but found none")
}
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
// unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
// quoted-key = basic-string / literal-string // quoted-key = basic-string / literal-string
if len(b) == 0 {
return nil, nil, nil, newDecodeError(b, "key is incomplete")
}
switch { switch {
case b[0] == '\'': case b[0] == '\'':
return p.parseLiteralString(b) return p.parseLiteralString(b)
@@ -736,7 +740,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// Fast path. If there is no escape sequence, the string should just be // Fast path. If there is no escape sequence, the string should just be
// an UTF-8 encoded string, which is the same as Go. In that case, // an UTF-8 encoded string, which is the same as Go. In that case,
// validate the string and return a direct reference to the buffer. // validate the string and return a direct reference to the buffer.
if escaped < 0 { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
@@ -866,7 +870,6 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
//nolint:gomnd
if len(b) < 3 { if len(b) < 3 {
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
@@ -884,23 +887,13 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
if idx == 2 && c == ':' || (idx == 4 && c == '-') { if idx == 2 && c == ':' || (idx == 4 && c == '-') {
return p.scanDateTime(b) return p.scanDateTime(b)
} }
break
} }
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
func digitsToInt(b []byte) int {
x := 0
for _, d := range b {
x *= 10
x += int(d - '0')
}
return x
}
//nolint:gocognit,cyclop
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
// followed by a digit. // followed by a digit.
@@ -970,7 +963,7 @@ byteLoop:
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
i := 0 i := 0
if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' { if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' {
var isValidRune validRuneFn var isValidRune validRuneFn
switch b[1] { switch b[1] {
+19
View File
@@ -1,6 +1,8 @@
package toml package toml
import ( import (
"strconv"
"strings"
"testing" "testing"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
@@ -371,6 +373,23 @@ func BenchmarkParseBasicStringWithUnicode(b *testing.B) {
}) })
} }
func BenchmarkParseBasicStringsEasy(b *testing.B) {
p := &parser{}
for _, size := range []int{1, 4, 8, 16, 21} {
b.Run(strconv.Itoa(size), func(b *testing.B) {
input := []byte(`"` + strings.Repeat("A", size) + `"`)
b.ReportAllocs()
b.SetBytes(int64(len(input)))
for i := 0; i < b.N; i++ {
p.parseBasicString(input)
}
})
}
}
func TestParser_AST_DateTimes(t *testing.T) { func TestParser_AST_DateTimes(t *testing.T) {
examples := []struct { examples := []struct {
desc string desc string
+52 -58
View File
@@ -53,7 +53,7 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) {
switch b[i] { switch b[i] {
case '\'': case '\'':
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
case '\n': case '\n', '\r':
return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines")
} }
size := utf8ValidNext(b[i:]) size := utf8ValidNext(b[i:])
@@ -76,30 +76,42 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
// mll-char = %x09 / %x20-26 / %x28-7E / non-ascii // mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
// mll-quotes = 1*2apostrophe // mll-quotes = 1*2apostrophe
for i := 3; i < len(b); { for i := 3; i < len(b); {
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) { switch b[i] {
i += 3 case '\'':
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
i += 3
// At that point we found 3 apostrophe, and i is the // At that point we found 3 apostrophe, and i is the
// index of the byte after the third one. The scanner // index of the byte after the third one. The scanner
// needs to be eager, because there can be an extra 2 // needs to be eager, because there can be an extra 2
// apostrophe that can be accepted at the end of the // apostrophe that can be accepted at the end of the
// string. // string.
if i >= len(b) || b[i] != '\'' {
return b[:i], b[i:], nil
}
i++
if i >= len(b) || b[i] != '\'' {
return b[:i], b[i:], nil
}
i++
if i < len(b) && b[i] == '\'' {
return nil, nil, newDecodeError(b[i-3:i+1], "''' not allowed in multiline literal string")
}
if i >= len(b) || b[i] != '\'' {
return b[:i], b[i:], nil return b[:i], b[i:], nil
} }
i++ case '\r':
if len(b) < i+2 {
if i >= len(b) || b[i] != '\'' { return nil, nil, newDecodeError(b[len(b):], `need a \n after \r`)
return b[:i], b[i:], nil
} }
i++ if b[i+1] != '\n' {
return nil, nil, newDecodeError(b[i:i+2], `need a \n after \r`)
if i < len(b) && b[i] == '\'' {
return nil, nil, newDecodeError(b[i-3:i+1], "''' not allowed in multiline literal string")
} }
i += 2 // skip the \n
return b[:i], b[i:], nil continue
} }
size := utf8ValidNext(b[i:]) size := utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
@@ -149,6 +161,12 @@ func scanComment(b []byte) ([]byte, []byte, error) {
if b[i] == '\n' { if b[i] == '\n' {
return b[:i], b[i:], nil return b[:i], b[i:], nil
} }
if b[i] == '\r' {
if i+1 < len(b) && b[i+1] == '\n' {
return b[:i+1], b[i+1:], nil
}
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment")
}
size := utf8ValidNext(b[i:]) size := utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment") return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment")
@@ -160,42 +178,26 @@ func scanComment(b []byte) ([]byte, []byte, error) {
return b, b[len(b):], nil return b, b[len(b):], nil
} }
func scanBasicString(b []byte) ([]byte, int, []byte, error) { func scanBasicString(b []byte) ([]byte, bool, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
// basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
// basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// escaped = escape escape-seq-char // escaped = escape escape-seq-char
escaped := -1 // index of the first \. -1 means no escape character in there. escaped := false
i := 1 i := 1
loop:
for ; i < len(b); i++ { for ; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
return b[:i+1], escaped, b[i+1:], nil return b[:i+1], escaped, b[i+1:], nil
case '\n': case '\n', '\r':
return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines")
case '\\':
if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\")
}
escaped = i
i += 2 // skip the next character
break loop
}
}
for ; i < len(b); i++ {
switch b[i] {
case '"':
return b[:i+1], escaped, b[i+1:], nil
case '\n':
return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines") return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines")
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\") return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\")
} }
escaped = true
i++ // skip the next character i++ // skip the next character
} }
} }
@@ -203,7 +205,7 @@ loop:
return nil, escaped, nil, newDecodeError(b[len(b):], `basic string not terminated by "`) return nil, escaped, nil, newDecodeError(b[len(b):], `basic string not terminated by "`)
} }
func scanMultilineBasicString(b []byte) ([]byte, int, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim // ml-basic-string-delim
// ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
@@ -215,10 +217,9 @@ func scanMultilineBasicString(b []byte) ([]byte, int, []byte, error) {
// mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii // mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// mlb-escaped-nl = escape ws newline *( wschar / newline ) // mlb-escaped-nl = escape ws newline *( wschar / newline )
escaped := -1 escaped := false
i := 3 i := 3
loop:
for ; i < len(b); i++ { for ; i < len(b); i++ {
switch b[i] { switch b[i] {
case '"': case '"':
@@ -251,23 +252,16 @@ loop:
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\") return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\")
} }
escaped = i escaped = true
i += 2 // skip the next character
break loop
}
}
for ; i < len(b); i++ {
switch b[i] {
case '"':
if scanFollowsMultilineBasicStringDelimiter(b[i:]) {
return b[:i+3], escaped, b[i+3:], nil
}
case '\\':
if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\")
}
i++ // skip the next character i++ // skip the next character
case '\r':
if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], `need a \n after \r`)
}
if b[i+1] != '\n' {
return nil, escaped, nil, newDecodeError(b[i:i+2], `need a \n after \r`)
}
i++ // skip the \n
} }
} }
+1 -1
View File
@@ -43,7 +43,7 @@ func DecodeStdin() error {
j := json.NewEncoder(os.Stdout) j := json.NewEncoder(os.Stdout)
j.SetIndent("", " ") j.SetIndent("", " ")
if err := j.Encode(addTag("", decoded)); err != nil { if err := j.Encode(addTag("", decoded)); err != nil {
fmt.Errorf("Error encoding JSON: %s", err) return fmt.Errorf("Error encoding JSON: %s", err)
} }
return nil return nil
+11 -1
View File
@@ -1,4 +1,4 @@
// Generated by tomltestgen for toml-test ref master on 2021-09-30T20:29:36-05:00 // Generated by tomltestgen for toml-test ref master on 2021-11-08T22:33:24-05:00
package toml_test package toml_test
import ( import (
@@ -375,6 +375,11 @@ func TestTOMLTest_Invalid_Float_TrailingPoint(t *testing.T) {
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestTOMLTest_Invalid_Float_TrailingUsExp(t *testing.T) {
input := "# trailing underscore in integer part is not allowed\ntrailing-us-exp = 1_e2\n# trailing underscore in float part is not allowed\ntrailing-us-exp2 = 1.2_e2\n"
testgenInvalid(t, input)
}
func TestTOMLTest_Invalid_Float_TrailingUs(t *testing.T) { func TestTOMLTest_Invalid_Float_TrailingUs(t *testing.T) {
input := "trailing-us = 1.2_\n" input := "trailing-us = 1.2_\n"
testgenInvalid(t, input) testgenInvalid(t, input)
@@ -395,6 +400,11 @@ func TestTOMLTest_Invalid_InlineTable_DoubleComma(t *testing.T) {
testgenInvalid(t, input) testgenInvalid(t, input)
} }
func TestTOMLTest_Invalid_InlineTable_DuplicateKey(t *testing.T) {
input := "# Duplicate keys within an inline table are invalid\na={b=1, b=2}\n"
testgenInvalid(t, input)
}
func TestTOMLTest_Invalid_InlineTable_Empty(t *testing.T) { func TestTOMLTest_Invalid_InlineTable_Empty(t *testing.T) {
input := "t = {,}\n" input := "t = {,}\n"
testgenInvalid(t, input) testgenInvalid(t, input)
+1
View File
@@ -11,3 +11,4 @@ var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
var sliceInterfaceType = reflect.TypeOf([]interface{}{}) var sliceInterfaceType = reflect.TypeOf([]interface{}{})
var stringType = reflect.TypeOf("")
+128 -85
View File
@@ -9,10 +9,11 @@ import (
"math" "math"
"reflect" "reflect"
"strings" "strings"
"sync" "sync/atomic"
"time" "time"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
) )
@@ -47,8 +48,9 @@ func NewDecoder(r io.Reader) *Decoder {
// that could not be set on the target value. In that case, the decoder returns // that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as // a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields. // well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) { func (d *Decoder) SetStrict(strict bool) *Decoder {
d.strict = strict d.strict = strict
return d
} }
// Decode the whole content of r into v. // Decode the whole content of r into v.
@@ -58,7 +60,8 @@ func (d *Decoder) SetStrict(strict bool) {
// //
// When a TOML local date, time, or date-time is decoded into a time.Time, its // When a TOML local date, time, or date-time is decoded into a time.Time, its
// value is represented in time.Local timezone. Otherwise the approriate Local* // value is represented in time.Local timezone. Otherwise the approriate Local*
// structure is used. // structure is used. For time values, precision up to the nanosecond is
// supported by truncating extra digits.
// //
// Empty tables decoded in an interface{} create an empty initialized // Empty tables decoded in an interface{} create an empty initialized
// map[string]interface{}. // map[string]interface{}.
@@ -70,6 +73,11 @@ func (d *Decoder) SetStrict(strict bool) {
// bounds for the target type (which includes negative numbers when decoding // bounds for the target type (which includes negative numbers when decoding
// into an unsigned int). // into an unsigned int).
// //
// If an error occurs while decoding the content of the document, this function
// returns a toml.DecodeError, providing context about the issue. When using
// strict mode and a field is missing, a `toml.StrictMissingError` is
// returned. In any other case, this function returns a standard Go error.
//
// Type mapping // Type mapping
// //
// List of supported TOML types and their associated accepted Go types: // List of supported TOML types and their associated accepted Go types:
@@ -129,6 +137,23 @@ type decoder struct {
// Strict mode // Strict mode
strict strict strict strict
// Current context for the error.
errorContext *errorContext
}
type errorContext struct {
Struct reflect.Type
Field []int
}
func (d *decoder) typeMismatchError(toml string, target reflect.Type) error {
if d.errorContext != nil && d.errorContext.Struct != nil {
ctx := d.errorContext
f := ctx.Struct.FieldByIndex(ctx.Field)
return fmt.Errorf("toml: cannot decode TOML %s into struct field %s.%s of type %s", toml, ctx.Struct, f.Name, f.Type)
}
return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target)
} }
func (d *decoder) expr() *ast.Node { func (d *decoder) expr() *ast.Node {
@@ -343,7 +368,9 @@ func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value)
if err != nil { if err != nil {
return reflect.Value{}, err return reflect.Value{}, err
} }
v.Elem().Set(elem) if elem.IsValid() {
v.Elem().Set(elem)
}
return v, nil return v, nil
case reflect.Slice: case reflect.Slice:
@@ -384,12 +411,14 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
elem = v.Elem() elem = v.Elem()
return d.handleKeyPart(key, elem, nextFn, makeFn) return d.handleKeyPart(key, elem, nextFn, makeFn)
case reflect.Map: case reflect.Map:
// Create the key for the map element. For now assume it's a string. // Create the key for the map element. For now assume it's a string.
mk := reflect.ValueOf(string(key.Node().Data)) mk := reflect.ValueOf(string(key.Node().Data))
// If the map does not exist, create it. // If the map does not exist, create it.
if v.IsNil() { if v.IsNil() {
v = reflect.MakeMap(v.Type()) vt := v.Type()
v = reflect.MakeMap(vt)
rv = v rv = v
} }
@@ -401,7 +430,8 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// map[string]interface{} or a []interface{} depending on whether // map[string]interface{} or a []interface{} depending on whether
// this is the last part of the array table key. // this is the last part of the array table key.
t := v.Type().Elem() vt := v.Type()
t := vt.Elem()
if t.Kind() == reflect.Interface { if t.Kind() == reflect.Interface {
mv = makeFn() mv = makeFn()
} else { } else {
@@ -415,7 +445,8 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
} }
set = true set = true
} else if !mv.CanAddr() { } else if !mv.CanAddr() {
t := v.Type().Elem() vt := v.Type()
t := vt.Elem()
oldmv := mv oldmv := mv
mv = reflect.New(t).Elem() mv = reflect.New(t).Elem()
mv.Set(oldmv) mv.Set(oldmv)
@@ -436,12 +467,20 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
v.SetMapIndex(mk, mv) v.SetMapIndex(mk, mv)
} }
case reflect.Struct: case reflect.Struct:
f, found := structField(v, string(key.Node().Data)) path, found := structFieldPath(v, string(key.Node().Data))
if !found { if !found {
d.skipUntilTable = true d.skipUntilTable = true
return reflect.Value{}, nil return reflect.Value{}, nil
} }
if d.errorContext == nil {
d.errorContext = new(errorContext)
}
t := v.Type()
d.errorContext.Struct = t
d.errorContext.Field = path
f := v.FieldByIndex(path)
x, err := nextFn(key, f) x, err := nextFn(key, f)
if err != nil || d.skipUntilTable { if err != nil || d.skipUntilTable {
return reflect.Value{}, err return reflect.Value{}, err
@@ -449,11 +488,13 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
if x.IsValid() { if x.IsValid() {
f.Set(x) f.Set(x)
} }
d.errorContext.Field = nil
d.errorContext.Struct = nil
case reflect.Interface: case reflect.Interface:
if v.Elem().IsValid() { if v.Elem().IsValid() {
v = v.Elem() v = v.Elem()
} else { } else {
v = reflect.MakeMap(mapStringInterfaceType) v = makeMapStringInterface()
} }
x, err := d.handleKeyPart(key, v, nextFn, makeFn) x, err := d.handleKeyPart(key, v, nextFn, makeFn)
@@ -649,7 +690,7 @@ func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error {
default: default:
// TODO: use newDecodeError, but first the parser needs to fill // TODO: use newDecodeError, but first the parser needs to fill
// array.Data. // array.Data.
return fmt.Errorf("toml: cannot store array in Go type %s", v.Kind()) return d.typeMismatchError("array", v.Type())
} }
elemType := v.Type().Elem() elemType := v.Type().Elem()
@@ -697,7 +738,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
case reflect.Interface: case reflect.Interface:
elem := v.Elem() elem := v.Elem()
if !elem.IsValid() { if !elem.IsValid() {
elem = reflect.MakeMap(mapStringInterfaceType) elem = makeMapStringInterface()
v.Set(elem) v.Set(elem)
} }
return d.unmarshalInlineTable(itable, elem) return d.unmarshalInlineTable(itable, elem)
@@ -896,7 +937,7 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error {
case reflect.Interface: case reflect.Interface:
r = reflect.ValueOf(i) r = reflect.ValueOf(i)
default: default:
return fmt.Errorf("toml: cannot store TOML integer into a Go %s", v.Kind()) return d.typeMismatchError("integer", v.Type())
} }
if !r.Type().AssignableTo(v.Type()) { if !r.Type().AssignableTo(v.Type()) {
@@ -953,12 +994,15 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
// There is no guarantee over what it could be. // There is no guarantee over what it could be.
switch v.Kind() { switch v.Kind() {
case reflect.Map: case reflect.Map:
mk := reflect.ValueOf(string(key.Node().Data)) vt := v.Type()
keyType := v.Type().Key() mk := reflect.ValueOf(string(key.Node().Data))
if !mk.Type().AssignableTo(keyType) { mkt := stringType
if !mk.Type().ConvertibleTo(keyType) {
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mk.Type(), keyType) keyType := vt.Key()
if !mkt.AssignableTo(keyType) {
if !mkt.ConvertibleTo(keyType) {
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mkt, keyType)
} }
mk = mk.Convert(keyType) mk = mk.Convert(keyType)
@@ -966,7 +1010,7 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
// If the map does not exist, create it. // If the map does not exist, create it.
if v.IsNil() { if v.IsNil() {
v = reflect.MakeMap(v.Type()) v = reflect.MakeMap(vt)
rv = v rv = v
} }
@@ -996,12 +1040,20 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
v.SetMapIndex(mk, mv) v.SetMapIndex(mk, mv)
} }
case reflect.Struct: case reflect.Struct:
f, found := structField(v, string(key.Node().Data)) path, found := structFieldPath(v, string(key.Node().Data))
if !found { if !found {
d.skipUntilTable = true d.skipUntilTable = true
break break
} }
if d.errorContext == nil {
d.errorContext = new(errorContext)
}
t := v.Type()
d.errorContext.Struct = t
d.errorContext.Field = path
f := v.FieldByIndex(path)
x, err := d.handleKeyValueInner(key, value, f) x, err := d.handleKeyValueInner(key, value, f)
if err != nil { if err != nil {
return reflect.Value{}, err return reflect.Value{}, err
@@ -1010,14 +1062,17 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
if x.IsValid() { if x.IsValid() {
f.Set(x) f.Set(x)
} }
d.errorContext.Struct = nil
d.errorContext.Field = nil
case reflect.Interface: case reflect.Interface:
v = v.Elem() v = v.Elem()
// Following encoding/toml: decoding an object into an interface{}, it // Following encoding/json: decoding an object into an
// needs to always hold a map[string]interface{}. This is for the types // interface{}, it needs to always hold a
// to be consistent whether a previous value was set or not. // map[string]interface{}. This is for the types to be
// consistent whether a previous value was set or not.
if !v.IsValid() || v.Type() != mapStringInterfaceType { if !v.IsValid() || v.Type() != mapStringInterfaceType {
v = reflect.MakeMap(mapStringInterfaceType) v = makeMapStringInterface()
} }
x, err := d.handleKeyValuePart(key, value, v) x, err := d.handleKeyValuePart(key, value, v)
@@ -1064,80 +1119,68 @@ func initAndDereferencePointer(v reflect.Value) reflect.Value {
type fieldPathsMap = map[string][]int type fieldPathsMap = map[string][]int
type fieldPathsCache struct { var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap
m map[reflect.Type]fieldPathsMap
l sync.RWMutex
}
func (c *fieldPathsCache) get(t reflect.Type) (fieldPathsMap, bool) { func structFieldPath(v reflect.Value, name string) ([]int, bool) {
c.l.RLock() t := v.Type()
paths, ok := c.m[t]
c.l.RUnlock()
return paths, ok cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap)
} fieldPaths, ok := cache[danger.MakeTypeID(t)]
func (c *fieldPathsCache) set(t reflect.Type, m fieldPathsMap) {
c.l.Lock()
c.m[t] = m
c.l.Unlock()
}
var globalFieldPathsCache = fieldPathsCache{
m: map[reflect.Type]fieldPathsMap{},
l: sync.RWMutex{},
}
func structField(v reflect.Value, name string) (reflect.Value, bool) {
//nolint:godox
// TODO: cache this, and reduce allocations
fieldPaths, ok := globalFieldPathsCache.get(v.Type())
if !ok { if !ok {
fieldPaths = map[string][]int{} fieldPaths = map[string][]int{}
path := make([]int, 0, 16) forEachField(t, nil, func(name string, path []int) {
fieldPaths[name] = path
// extra copy for the case-insensitive match
fieldPaths[strings.ToLower(name)] = path
})
var walk func(reflect.Value) newCache := make(map[danger.TypeID]fieldPathsMap, len(cache)+1)
walk = func(v reflect.Value) { newCache[danger.MakeTypeID(t)] = fieldPaths
t := v.Type() for k, v := range cache {
for i := 0; i < t.NumField(); i++ { newCache[k] = v
l := len(path)
path = append(path, i)
f := t.Field(i)
if f.Anonymous {
walk(v.Field(i))
} else if f.PkgPath == "" {
// only consider exported fields
fieldName, ok := f.Tag.Lookup("toml")
if !ok {
fieldName = f.Name
}
pathCopy := make([]int, len(path))
copy(pathCopy, path)
fieldPaths[fieldName] = pathCopy
// extra copy for the case-insensitive match
fieldPaths[strings.ToLower(fieldName)] = pathCopy
}
path = path[:l]
}
} }
globalFieldPathsCache.Store(newCache)
walk(v)
globalFieldPathsCache.set(v.Type(), fieldPaths)
} }
path, ok := fieldPaths[name] path, ok := fieldPaths[name]
if !ok { if !ok {
path, ok = fieldPaths[strings.ToLower(name)] path, ok = fieldPaths[strings.ToLower(name)]
} }
return path, ok
if !ok { }
return reflect.Value{}, false
} func forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
n := t.NumField()
return v.FieldByIndex(path), true for i := 0; i < n; i++ {
f := t.Field(i)
if !f.Anonymous && f.PkgPath != "" {
// only consider exported fields.
continue
}
fieldPath := append(path, i)
fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]
if f.Anonymous {
forEachField(f.Type, fieldPath, do)
continue
}
name := f.Tag.Get("toml")
if name == "-" {
continue
}
if i := strings.IndexByte(name, ','); i >= 0 {
name = name[:i]
}
if name == "" {
name = f.Name
}
do(name, fieldPath)
}
} }
+723 -347
View File
File diff suppressed because it is too large Load Diff
+38 -1
View File
@@ -140,8 +140,45 @@ func utf8ValidNext(p []byte) int {
return size return size
} }
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func invalidAscii(b byte) bool { func invalidAscii(b byte) bool {
return b <= 0x08 || (b > 0x0A && b < 0x0D) || (b > 0x0D && b <= 0x1F) || b == 0x7F return invalidAsciiTable[b]
} }
// acceptRange gives the range of valid values for the second byte in a UTF-8 // acceptRange gives the range of valid values for the second byte in a UTF-8