Compare commits

...

9 Commits

Author SHA1 Message Date
Thomas Pelletier 9b67e40640 decoder: strict mode (#512) 2021-04-20 21:26:22 -04:00
Vincent Serpoul dca2103910 golangci-lint: marshaler (#516) 2021-04-20 20:24:44 -04:00
Cameron Moore a713a96e69 Add more newline tests for scanner (#515) 2021-04-16 19:07:29 -04:00
Cameron Moore a7b50eb8f1 Tidy (#511)
* Disconnect package godoc comment from imported file

* Add missing newline in toml.abnf

* Tag testing helper funcs
2021-04-15 16:49:19 -04:00
Cameron Moore 24b62ebe61 Simplify scanFollows usage (#510)
Use static functions to avoid declaring global vars and creating more
package init costs.  This change has no negative effects on benchmarks
in my testing.
2021-04-15 16:48:19 -04:00
Thomas Pelletier 9bc4641a49 ci-lint: disable ifshort 2021-04-15 13:37:24 -04:00
Thomas Pelletier b86b890b8d decoder: handle private anonymous structs
Ref #508
2021-04-15 12:49:24 -04:00
Vincent Serpoul 080baa8574 golangci-lint: localtime (#509) 2021-04-15 12:44:31 -04:00
Thomas Pelletier 0537b928df decoder: add test for #507 2021-04-15 11:36:36 -04:00
20 changed files with 1087 additions and 408 deletions
+4 -1
View File
@@ -4,6 +4,9 @@ golangci-lint-version = "1.39.0"
[linters-settings.wsl]
allow-assign-and-anything = true
[linters-settings.exhaustive]
default-signifies-exhaustive = true
[linters]
disable-all = true
enable = [
@@ -45,7 +48,7 @@ enable = [
"gosec",
"gosimple",
"govet",
"ifshort",
# "ifshort",
"importas",
"ineffassign",
"lll",
+1 -1
View File
@@ -22,7 +22,7 @@ Development branch. Use at your own risk.
- [x] Abstract AST.
- [x] Original go-toml testgen tests pass.
- [x] Track file position (line, column) for errors.
- [ ] Strict mode.
- [x] Strict mode.
- [ ] Document Unmarshal / Decode
### Marshal
+38 -1
View File
@@ -18,15 +18,46 @@ type DecodeError struct {
message string
line int
column int
key Key
human string
}
// StrictMissingError occurs in a TOML document that does not have a
// corresponding field in the target value. It contains all the missing fields
// in Errors.
//
// Emitted by Decoder when SetStrict(true) was called.
type StrictMissingError struct {
// One error per field that could not be found.
Errors []DecodeError
}
// Error returns the cannonical string for this error.
func (s *StrictMissingError) Error() string {
return "strict mode: fields in the document are missing in the target struct"
}
// String returns a human readable description of all errors.
func (s *StrictMissingError) String() string {
var buf strings.Builder
for i, e := range s.Errors {
if i > 0 {
buf.WriteString("\n---\n")
}
buf.WriteString(e.String())
}
return buf.String()
}
type Key []string
// internal version of DecodeError that is used as the base to create a
// DecodeError with full context.
type decodeError struct {
highlight []byte
message string
key Key // optional
}
func (de *decodeError) Error() string {
@@ -56,6 +87,11 @@ func (e *DecodeError) Position() (row int, column int) {
return e.line, e.column
}
// Key that was being processed when the error occured.
func (e *DecodeError) Key() Key {
return e.key
}
// decodeErrorFromHighlight creates a DecodeError referencing to a highlighted
// range of bytes from document.
//
@@ -64,7 +100,7 @@ func (e *DecodeError) Position() (row int, column int) {
// The function copies all bytes used in DecodeError, so that document and
// highlight can be freely deallocated.
//nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) error {
func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
if de == nil {
return nil
}
@@ -137,6 +173,7 @@ func wrapDecodeError(document []byte, de *decodeError) error {
message: errMessage,
line: errLine,
column: errColumn,
key: de.key,
human: buf.String(),
}
}
@@ -7,6 +7,7 @@ package imported_tests
// marked as skipped until we figure out if that's something we want in v2.
import (
"bytes"
"errors"
"fmt"
"reflect"
@@ -1955,66 +1956,80 @@ String2="2"`
assert.Error(t, err)
}
func decoder(doc string) *toml.Decoder {
return toml.NewDecoder(bytes.NewReader([]byte(doc)))
}
func strictDecoder(doc string) *toml.Decoder {
d := decoder(doc)
d.SetStrict(true)
return d
}
func TestDecoderStrict(t *testing.T) {
t.Skip()
// input := `
//[decoded]
// key = ""
//
//[undecoded]
// key = ""
//
// [undecoded.inner]
// key = ""
//
// [[undecoded.array]]
// key = ""
//
// [[undecoded.array]]
// key = ""
//
//`
// var doc struct {
// Decoded struct {
// Key string
// }
// }
//
// expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]`
//
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc)
// if err == nil {
// t.Error("expected error, got none")
// } else if err.Error() != expected {
// t.Errorf("expect err: %s, got: %s", expected, err.Error())
// }
//
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil {
// t.Errorf("unexpected err: %s", err)
// }
//
// var m map[string]interface{}
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil {
// t.Errorf("unexpected err: %s", err)
// }
input := `
[decoded]
key = ""
[undecoded]
key = ""
[undecoded.inner]
key = ""
[[undecoded.array]]
key = ""
[[undecoded.array]]
key = ""
`
var doc struct {
Decoded struct {
Key string
}
}
err := strictDecoder(input).Decode(&doc)
require.Error(t, err)
require.IsType(t, &toml.StrictMissingError{}, err)
se := err.(*toml.StrictMissingError)
keys := []toml.Key{}
for _, e := range se.Errors {
keys = append(keys, e.Key())
}
expectedKeys := []toml.Key{
{"undecoded"},
{"undecoded", "inner"},
{"undecoded", "array"},
{"undecoded", "array"},
}
require.Equal(t, expectedKeys, keys)
err = decoder(input).Decode(&doc)
require.NoError(t, err)
var m map[string]interface{}
err = decoder(input).Decode(&m)
}
func TestDecoderStrictValid(t *testing.T) {
t.Skip()
// input := `
//[decoded]
// key = ""
//`
// var doc struct {
// Decoded struct {
// Key string
// }
// }
//
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc)
// if err != nil {
// t.Fatal("unexpected error:", err)
// }
input := `
[decoded]
key = ""
`
var doc struct {
Decoded struct {
Key string
}
}
err := strictDecoder(input).Decode(&doc)
require.NoError(t, err)
}
type docUnmarshalTOML struct {
+50
View File
@@ -0,0 +1,50 @@
package tracker
import (
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked.
type KeyTracker struct {
k []string
}
// UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node ast.Node) {
t.reset()
t.Push(node)
}
// UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node ast.Node) {
t.reset()
t.Push(node)
}
// Push the given key on the stack.
func (t *KeyTracker) Push(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = append(t.k, string(it.Node().Data))
}
}
// Pop key from stack.
func (t *KeyTracker) Pop(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = t.k[:len(t.k)-1]
}
}
// Key returns the current key
func (t *KeyTracker) Key() []string {
k := make([]string, len(t.k))
copy(k, t.k)
return k
}
func (t *KeyTracker) reset() {
t.k = t.k[:0]
}
+200
View File
@@ -0,0 +1,200 @@
package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type SeenTracker struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *SeenTracker) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *SeenTracker) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *SeenTracker) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
-199
View File
@@ -1,200 +1 @@
package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// Tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type Seen struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *Seen) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *Seen) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *Seen) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *Seen) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
+24
View File
@@ -33,3 +33,27 @@ func SubsliceOffset(data []byte, subslice []byte) int {
return intoffset
}
func BytesRange(start []byte, end []byte) []byte {
if start == nil || end == nil {
panic("cannot call BytesRange with nil")
}
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
if startp.Data > endp.Data {
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
}
l := startp.Len
endLen := int(endp.Data-startp.Data) + endp.Len
if endLen > l {
l = endLen
}
if l > startp.Cap {
panic(fmt.Errorf("range length is larger than capacity"))
}
return start[:l]
}
+89
View File
@@ -77,3 +77,92 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) {
})
}
}
func TestUnsafeBytesRange(t *testing.T) {
type fn = func() ([]byte, []byte)
examples := []struct {
desc string
test fn
expected []byte
}{
{
desc: "simple",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[6:8]
},
expected: []byte("ello wo"),
},
{
desc: "full",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[0:1], full[len(full)-1:]
},
expected: []byte("hello world"),
},
{
desc: "end before start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[len(full)-1:], full[0:1]
},
},
{
desc: "nils",
test: func() ([]byte, []byte) {
return nil, nil
},
},
{
desc: "nils start",
test: func() ([]byte, []byte) {
return nil, []byte("foo")
},
},
{
desc: "nils end",
test: func() ([]byte, []byte) {
return []byte("foo"), nil
},
},
{
desc: "start is end",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[1:3]
},
expected: []byte("el"),
},
{
desc: "end contained in start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:7], full[2:4]
},
expected: []byte("ello w"),
},
{
desc: "different backing arrays",
test: func() ([]byte, []byte) {
one := []byte("hello world")
two := []byte("hello world")
return one, two
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
start, end := e.test()
if e.expected == nil {
require.Panics(t, func() {
unsafe.BytesRange(start, end)
})
} else {
res := unsafe.BytesRange(start, end)
require.Equal(t, e.expected, res)
}
})
}
}
+41 -22
View File
@@ -23,6 +23,7 @@
//
// Because they lack location information, these types do not represent unique
// moments or intervals of time. Use time.Time for that purpose.
package toml
import (
@@ -44,6 +45,7 @@ type LocalDate struct {
func LocalDateOf(t time.Time) LocalDate {
var d LocalDate
d.Year, d.Month, d.Day = t.Date()
return d
}
@@ -51,8 +53,9 @@ func LocalDateOf(t time.Time) LocalDate {
func ParseLocalDate(s string) (LocalDate, error) {
t, err := time.Parse("2006-01-02", s)
if err != nil {
return LocalDate{}, err
return LocalDate{}, fmt.Errorf("ParseLocalDate: %w", err)
}
return LocalDateOf(t), nil
}
@@ -92,23 +95,28 @@ func (d LocalDate) DaysSince(s LocalDate) (days int) {
// We convert to Unix time so we do not have to worry about leap seconds:
// Unix time increases by exactly 86400 seconds per day.
deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix()
return int(deltaUnix / 86400)
const secondsInADay = 86400
return int(deltaUnix / secondsInADay)
}
// Before reports whether d1 occurs before d2.
func (d1 LocalDate) Before(d2 LocalDate) bool {
if d1.Year != d2.Year {
return d1.Year < d2.Year
// Before reports whether d1 occurs before future date.
func (d LocalDate) Before(future LocalDate) bool {
if d.Year != future.Year {
return d.Year < future.Year
}
if d1.Month != d2.Month {
return d1.Month < d2.Month
if d.Month != future.Month {
return d.Month < future.Month
}
return d1.Day < d2.Day
return d.Day < future.Day
}
// After reports whether d1 occurs after d2.
func (d1 LocalDate) After(d2 LocalDate) bool {
return d2.Before(d1)
// After reports whether d1 occurs after past date.
func (d LocalDate) After(past LocalDate) bool {
return past.Before(d)
}
// MarshalText implements the encoding.TextMarshaler interface.
@@ -122,6 +130,7 @@ func (d LocalDate) MarshalText() ([]byte, error) {
func (d *LocalDate) UnmarshalText(data []byte) error {
var err error
*d, err = ParseLocalDate(string(data))
return err
}
@@ -145,6 +154,7 @@ func LocalTimeOf(t time.Time) LocalTime {
var tm LocalTime
tm.Hour, tm.Minute, tm.Second = t.Clock()
tm.Nanosecond = t.Nanosecond()
return tm
}
@@ -156,8 +166,9 @@ func LocalTimeOf(t time.Time) LocalTime {
func ParseLocalTime(s string) (LocalTime, error) {
t, err := time.Parse("15:04:05.999999999", s)
if err != nil {
return LocalTime{}, err
return LocalTime{}, fmt.Errorf("ParseLocalTime: %w", err)
}
return LocalTimeOf(t), nil
}
@@ -169,6 +180,7 @@ func (t LocalTime) String() string {
if t.Nanosecond == 0 {
return s
}
return s + fmt.Sprintf(".%09d", t.Nanosecond)
}
@@ -176,6 +188,7 @@ func (t LocalTime) String() string {
func (t LocalTime) IsValid() bool {
// Construct a non-zero time.
tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC)
return LocalTimeOf(tm) == t
}
@@ -190,6 +203,7 @@ func (t LocalTime) MarshalText() ([]byte, error) {
func (t *LocalTime) UnmarshalText(data []byte) error {
var err error
*t, err = ParseLocalTime(string(data))
return err
}
@@ -223,9 +237,10 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) {
if err != nil {
t, err = time.Parse("2006-01-02t15:04:05.999999999", s)
if err != nil {
return LocalDateTime{}, err
return LocalDateTime{}, fmt.Errorf("ParseLocalDateTime: %w", err)
}
}
return LocalDateTimeOf(t), nil
}
@@ -253,17 +268,20 @@ func (dt LocalDateTime) IsValid() bool {
//
// In panics if loc is nil.
func (dt LocalDateTime) In(loc *time.Location) time.Time {
return time.Date(dt.Date.Year, dt.Date.Month, dt.Date.Day, dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc)
return time.Date(
dt.Date.Year, dt.Date.Month, dt.Date.Day,
dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc,
)
}
// Before reports whether dt1 occurs before dt2.
func (dt1 LocalDateTime) Before(dt2 LocalDateTime) bool {
return dt1.In(time.UTC).Before(dt2.In(time.UTC))
// Before reports whether dt occurs before future.
func (dt LocalDateTime) Before(future LocalDateTime) bool {
return dt.In(time.UTC).Before(future.In(time.UTC))
}
// After reports whether dt1 occurs after dt2.
func (dt1 LocalDateTime) After(dt2 LocalDateTime) bool {
return dt2.Before(dt1)
// After reports whether dt occurs after past.
func (dt LocalDateTime) After(past LocalDateTime) bool {
return past.Before(dt)
}
// MarshalText implements the encoding.TextMarshaler interface.
@@ -273,9 +291,10 @@ func (dt LocalDateTime) MarshalText() ([]byte, error) {
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The datetime is expected to be a string in a format accepted by ParseLocalDateTime
// The datetime is expected to be a string in a format accepted by ParseLocalDateTime.
func (dt *LocalDateTime) UnmarshalText(data []byte) error {
var err error
*dt, err = ParseLocalDateTime(string(data))
return err
}
+49 -6
View File
@@ -26,6 +26,8 @@ func cmpEqual(x, y interface{}) bool {
}
func TestDates(t *testing.T) {
t.Parallel()
for _, test := range []struct {
date LocalDate
loc *time.Location
@@ -61,6 +63,8 @@ func TestDates(t *testing.T) {
}
func TestDateIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct {
date LocalDate
want bool
@@ -86,6 +90,10 @@ func TestDateIsValid(t *testing.T) {
}
func TestParseDate(t *testing.T) {
t.Parallel()
var emptyDate LocalDate
for _, test := range []struct {
str string
want LocalDate // if empty, expect an error
@@ -93,21 +101,23 @@ func TestParseDate(t *testing.T) {
{"2016-01-02", LocalDate{2016, 1, 2}},
{"2016-12-31", LocalDate{2016, 12, 31}},
{"0003-02-04", LocalDate{3, 2, 4}},
{"999-01-26", LocalDate{}},
{"", LocalDate{}},
{"2016-01-02x", LocalDate{}},
{"999-01-26", emptyDate},
{"", emptyDate},
{"2016-01-02x", emptyDate},
} {
got, err := ParseLocalDate(test.str)
if got != test.want {
t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want)
}
if err != nil && test.want != (LocalDate{}) {
if err != nil && test.want != (emptyDate) {
t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str)
}
}
}
func TestDateArithmetic(t *testing.T) {
t.Parallel()
for _, test := range []struct {
desc string
start LocalDate
@@ -167,6 +177,8 @@ func TestDateArithmetic(t *testing.T) {
}
func TestDateBefore(t *testing.T) {
t.Parallel()
for _, test := range []struct {
d1, d2 LocalDate
want bool
@@ -183,6 +195,8 @@ func TestDateBefore(t *testing.T) {
}
func TestDateAfter(t *testing.T) {
t.Parallel()
for _, test := range []struct {
d1, d2 LocalDate
want bool
@@ -198,6 +212,8 @@ func TestDateAfter(t *testing.T) {
}
func TestTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct {
str string
time LocalTime
@@ -212,6 +228,7 @@ func TestTimeToString(t *testing.T) {
gotTime, err := ParseLocalTime(test.str)
if err != nil {
t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err)
continue
}
if gotTime != test.time {
@@ -227,6 +244,8 @@ func TestTimeToString(t *testing.T) {
}
func TestTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct {
time time.Time
want LocalTime
@@ -241,6 +260,8 @@ func TestTimeOf(t *testing.T) {
}
func TestTimeIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct {
time LocalTime
want bool
@@ -265,6 +286,8 @@ func TestTimeIsValid(t *testing.T) {
}
func TestDateTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct {
str string
dateTime LocalDateTime
@@ -277,6 +300,7 @@ func TestDateTimeToString(t *testing.T) {
gotDateTime, err := ParseLocalDateTime(test.str)
if err != nil {
t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err)
continue
}
if gotDateTime != test.dateTime {
@@ -292,6 +316,8 @@ func TestDateTimeToString(t *testing.T) {
}
func TestParseDateTimeErrors(t *testing.T) {
t.Parallel()
for _, str := range []string{
"",
"2016-03-22", // just a date
@@ -306,6 +332,8 @@ func TestParseDateTimeErrors(t *testing.T) {
}
func TestDateTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct {
time time.Time
want LocalDateTime
@@ -322,6 +350,8 @@ func TestDateTimeOf(t *testing.T) {
}
func TestDateTimeIsValid(t *testing.T) {
t.Parallel()
// No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid.
for _, test := range []struct {
dt LocalDateTime
@@ -339,19 +369,24 @@ func TestDateTimeIsValid(t *testing.T) {
}
func TestDateTimeIn(t *testing.T) {
t.Parallel()
dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}}
got := dt.In(time.UTC)
want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC)
if !got.Equal(want) {
if got := dt.In(time.UTC); !got.Equal(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func TestDateTimeBefore(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct {
dt1, dt2 LocalDateTime
want bool
@@ -368,10 +403,13 @@ func TestDateTimeBefore(t *testing.T) {
}
func TestDateTimeAfter(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct {
dt1, dt2 LocalDateTime
want bool
@@ -388,6 +426,8 @@ func TestDateTimeAfter(t *testing.T) {
}
func TestMarshalJSON(t *testing.T) {
t.Parallel()
for _, test := range []struct {
value interface{}
want string
@@ -407,9 +447,12 @@ func TestMarshalJSON(t *testing.T) {
}
func TestUnmarshalJSON(t *testing.T) {
t.Parallel()
var d LocalDate
var tm LocalTime
var dt LocalDateTime
for _, test := range []struct {
data string
ptr interface{}
+138 -76
View File
@@ -18,10 +18,12 @@ import (
func Marshal(v interface{}) ([]byte, error) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(v)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
@@ -96,7 +98,7 @@ func NewEncoder(w io.Writer) *Encoder {
// 5. Intermediate tables are always printed.
//
// By default, strings are encoded as literal string, unless they contain either
// a newline character or a single quote. In that case they are emited as quoted
// a newline character or a single quote. In that case they are emitted as quoted
// strings.
//
// When encoding structs, fields are encoded in order of definition, with their
@@ -107,25 +109,38 @@ func NewEncoder(w io.Writer) *Encoder {
// `multiline:"true"`: when the field contains a string, it will be emitted as
// a quoted multi-line TOML string.
func (enc *Encoder) Encode(v interface{}) error {
var b []byte
var ctx encoderCtx
var (
b []byte
ctx encoderCtx
)
b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil {
return err
return fmt.Errorf("Encode: %w", err)
}
_, err = enc.w.Write(b)
return err
if err != nil {
return fmt.Errorf("Encode: %w", err)
}
return nil
}
var errUnsupportedValue = errors.New("unsupported encode value kind")
//nolint:cyclop
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
//nolint:gocritic,godox
switch i := v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler
b = i.AppendFormat(b, time.RFC3339)
return b, nil
}
// containers
switch v.Kind() {
// containers
case reflect.Map:
return enc.encodeMap(b, ctx, v)
case reflect.Struct:
@@ -136,19 +151,18 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
if v.IsNil() {
return nil, errNilInterface
}
return enc.encode(b, ctx, v.Elem())
case reflect.Ptr:
if v.IsNil() {
return enc.encode(b, ctx, reflect.Zero(v.Type().Elem()))
}
return enc.encode(b, ctx, v.Elem())
}
// values
var err error
switch v.Kind() {
case reflect.String:
b, err = enc.encodeString(b, v.String(), ctx.options)
b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32:
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
case reflect.Float64:
@@ -164,10 +178,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
b = strconv.AppendInt(b, v.Int(), 10)
default:
err = fmt.Errorf("unsupported encode value kind: %s", v.Kind())
}
if err != nil {
return nil, err
return nil, fmt.Errorf("encode(type %s): %w", v.Kind(), errUnsupportedValue)
}
return b, nil
@@ -217,30 +228,31 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
const literalQuote = '\''
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) ([]byte, error) {
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte {
if needsQuoting(v) {
b = enc.encodeQuotedString(options.multiline, b, v)
} else {
b = enc.encodeLiteralString(b, v)
return enc.encodeQuotedString(options.multiline, b, v)
}
return b, nil
return enc.encodeLiteralString(b, v)
}
func needsQuoting(v string) bool {
return strings.ContainsAny(v, "'\b\f\n\r\t")
}
// caller should have checked that the string does not contain new lines or '
// caller should have checked that the string does not contain new lines or ' .
func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte {
b = append(b, literalQuote)
b = append(b, v...)
b = append(b, literalQuote)
return b
}
//nolint:cyclop
func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte {
const hextable = "0123456789ABCDEF"
stringQuote := `"`
if multiline {
stringQuote = `"""`
}
@@ -250,6 +262,16 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, '\n')
}
const (
hextable = "0123456789ABCDEF"
// U+0000 to U+0008, U+000A to U+001F, U+007F
nul = 0x0
bs = 0x8
lf = 0xa
us = 0x1f
del = 0x7f
)
for _, r := range []byte(v) {
switch r {
case '\\':
@@ -272,7 +294,7 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, `\t`...)
default:
switch {
case r >= 0x0 && r <= 0x8, r >= 0xA && r <= 0x1F, r == 0x7F:
case r >= nul && r <= bs, r >= lf && r <= us, r == del:
b = append(b, `\u00`...)
b = append(b, hextable[r>>4])
b = append(b, hextable[r&0x0f])
@@ -280,14 +302,14 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, r)
}
}
// U+0000 to U+0008, U+000A to U+001F, U+007F
}
b = append(b, stringQuote...)
return b
}
// called should have checked that the string is in A-Z / a-z / 0-9 / - / _
// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...)
}
@@ -300,6 +322,7 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
b = append(b, '[')
var err error
b, err = enc.encodeKey(b, key[0])
if err != nil {
return nil, err
@@ -307,6 +330,7 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
for _, k := range key[1:] {
b = append(b, '.')
b, err = enc.encodeKey(b, k)
if err != nil {
return nil, err
@@ -318,6 +342,9 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
return b, nil
}
var errTomlNoMultiline = errors.New("TOML does not support multiline keys")
//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
needsQuotation := false
cannotUseLiteral := false
@@ -326,32 +353,39 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
continue
}
if c == '\n' {
return nil, fmt.Errorf("TOML does not support multiline keys")
return nil, errTomlNoMultiline
}
if c == literalQuote {
cannotUseLiteral = true
}
needsQuotation = true
}
if cannotUseLiteral {
b = enc.encodeQuotedString(false, b, k)
} else if needsQuotation {
b = enc.encodeLiteralString(b, k)
} else {
b = enc.encodeUnquotedKey(b, k)
switch {
case cannotUseLiteral:
return enc.encodeQuotedString(false, b, k), nil
case needsQuotation:
return enc.encodeLiteralString(b, k), nil
default:
return enc.encodeUnquotedKey(b, k), nil
}
return b, nil
}
var errNotSupportedAsMapKey = errors.New("type not supported as map key")
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("type '%s' not supported as map key", v.Type().Key().Kind())
return nil, fmt.Errorf("encodeMap '%s': %w", v.Type().Key().Kind(), errNotSupportedAsMapKey)
}
t := table{}
var (
t table
emptyValueOptions valueOptions
)
iter := v.MapRange()
for iter.Next() {
@@ -368,9 +402,9 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte
}
if table {
t.pushTable(k, v, valueOptions{})
t.pushTable(k, v, emptyValueOptions)
} else {
t.pushKV(k, v, valueOptions{})
t.pushKV(k, v, emptyValueOptions)
}
}
@@ -405,13 +439,10 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) {
t.tables = append(t.tables, entry{Key: k, Value: v, Options: options})
}
func (t *table) hasKVs() bool {
return len(t.kvs) > 0
}
func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
t := table{}
var t table
//nolint:godox
// TODO: cache this?
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
@@ -443,7 +474,7 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
return nil, err
}
options := valueOptions{}
var options valueOptions
ml, ok := fieldType.Tag.Lookup("multiline")
if ok {
@@ -466,38 +497,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
ctx.shiftKey()
if ctx.insideKv {
b = append(b, '{')
first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}\n"...)
return b, nil
return enc.encodeTableInsideKV(b, ctx, t)
}
if !ctx.skipTableHeader {
@@ -510,29 +510,76 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
for _, kv := range t.kvs {
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
for _, table := range t.tables {
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
return b, nil
}
func (enc *Encoder) encodeTableInsideKV(b []byte, ctx encoderCtx, t table) ([]byte, error) {
var err error
b = append(b, '{')
first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}\n"...)
return b, nil
}
var errNilInterface = errors.New("nil interface not supported")
var errNilPointer = errors.New("nil pointer not supported")
func willConvertToTable(v reflect.Value) (bool, error) {
//nolint:gocritic,godox
switch v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler
return false, nil
@@ -546,11 +593,13 @@ func willConvertToTable(v reflect.Value) (bool, error) {
if v.IsNil() {
return false, errNilInterface
}
return willConvertToTable(v.Elem())
case reflect.Ptr:
if v.IsNil() {
return false, nil
}
return willConvertToTable(v.Elem())
default:
return false, nil
@@ -564,6 +613,7 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
if v.IsNil() {
return false, errNilInterface
}
return willConvertToTableOrArrayTable(v.Elem())
}
@@ -572,15 +622,18 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
// An empty slice should be a kv = [].
return false, nil
}
for i := 0; i < v.Len(); i++ {
t, err := willConvertToTable(v.Index(i))
if err != nil {
return false, err
}
if !t {
return false, nil
}
}
return true, nil
}
@@ -590,6 +643,7 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Len() == 0 {
b = append(b, "[]"...)
return b, nil
}
@@ -617,25 +671,30 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
var err error
scratch := make([]byte, 0, 64)
scratch = append(scratch, "[["...)
for i, k := range ctx.parentKey {
if i > 0 {
scratch = append(scratch, '.')
}
scratch, err = enc.encodeKey(scratch, k)
if err != nil {
return nil, err
}
}
scratch = append(scratch, "]]\n"...)
ctx.skipTableHeader = true
for i := 0; i < v.Len(); i++ {
b = append(b, scratch...)
b, err = enc.encode(b, ctx, v.Index(i))
if err != nil {
return nil, err
}
}
return b, nil
}
@@ -644,10 +703,12 @@ func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value
var err error
first := true
for i := 0; i < v.Len(); i++ {
if !first {
b = append(b, ", "...)
}
first = false
b, err = enc.encode(b, ctx, v.Index(i))
@@ -657,5 +718,6 @@ func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value
}
b = append(b, ']')
return b, nil
}
+21 -2
View File
@@ -11,7 +11,10 @@ import (
"github.com/stretchr/testify/require"
)
//nolint:funlen
func TestMarshal(t *testing.T) {
t.Parallel()
examples := []struct {
desc string
v interface{}
@@ -65,6 +68,7 @@ hello = 'world'`,
a = 'test'`,
},
{
//nolint:godox
// TODO: this test is flaky because output changes depending on
// the map iteration order.
desc: "map in map in map and string with values",
@@ -89,6 +93,16 @@ a = 'test'`,
},
expected: `array = ['one', 'two', 'three']`,
},
{
desc: "empty string array",
v: map[string][]string{},
expected: ``,
},
{
desc: "map",
v: map[string][]string{},
expected: ``,
},
{
desc: "nested string arrays",
v: map[string][][]string{
@@ -104,7 +118,7 @@ a = 'test'`,
expected: `array = ['a string', ['one', 'two'], 'last']`,
},
{
desc: "slice of maps",
desc: "array of maps",
v: map[string][]map[string]string{
"top": {
{"map1.1": "v1.1"},
@@ -157,7 +171,7 @@ K2 = 'v2'
`,
},
{
desc: "structs in slice with interfaces",
desc: "structs in array with interfaces",
v: map[string]interface{}{
"root": map[string]interface{}{
"nested": []interface{}{
@@ -237,7 +251,10 @@ world"""`,
}
for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) {
t.Parallel()
b, err := toml.Marshal(e.v)
if e.err {
require.Error(t, err)
@@ -256,6 +273,8 @@ func equalStringsIgnoreNewlines(t *testing.T, expected string, actual string) {
}
func TestIssue436(t *testing.T) {
t.Parallel()
data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`)
var v interface{}
+10 -6
View File
@@ -155,12 +155,14 @@ func TestParser_AST_Numbers(t *testing.T) {
}
}
type astRoot []astNode
type astNode struct {
Kind ast.Kind
Data []byte
Children []astNode
}
type (
astRoot []astNode
astNode struct {
Kind ast.Kind
Data []byte
Children []astNode
}
)
func compareAST(t *testing.T, expected astRoot, actual *ast.Root) {
it := actual.Iterator()
@@ -168,6 +170,7 @@ func compareAST(t *testing.T, expected astRoot, actual *ast.Root) {
}
func compareNode(t *testing.T, e astNode, n ast.Node) {
t.Helper()
require.Equal(t, e.Kind, n.Kind)
require.Equal(t, e.Data, n.Data)
@@ -175,6 +178,7 @@ func compareNode(t *testing.T, e astNode, n ast.Node) {
}
func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) {
t.Helper()
idx := 0
for actual.Next() {
+35 -27
View File
@@ -2,29 +2,37 @@ package toml
import "fmt"
func scanFollows(pattern []byte) func(b []byte) bool {
return func(b []byte) bool {
if len(b) < len(pattern) {
return false
}
for i, c := range pattern {
if b[i] != c {
return false
}
}
return true
}
func scanFollows(b []byte, pattern string) bool {
n := len(pattern)
return len(b) >= n && string(b[:n]) == pattern
}
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'})
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''})
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'})
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'})
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'})
func scanFollowsMultilineBasicStringDelimiter(b []byte) bool {
return scanFollows(b, `"""`)
}
func scanFollowsMultilineLiteralStringDelimiter(b []byte) bool {
return scanFollows(b, `'''`)
}
func scanFollowsTrue(b []byte) bool {
return scanFollows(b, `true`)
}
func scanFollowsFalse(b []byte) bool {
return scanFollows(b, `false`)
}
func scanFollowsInf(b []byte) bool {
return scanFollows(b, `inf`)
}
func scanFollowsNan(b []byte) bool {
return scanFollows(b, `nan`)
}
func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
// unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
for i := 0; i < len(b); i++ {
if !isUnquotedKeyChar(b[i]) {
return b[:i], b[i:], nil
@@ -38,9 +46,9 @@ func isUnquotedKeyChar(r byte) bool {
}
func scanLiteralString(b []byte) ([]byte, []byte, error) {
//literal-string = apostrophe *literal-char apostrophe
//apostrophe = %x27 ; ' apostrophe
//literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
// literal-string = apostrophe *literal-char apostrophe
// apostrophe = %x27 ; ' apostrophe
// literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
for i := 1; i < len(b); i++ {
switch b[i] {
case '\'':
@@ -115,11 +123,11 @@ func scanComment(b []byte) ([]byte, []byte, error) {
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) {
//basic-string = quotation-mark *basic-char quotation-mark
//quotation-mark = %x22 ; "
//basic-char = basic-unescaped / escaped
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//escaped = escape escape-seq-char
// basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; "
// basic-char = basic-unescaped / escaped
// basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
// escaped = escape escape-seq-char
for i := 1; i < len(b); i++ {
switch b[i] {
case '"':
+79
View File
@@ -0,0 +1,79 @@
package toml
import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker"
)
type strict struct {
Enabled bool
// Tracks the current key being processed.
key tracker.KeyTracker
missing []decodeError
}
func (s *strict) EnterTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateTable(node)
}
func (s *strict) EnterArrayTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateArrayTable(node)
}
func (s *strict) EnterKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Push(node)
}
func (s *strict) ExitKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Pop(node)
}
func (s *strict) MissingTable(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing table",
key: s.key.Key(),
})
}
func (s *strict) MissingField(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing field",
key: s.key.Key(),
})
}
func (s *strict) Error(doc []byte) error {
if !s.Enabled || len(s.missing) == 0 {
return nil
}
err := &StrictMissingError{
Errors: make([]DecodeError, 0, len(s.missing)),
}
for _, derr := range s.missing {
err.Errors = append(err.Errors, *wrapDecodeError(doc, &derr))
}
return err
}
+4 -5
View File
@@ -24,7 +24,7 @@ type target interface {
// Store a float64 at the target
setFloat64(v float64) error
// Stores any value at the target
// Stores any value at the target
set(v reflect.Value) error
}
@@ -516,11 +516,10 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) {
l := len(path)
path = append(path, i)
f := t.Field(i)
if f.PkgPath != "" {
// only consider exported fields
} else if f.Anonymous {
if f.Anonymous {
walk(v.Field(i))
} else {
} else if f.PkgPath == "" {
// only consider exported fields
fieldName, ok := f.Tag.Lookup("toml")
if !ok {
fieldName = f.Name
+1
View File
@@ -59,6 +59,7 @@ val = string / boolean / array / inline-table / date-time / float / integer
;; String
string = ml-basic-string / basic-string / ml-literal-string / literal-string
;; Basic String
basic-string = quotation-mark *basic-char quotation-mark
+49 -2
View File
@@ -10,6 +10,7 @@ import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/internal/unsafe"
)
func Unmarshal(data []byte, v interface{}) error {
@@ -21,7 +22,11 @@ func Unmarshal(data []byte, v interface{}) error {
// Decoder reads and decode a TOML document from an input stream.
type Decoder struct {
// input
r io.Reader
// global settings
strict bool
}
// NewDecoder creates a new Decoder that will read from r.
@@ -29,6 +34,16 @@ func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// SetStrict toggles decoding in stict mode.
//
// When the decoder is in strict mode, it will record fields from the document
// that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) {
d.strict = strict
}
// Decode the whole content of r into v.
//
// When a TOML local date is decoded into a time.Time, its value is represented
@@ -43,7 +58,11 @@ func (d *Decoder) Decode(v interface{}) error {
}
p := parser{}
p.Reset(b)
dec := decoder{}
dec := decoder{
strict: strict{
Enabled: d.strict,
},
}
return dec.FromParser(&p, v)
}
@@ -52,7 +71,10 @@ type decoder struct {
arrayIndexes map[reflect.Value]int
// Tracks keys that have been seen, with which type.
seen tracker.Seen
seen tracker.SeenTracker
// Strict mode
strict strict
}
func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
@@ -79,9 +101,27 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
err = wrapDecodeError(p.data, de)
}
}
if err == nil {
err = d.strict.Error(p.data)
}
return err
}
func keyLocation(node ast.Node) []byte {
k := node.Key()
hasOne := k.Next()
if !hasOne {
panic("should not be called with empty key")
}
start := k.Node().Data
end := k.Node().Data
for k.Next() {
end = k.Node().Data
}
return unsafe.BytesRange(start, end)
}
func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr {
@@ -113,6 +153,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
err = d.unmarshalKeyValue(current, node)
found = true
case ast.Table:
d.strict.EnterTable(node)
current, found, err = d.scopeWithKey(root, node.Key())
if err == nil && found {
// In case this table points to an interface,
@@ -123,6 +164,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
ensureMapIfInterface(current)
}
case ast.ArrayTable:
d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key())
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
@@ -134,6 +176,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
if !found {
skipUntilTable = true
d.strict.MissingTable(node)
}
}
@@ -217,6 +260,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
assertNode(ast.KeyValue, node)
d.strict.EnterKeyValue(node)
defer d.strict.ExitKeyValue(node)
x, found, err := d.scopeWithKey(x, node.Key())
if err != nil {
return err
@@ -224,6 +270,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
// A struct in the path was not found. Skip this value.
if !found {
d.strict.MissingField(node)
return nil
}
+183 -4
View File
@@ -1,8 +1,10 @@
package toml_test
import (
"fmt"
"math"
"strconv"
"strings"
"testing"
"time"
@@ -132,6 +134,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan",
input: `nan`,
testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v))
},
},
@@ -139,6 +142,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan negative",
input: `-nan`,
testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v))
},
},
@@ -146,6 +150,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan positive",
input: `+nan`,
testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v))
},
},
@@ -706,6 +711,42 @@ B = "data"`,
}
},
},
{
desc: "windows line endings",
input: "A = 1\r\n\r\nB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
expected: &map[string]interface{}{
"A": int64(1),
"B": int64(2),
},
}
},
},
{
desc: "dangling CR",
input: "A = 1\r",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
{
desc: "missing NL after CR",
input: "A = 1\rB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
}
for _, e := range examples {
@@ -759,8 +800,10 @@ func TestIssue484(t *testing.T) {
}, cfg)
}
type Map458 map[string]interface{}
type Slice458 []interface{}
type (
Map458 map[string]interface{}
Slice458 []interface{}
)
func (m Map458) A(s string) Slice458 {
return m[s].([]interface{})
@@ -779,7 +822,8 @@ version = "0.1.0"`)
map[string]interface{}{
"dependencies": []interface{}{"regex"},
"name": "decode",
"version": "0.1.0"},
"version": "0.1.0",
},
}
assert.Equal(t, expected, a)
}
@@ -790,7 +834,7 @@ func TestIssue252(t *testing.T) {
Val2 string `toml:"val2"`
}
var configFile = []byte(
configFile := []byte(
`
val1 = "test1"
`)
@@ -818,6 +862,13 @@ bar = 2021-04-08
require.NoError(t, err)
}
func TestIssue507(t *testing.T) {
data := []byte{'0', '=', '\n', '0', 'a', 'm', 'e'}
m := map[string]interface{}{}
err := toml.Unmarshal(data, &m)
require.Error(t, err)
}
func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct {
desc string
@@ -924,3 +975,131 @@ func TestIssue287(t *testing.T) {
}
require.Equal(t, expected, v)
}
func TestIssue508(t *testing.T) {
type head struct {
Title string `toml:"title"`
}
type text struct {
head
}
b := []byte(`title = "This is a title"`)
t1 := text{}
err := toml.Unmarshal(b, &t1)
require.NoError(t, err)
require.Equal(t, "This is a title", t1.head.Title)
}
func TestDecoderStrict(t *testing.T) {
examples := []struct {
desc string
input string
expected string
target interface{}
}{
{
desc: "multiple missing root keys",
input: `
key1 = "value1"
key2 = "missing2"
key3 = "missing3"
key4 = "value4"
`,
expected: `
2| key1 = "value1"
3| key2 = "missing2"
| ~~~~ missing field
4| key3 = "missing3"
5| key4 = "value4"
---
2| key1 = "value1"
3| key2 = "missing2"
4| key3 = "missing3"
| ~~~~ missing field
5| key4 = "value4"
`,
target: &struct {
Key1 string
Key4 string
}{},
},
{
desc: "multi-part key",
input: `a.short.key="foo"`,
expected: `
1| a.short.key="foo"
| ~~~~~~~~~~~ missing field
`,
},
{
desc: "missing table",
input: `
[foo]
bar = 42
`,
expected: `
2| [foo]
| ~~~ missing table
3| bar = 42
`,
},
{
desc: "missing array table",
input: `
[[foo]]
bar = 42
`,
expected: `
2| [[foo]]
| ~~~ missing table
3| bar = 42
`,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
r := strings.NewReader(e.input)
d := toml.NewDecoder(r)
d.SetStrict(true)
x := e.target
if x == nil {
x = &struct{}{}
}
err := d.Decode(x)
details := err.(*toml.StrictMissingError)
equalStringsIgnoreNewlines(t, e.expected, details.String())
})
}
}
func ExampleDecoder_SetStrict() {
type S struct {
Key1 string
Key3 string
}
doc := `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
r := strings.NewReader(doc)
d := toml.NewDecoder(r)
d.SetStrict(true)
s := S{}
err := d.Decode(&s)
fmt.Println(err.Error())
// Output: strict mode: fields in the document are missing in the target struct
details := err.(*toml.StrictMissingError)
fmt.Println(details.String())
// Ouput:
// 2| key1 = "value1"
// 3| key2 = "value2"
// | ~~~~ missing field
// 4| key3 = "value3"
}