Compare commits

...

9 Commits

Author SHA1 Message Date
Thomas Pelletier 9b67e40640 decoder: strict mode (#512) 2021-04-20 21:26:22 -04:00
Vincent Serpoul dca2103910 golangci-lint: marshaler (#516) 2021-04-20 20:24:44 -04:00
Cameron Moore a713a96e69 Add more newline tests for scanner (#515) 2021-04-16 19:07:29 -04:00
Cameron Moore a7b50eb8f1 Tidy (#511)
* Disconnect package godoc comment from imported file

* Add missing newline in toml.abnf

* Tag testing helper funcs
2021-04-15 16:49:19 -04:00
Cameron Moore 24b62ebe61 Simplify scanFollows usage (#510)
Use static functions to avoid declaring global vars and creating more
package init costs.  This change has no negative effects on benchmarks
in my testing.
2021-04-15 16:48:19 -04:00
Thomas Pelletier 9bc4641a49 ci-lint: disable ifshort 2021-04-15 13:37:24 -04:00
Thomas Pelletier b86b890b8d decoder: handle private anonymous structs
Ref #508
2021-04-15 12:49:24 -04:00
Vincent Serpoul 080baa8574 golangci-lint: localtime (#509) 2021-04-15 12:44:31 -04:00
Thomas Pelletier 0537b928df decoder: add test for #507 2021-04-15 11:36:36 -04:00
20 changed files with 1087 additions and 408 deletions
+4 -1
View File
@@ -4,6 +4,9 @@ golangci-lint-version = "1.39.0"
[linters-settings.wsl] [linters-settings.wsl]
allow-assign-and-anything = true allow-assign-and-anything = true
[linters-settings.exhaustive]
default-signifies-exhaustive = true
[linters] [linters]
disable-all = true disable-all = true
enable = [ enable = [
@@ -45,7 +48,7 @@ enable = [
"gosec", "gosec",
"gosimple", "gosimple",
"govet", "govet",
"ifshort", # "ifshort",
"importas", "importas",
"ineffassign", "ineffassign",
"lll", "lll",
+1 -1
View File
@@ -22,7 +22,7 @@ Development branch. Use at your own risk.
- [x] Abstract AST. - [x] Abstract AST.
- [x] Original go-toml testgen tests pass. - [x] Original go-toml testgen tests pass.
- [x] Track file position (line, column) for errors. - [x] Track file position (line, column) for errors.
- [ ] Strict mode. - [x] Strict mode.
- [ ] Document Unmarshal / Decode - [ ] Document Unmarshal / Decode
### Marshal ### Marshal
+38 -1
View File
@@ -18,15 +18,46 @@ type DecodeError struct {
message string message string
line int line int
column int column int
key Key
human string human string
} }
// StrictMissingError occurs in a TOML document that does not have a
// corresponding field in the target value. It contains all the missing fields
// in Errors.
//
// Emitted by Decoder when SetStrict(true) was called.
type StrictMissingError struct {
// One error per field that could not be found.
Errors []DecodeError
}
// Error returns the cannonical string for this error.
func (s *StrictMissingError) Error() string {
return "strict mode: fields in the document are missing in the target struct"
}
// String returns a human readable description of all errors.
func (s *StrictMissingError) String() string {
var buf strings.Builder
for i, e := range s.Errors {
if i > 0 {
buf.WriteString("\n---\n")
}
buf.WriteString(e.String())
}
return buf.String()
}
type Key []string
// internal version of DecodeError that is used as the base to create a // internal version of DecodeError that is used as the base to create a
// DecodeError with full context. // DecodeError with full context.
type decodeError struct { type decodeError struct {
highlight []byte highlight []byte
message string message string
key Key // optional
} }
func (de *decodeError) Error() string { func (de *decodeError) Error() string {
@@ -56,6 +87,11 @@ func (e *DecodeError) Position() (row int, column int) {
return e.line, e.column return e.line, e.column
} }
// Key that was being processed when the error occured.
func (e *DecodeError) Key() Key {
return e.key
}
// decodeErrorFromHighlight creates a DecodeError referencing to a highlighted // decodeErrorFromHighlight creates a DecodeError referencing to a highlighted
// range of bytes from document. // range of bytes from document.
// //
@@ -64,7 +100,7 @@ func (e *DecodeError) Position() (row int, column int) {
// The function copies all bytes used in DecodeError, so that document and // The function copies all bytes used in DecodeError, so that document and
// highlight can be freely deallocated. // highlight can be freely deallocated.
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) error { func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
if de == nil { if de == nil {
return nil return nil
} }
@@ -137,6 +173,7 @@ func wrapDecodeError(document []byte, de *decodeError) error {
message: errMessage, message: errMessage,
line: errLine, line: errLine,
column: errColumn, column: errColumn,
key: de.key,
human: buf.String(), human: buf.String(),
} }
} }
@@ -7,6 +7,7 @@ package imported_tests
// marked as skipped until we figure out if that's something we want in v2. // marked as skipped until we figure out if that's something we want in v2.
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -1955,66 +1956,80 @@ String2="2"`
assert.Error(t, err) assert.Error(t, err)
} }
func decoder(doc string) *toml.Decoder {
return toml.NewDecoder(bytes.NewReader([]byte(doc)))
}
func strictDecoder(doc string) *toml.Decoder {
d := decoder(doc)
d.SetStrict(true)
return d
}
func TestDecoderStrict(t *testing.T) { func TestDecoderStrict(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = ""
// [undecoded]
//[undecoded] key = ""
// key = ""
// [undecoded.inner]
// [undecoded.inner] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// [[undecoded.array]]
// [[undecoded.array]] key = ""
// key = ""
// `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]` require.Error(t, err)
// require.IsType(t, &toml.StrictMissingError{}, err)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) se := err.(*toml.StrictMissingError)
// if err == nil {
// t.Error("expected error, got none") keys := []toml.Key{}
// } else if err.Error() != expected {
// t.Errorf("expect err: %s, got: %s", expected, err.Error()) for _, e := range se.Errors {
// } keys = append(keys, e.Key())
// }
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil {
// t.Errorf("unexpected err: %s", err) expectedKeys := []toml.Key{
// } {"undecoded"},
// {"undecoded", "inner"},
// var m map[string]interface{} {"undecoded", "array"},
// if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil { {"undecoded", "array"},
// t.Errorf("unexpected err: %s", err) }
// }
require.Equal(t, expectedKeys, keys)
err = decoder(input).Decode(&doc)
require.NoError(t, err)
var m map[string]interface{}
err = decoder(input).Decode(&m)
} }
func TestDecoderStrictValid(t *testing.T) { func TestDecoderStrictValid(t *testing.T) {
t.Skip() input := `
// input := ` [decoded]
//[decoded] key = ""
// key = "" `
//` var doc struct {
// var doc struct { Decoded struct {
// Decoded struct { Key string
// Key string }
// } }
// }
// err := strictDecoder(input).Decode(&doc)
// err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) require.NoError(t, err)
// if err != nil {
// t.Fatal("unexpected error:", err)
// }
} }
type docUnmarshalTOML struct { type docUnmarshalTOML struct {
+50
View File
@@ -0,0 +1,50 @@
package tracker
import (
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked.
type KeyTracker struct {
k []string
}
// UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node ast.Node) {
t.reset()
t.Push(node)
}
// UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node ast.Node) {
t.reset()
t.Push(node)
}
// Push the given key on the stack.
func (t *KeyTracker) Push(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = append(t.k, string(it.Node().Data))
}
}
// Pop key from stack.
func (t *KeyTracker) Pop(node ast.Node) {
it := node.Key()
for it.Next() {
t.k = t.k[:len(t.k)-1]
}
}
// Key returns the current key
func (t *KeyTracker) Key() []string {
k := make([]string, len(t.k))
copy(k, t.k)
return k
}
func (t *KeyTracker) reset() {
t.k = t.k[:0]
}
+200
View File
@@ -0,0 +1,200 @@
package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// SeenTracker tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type SeenTracker struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *SeenTracker) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *SeenTracker) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *SeenTracker) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
-199
View File
@@ -1,200 +1 @@
package tracker package tracker
import (
"fmt"
"github.com/pelletier/go-toml/v2/internal/ast"
)
type keyKind uint8
const (
invalidKind keyKind = iota
valueKind
tableKind
arrayTableKind
)
func (k keyKind) String() string {
switch k {
case invalidKind:
return "invalid"
case valueKind:
return "value"
case tableKind:
return "table"
case arrayTableKind:
return "array table"
}
panic("missing keyKind string mapping")
}
// Tracks which keys have been seen with which TOML type to flag duplicates
// and mismatches according to the spec.
type Seen struct {
root *info
current *info
}
type info struct {
parent *info
kind keyKind
children map[string]*info
explicit bool
}
func (i *info) Clear() {
i.children = nil
}
func (i *info) Has(k string) (*info, bool) {
c, ok := i.children[k]
return c, ok
}
func (i *info) SetKind(kind keyKind) {
i.kind = kind
}
func (i *info) CreateTable(k string, explicit bool) *info {
return i.createChild(k, tableKind, explicit)
}
func (i *info) CreateArrayTable(k string, explicit bool) *info {
return i.createChild(k, arrayTableKind, explicit)
}
func (i *info) createChild(k string, kind keyKind, explicit bool) *info {
if i.children == nil {
i.children = make(map[string]*info, 1)
}
x := &info{
parent: i,
kind: kind,
explicit: explicit,
}
i.children[k] = x
return x
}
// CheckExpression takes a top-level node and checks that it does not contain keys
// that have been seen in previous calls, and validates that types are consistent.
func (s *Seen) CheckExpression(node ast.Node) error {
if s.root == nil {
s.root = &info{
kind: tableKind,
}
s.current = s.root
}
switch node.Kind {
case ast.KeyValue:
return s.checkKeyValue(s.current, node)
case ast.Table:
return s.checkTable(node)
case ast.ArrayTable:
return s.checkArrayTable(node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func (s *Seen) checkTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
}
i.explicit = true
s.current = i
} else {
s.current = s.current.CreateTable(k, true)
}
return nil
}
func (s *Seen) checkArrayTable(node ast.Node) error {
s.current = s.root
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
if !it.Node().Next().Valid() {
break
}
k := string(it.Node().Data)
child, found := s.current.Has(k)
if !found {
child = s.current.CreateTable(k, false)
}
s.current = child
}
// handle the last part of the key
k := string(it.Node().Data)
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
}
info.Clear()
} else {
info = s.current.CreateArrayTable(k, true)
}
s.current = info
return nil
}
func (s *Seen) checkKeyValue(context *info, node ast.Node) error {
it := node.Key()
// handle the first parts of the key, excluding the last one
for it.Next() {
k := string(it.Node().Data)
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
}
context = child
}
if node.Value().Kind == ast.InlineTable {
context.SetKind(tableKind)
} else {
context.SetKind(valueKind)
}
return nil
}
+24
View File
@@ -33,3 +33,27 @@ func SubsliceOffset(data []byte, subslice []byte) int {
return intoffset return intoffset
} }
func BytesRange(start []byte, end []byte) []byte {
if start == nil || end == nil {
panic("cannot call BytesRange with nil")
}
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
if startp.Data > endp.Data {
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
}
l := startp.Len
endLen := int(endp.Data-startp.Data) + endp.Len
if endLen > l {
l = endLen
}
if l > startp.Cap {
panic(fmt.Errorf("range length is larger than capacity"))
}
return start[:l]
}
+89
View File
@@ -77,3 +77,92 @@ func TestUnsafeSubsliceOffsetInvalid(t *testing.T) {
}) })
} }
} }
func TestUnsafeBytesRange(t *testing.T) {
type fn = func() ([]byte, []byte)
examples := []struct {
desc string
test fn
expected []byte
}{
{
desc: "simple",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[6:8]
},
expected: []byte("ello wo"),
},
{
desc: "full",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[0:1], full[len(full)-1:]
},
expected: []byte("hello world"),
},
{
desc: "end before start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[len(full)-1:], full[0:1]
},
},
{
desc: "nils",
test: func() ([]byte, []byte) {
return nil, nil
},
},
{
desc: "nils start",
test: func() ([]byte, []byte) {
return nil, []byte("foo")
},
},
{
desc: "nils end",
test: func() ([]byte, []byte) {
return []byte("foo"), nil
},
},
{
desc: "start is end",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:3], full[1:3]
},
expected: []byte("el"),
},
{
desc: "end contained in start",
test: func() ([]byte, []byte) {
full := []byte("hello world")
return full[1:7], full[2:4]
},
expected: []byte("ello w"),
},
{
desc: "different backing arrays",
test: func() ([]byte, []byte) {
one := []byte("hello world")
two := []byte("hello world")
return one, two
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
start, end := e.test()
if e.expected == nil {
require.Panics(t, func() {
unsafe.BytesRange(start, end)
})
} else {
res := unsafe.BytesRange(start, end)
require.Equal(t, e.expected, res)
}
})
}
}
+43 -24
View File
@@ -23,6 +23,7 @@
// //
// Because they lack location information, these types do not represent unique // Because they lack location information, these types do not represent unique
// moments or intervals of time. Use time.Time for that purpose. // moments or intervals of time. Use time.Time for that purpose.
package toml package toml
import ( import (
@@ -44,6 +45,7 @@ type LocalDate struct {
func LocalDateOf(t time.Time) LocalDate { func LocalDateOf(t time.Time) LocalDate {
var d LocalDate var d LocalDate
d.Year, d.Month, d.Day = t.Date() d.Year, d.Month, d.Day = t.Date()
return d return d
} }
@@ -51,8 +53,9 @@ func LocalDateOf(t time.Time) LocalDate {
func ParseLocalDate(s string) (LocalDate, error) { func ParseLocalDate(s string) (LocalDate, error) {
t, err := time.Parse("2006-01-02", s) t, err := time.Parse("2006-01-02", s)
if err != nil { if err != nil {
return LocalDate{}, err return LocalDate{}, fmt.Errorf("ParseLocalDate: %w", err)
} }
return LocalDateOf(t), nil return LocalDateOf(t), nil
} }
@@ -92,23 +95,28 @@ func (d LocalDate) DaysSince(s LocalDate) (days int) {
// We convert to Unix time so we do not have to worry about leap seconds: // We convert to Unix time so we do not have to worry about leap seconds:
// Unix time increases by exactly 86400 seconds per day. // Unix time increases by exactly 86400 seconds per day.
deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix() deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix()
return int(deltaUnix / 86400)
const secondsInADay = 86400
return int(deltaUnix / secondsInADay)
} }
// Before reports whether d1 occurs before d2. // Before reports whether d1 occurs before future date.
func (d1 LocalDate) Before(d2 LocalDate) bool { func (d LocalDate) Before(future LocalDate) bool {
if d1.Year != d2.Year { if d.Year != future.Year {
return d1.Year < d2.Year return d.Year < future.Year
}
if d1.Month != d2.Month {
return d1.Month < d2.Month
}
return d1.Day < d2.Day
} }
// After reports whether d1 occurs after d2. if d.Month != future.Month {
func (d1 LocalDate) After(d2 LocalDate) bool { return d.Month < future.Month
return d2.Before(d1) }
return d.Day < future.Day
}
// After reports whether d1 occurs after past date.
func (d LocalDate) After(past LocalDate) bool {
return past.Before(d)
} }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
@@ -122,6 +130,7 @@ func (d LocalDate) MarshalText() ([]byte, error) {
func (d *LocalDate) UnmarshalText(data []byte) error { func (d *LocalDate) UnmarshalText(data []byte) error {
var err error var err error
*d, err = ParseLocalDate(string(data)) *d, err = ParseLocalDate(string(data))
return err return err
} }
@@ -145,6 +154,7 @@ func LocalTimeOf(t time.Time) LocalTime {
var tm LocalTime var tm LocalTime
tm.Hour, tm.Minute, tm.Second = t.Clock() tm.Hour, tm.Minute, tm.Second = t.Clock()
tm.Nanosecond = t.Nanosecond() tm.Nanosecond = t.Nanosecond()
return tm return tm
} }
@@ -156,8 +166,9 @@ func LocalTimeOf(t time.Time) LocalTime {
func ParseLocalTime(s string) (LocalTime, error) { func ParseLocalTime(s string) (LocalTime, error) {
t, err := time.Parse("15:04:05.999999999", s) t, err := time.Parse("15:04:05.999999999", s)
if err != nil { if err != nil {
return LocalTime{}, err return LocalTime{}, fmt.Errorf("ParseLocalTime: %w", err)
} }
return LocalTimeOf(t), nil return LocalTimeOf(t), nil
} }
@@ -169,6 +180,7 @@ func (t LocalTime) String() string {
if t.Nanosecond == 0 { if t.Nanosecond == 0 {
return s return s
} }
return s + fmt.Sprintf(".%09d", t.Nanosecond) return s + fmt.Sprintf(".%09d", t.Nanosecond)
} }
@@ -176,6 +188,7 @@ func (t LocalTime) String() string {
func (t LocalTime) IsValid() bool { func (t LocalTime) IsValid() bool {
// Construct a non-zero time. // Construct a non-zero time.
tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC) tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC)
return LocalTimeOf(tm) == t return LocalTimeOf(tm) == t
} }
@@ -190,6 +203,7 @@ func (t LocalTime) MarshalText() ([]byte, error) {
func (t *LocalTime) UnmarshalText(data []byte) error { func (t *LocalTime) UnmarshalText(data []byte) error {
var err error var err error
*t, err = ParseLocalTime(string(data)) *t, err = ParseLocalTime(string(data))
return err return err
} }
@@ -223,9 +237,10 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) {
if err != nil { if err != nil {
t, err = time.Parse("2006-01-02t15:04:05.999999999", s) t, err = time.Parse("2006-01-02t15:04:05.999999999", s)
if err != nil { if err != nil {
return LocalDateTime{}, err return LocalDateTime{}, fmt.Errorf("ParseLocalDateTime: %w", err)
} }
} }
return LocalDateTimeOf(t), nil return LocalDateTimeOf(t), nil
} }
@@ -253,17 +268,20 @@ func (dt LocalDateTime) IsValid() bool {
// //
// In panics if loc is nil. // In panics if loc is nil.
func (dt LocalDateTime) In(loc *time.Location) time.Time { func (dt LocalDateTime) In(loc *time.Location) time.Time {
return time.Date(dt.Date.Year, dt.Date.Month, dt.Date.Day, dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc) return time.Date(
dt.Date.Year, dt.Date.Month, dt.Date.Day,
dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc,
)
} }
// Before reports whether dt1 occurs before dt2. // Before reports whether dt occurs before future.
func (dt1 LocalDateTime) Before(dt2 LocalDateTime) bool { func (dt LocalDateTime) Before(future LocalDateTime) bool {
return dt1.In(time.UTC).Before(dt2.In(time.UTC)) return dt.In(time.UTC).Before(future.In(time.UTC))
} }
// After reports whether dt1 occurs after dt2. // After reports whether dt occurs after past.
func (dt1 LocalDateTime) After(dt2 LocalDateTime) bool { func (dt LocalDateTime) After(past LocalDateTime) bool {
return dt2.Before(dt1) return past.Before(dt)
} }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
@@ -273,9 +291,10 @@ func (dt LocalDateTime) MarshalText() ([]byte, error) {
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
// The datetime is expected to be a string in a format accepted by ParseLocalDateTime // The datetime is expected to be a string in a format accepted by ParseLocalDateTime.
func (dt *LocalDateTime) UnmarshalText(data []byte) error { func (dt *LocalDateTime) UnmarshalText(data []byte) error {
var err error var err error
*dt, err = ParseLocalDateTime(string(data)) *dt, err = ParseLocalDateTime(string(data))
return err return err
} }
+49 -6
View File
@@ -26,6 +26,8 @@ func cmpEqual(x, y interface{}) bool {
} }
func TestDates(t *testing.T) { func TestDates(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
date LocalDate date LocalDate
loc *time.Location loc *time.Location
@@ -61,6 +63,8 @@ func TestDates(t *testing.T) {
} }
func TestDateIsValid(t *testing.T) { func TestDateIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
date LocalDate date LocalDate
want bool want bool
@@ -86,6 +90,10 @@ func TestDateIsValid(t *testing.T) {
} }
func TestParseDate(t *testing.T) { func TestParseDate(t *testing.T) {
t.Parallel()
var emptyDate LocalDate
for _, test := range []struct { for _, test := range []struct {
str string str string
want LocalDate // if empty, expect an error want LocalDate // if empty, expect an error
@@ -93,21 +101,23 @@ func TestParseDate(t *testing.T) {
{"2016-01-02", LocalDate{2016, 1, 2}}, {"2016-01-02", LocalDate{2016, 1, 2}},
{"2016-12-31", LocalDate{2016, 12, 31}}, {"2016-12-31", LocalDate{2016, 12, 31}},
{"0003-02-04", LocalDate{3, 2, 4}}, {"0003-02-04", LocalDate{3, 2, 4}},
{"999-01-26", LocalDate{}}, {"999-01-26", emptyDate},
{"", LocalDate{}}, {"", emptyDate},
{"2016-01-02x", LocalDate{}}, {"2016-01-02x", emptyDate},
} { } {
got, err := ParseLocalDate(test.str) got, err := ParseLocalDate(test.str)
if got != test.want { if got != test.want {
t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want) t.Errorf("ParseLocalDate(%q) = %+v, want %+v", test.str, got, test.want)
} }
if err != nil && test.want != (LocalDate{}) { if err != nil && test.want != (emptyDate) {
t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str) t.Errorf("Unexpected error %v from ParseLocalDate(%q)", err, test.str)
} }
} }
} }
func TestDateArithmetic(t *testing.T) { func TestDateArithmetic(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
desc string desc string
start LocalDate start LocalDate
@@ -167,6 +177,8 @@ func TestDateArithmetic(t *testing.T) {
} }
func TestDateBefore(t *testing.T) { func TestDateBefore(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
d1, d2 LocalDate d1, d2 LocalDate
want bool want bool
@@ -183,6 +195,8 @@ func TestDateBefore(t *testing.T) {
} }
func TestDateAfter(t *testing.T) { func TestDateAfter(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
d1, d2 LocalDate d1, d2 LocalDate
want bool want bool
@@ -198,6 +212,8 @@ func TestDateAfter(t *testing.T) {
} }
func TestTimeToString(t *testing.T) { func TestTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
str string str string
time LocalTime time LocalTime
@@ -212,6 +228,7 @@ func TestTimeToString(t *testing.T) {
gotTime, err := ParseLocalTime(test.str) gotTime, err := ParseLocalTime(test.str)
if err != nil { if err != nil {
t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err) t.Errorf("ParseLocalTime(%q): got error: %v", test.str, err)
continue continue
} }
if gotTime != test.time { if gotTime != test.time {
@@ -227,6 +244,8 @@ func TestTimeToString(t *testing.T) {
} }
func TestTimeOf(t *testing.T) { func TestTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time time.Time time time.Time
want LocalTime want LocalTime
@@ -241,6 +260,8 @@ func TestTimeOf(t *testing.T) {
} }
func TestTimeIsValid(t *testing.T) { func TestTimeIsValid(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time LocalTime time LocalTime
want bool want bool
@@ -265,6 +286,8 @@ func TestTimeIsValid(t *testing.T) {
} }
func TestDateTimeToString(t *testing.T) { func TestDateTimeToString(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
str string str string
dateTime LocalDateTime dateTime LocalDateTime
@@ -277,6 +300,7 @@ func TestDateTimeToString(t *testing.T) {
gotDateTime, err := ParseLocalDateTime(test.str) gotDateTime, err := ParseLocalDateTime(test.str)
if err != nil { if err != nil {
t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err) t.Errorf("ParseLocalDateTime(%q): got error: %v", test.str, err)
continue continue
} }
if gotDateTime != test.dateTime { if gotDateTime != test.dateTime {
@@ -292,6 +316,8 @@ func TestDateTimeToString(t *testing.T) {
} }
func TestParseDateTimeErrors(t *testing.T) { func TestParseDateTimeErrors(t *testing.T) {
t.Parallel()
for _, str := range []string{ for _, str := range []string{
"", "",
"2016-03-22", // just a date "2016-03-22", // just a date
@@ -306,6 +332,8 @@ func TestParseDateTimeErrors(t *testing.T) {
} }
func TestDateTimeOf(t *testing.T) { func TestDateTimeOf(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
time time.Time time time.Time
want LocalDateTime want LocalDateTime
@@ -322,6 +350,8 @@ func TestDateTimeOf(t *testing.T) {
} }
func TestDateTimeIsValid(t *testing.T) { func TestDateTimeIsValid(t *testing.T) {
t.Parallel()
// No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid. // No need to be exhaustive here; it's just LocalDate.IsValid && LocalTime.IsValid.
for _, test := range []struct { for _, test := range []struct {
dt LocalDateTime dt LocalDateTime
@@ -339,19 +369,24 @@ func TestDateTimeIsValid(t *testing.T) {
} }
func TestDateTimeIn(t *testing.T) { func TestDateTimeIn(t *testing.T) {
t.Parallel()
dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}} dt := LocalDateTime{LocalDate{2016, 1, 2}, LocalTime{3, 4, 5, 6}}
got := dt.In(time.UTC)
want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC) want := time.Date(2016, 1, 2, 3, 4, 5, 6, time.UTC)
if !got.Equal(want) { if got := dt.In(time.UTC); !got.Equal(want) {
t.Errorf("got %v, want %v", got, want) t.Errorf("got %v, want %v", got, want)
} }
} }
func TestDateTimeBefore(t *testing.T) { func TestDateTimeBefore(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31} d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1} d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8} t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9} t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct { for _, test := range []struct {
dt1, dt2 LocalDateTime dt1, dt2 LocalDateTime
want bool want bool
@@ -368,10 +403,13 @@ func TestDateTimeBefore(t *testing.T) {
} }
func TestDateTimeAfter(t *testing.T) { func TestDateTimeAfter(t *testing.T) {
t.Parallel()
d1 := LocalDate{2016, 12, 31} d1 := LocalDate{2016, 12, 31}
d2 := LocalDate{2017, 1, 1} d2 := LocalDate{2017, 1, 1}
t1 := LocalTime{5, 6, 7, 8} t1 := LocalTime{5, 6, 7, 8}
t2 := LocalTime{5, 6, 7, 9} t2 := LocalTime{5, 6, 7, 9}
for _, test := range []struct { for _, test := range []struct {
dt1, dt2 LocalDateTime dt1, dt2 LocalDateTime
want bool want bool
@@ -388,6 +426,8 @@ func TestDateTimeAfter(t *testing.T) {
} }
func TestMarshalJSON(t *testing.T) { func TestMarshalJSON(t *testing.T) {
t.Parallel()
for _, test := range []struct { for _, test := range []struct {
value interface{} value interface{}
want string want string
@@ -407,9 +447,12 @@ func TestMarshalJSON(t *testing.T) {
} }
func TestUnmarshalJSON(t *testing.T) { func TestUnmarshalJSON(t *testing.T) {
t.Parallel()
var d LocalDate var d LocalDate
var tm LocalTime var tm LocalTime
var dt LocalDateTime var dt LocalDateTime
for _, test := range []struct { for _, test := range []struct {
data string data string
ptr interface{} ptr interface{}
+140 -78
View File
@@ -18,10 +18,12 @@ import (
func Marshal(v interface{}) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
enc := NewEncoder(&buf) enc := NewEncoder(&buf)
err := enc.Encode(v) err := enc.Encode(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
@@ -96,7 +98,7 @@ func NewEncoder(w io.Writer) *Encoder {
// 5. Intermediate tables are always printed. // 5. Intermediate tables are always printed.
// //
// By default, strings are encoded as literal string, unless they contain either // By default, strings are encoded as literal string, unless they contain either
// a newline character or a single quote. In that case they are emited as quoted // a newline character or a single quote. In that case they are emitted as quoted
// strings. // strings.
// //
// When encoding structs, fields are encoded in order of definition, with their // When encoding structs, fields are encoded in order of definition, with their
@@ -107,25 +109,38 @@ func NewEncoder(w io.Writer) *Encoder {
// `multiline:"true"`: when the field contains a string, it will be emitted as // `multiline:"true"`: when the field contains a string, it will be emitted as
// a quoted multi-line TOML string. // a quoted multi-line TOML string.
func (enc *Encoder) Encode(v interface{}) error { func (enc *Encoder) Encode(v interface{}) error {
var b []byte var (
var ctx encoderCtx b []byte
ctx encoderCtx
)
b, err := enc.encode(b, ctx, reflect.ValueOf(v)) b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil { if err != nil {
return err return fmt.Errorf("Encode: %w", err)
}
_, err = enc.w.Write(b)
return err
} }
_, err = enc.w.Write(b)
if err != nil {
return fmt.Errorf("Encode: %w", err)
}
return nil
}
var errUnsupportedValue = errors.New("unsupported encode value kind")
//nolint:cyclop
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
//nolint:gocritic,godox
switch i := v.Interface().(type) { switch i := v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler case time.Time: // TODO: add TextMarshaler
b = i.AppendFormat(b, time.RFC3339) b = i.AppendFormat(b, time.RFC3339)
return b, nil return b, nil
} }
// containers
switch v.Kind() { switch v.Kind() {
// containers
case reflect.Map: case reflect.Map:
return enc.encodeMap(b, ctx, v) return enc.encodeMap(b, ctx, v)
case reflect.Struct: case reflect.Struct:
@@ -136,19 +151,18 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
if v.IsNil() { if v.IsNil() {
return nil, errNilInterface return nil, errNilInterface
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
case reflect.Ptr: case reflect.Ptr:
if v.IsNil() { if v.IsNil() {
return enc.encode(b, ctx, reflect.Zero(v.Type().Elem())) return enc.encode(b, ctx, reflect.Zero(v.Type().Elem()))
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
}
// values // values
var err error
switch v.Kind() {
case reflect.String: case reflect.String:
b, err = enc.encodeString(b, v.String(), ctx.options) b = enc.encodeString(b, v.String(), ctx.options)
case reflect.Float32: case reflect.Float32:
b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32) b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32)
case reflect.Float64: case reflect.Float64:
@@ -164,10 +178,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
b = strconv.AppendInt(b, v.Int(), 10) b = strconv.AppendInt(b, v.Int(), 10)
default: default:
err = fmt.Errorf("unsupported encode value kind: %s", v.Kind()) return nil, fmt.Errorf("encode(type %s): %w", v.Kind(), errUnsupportedValue)
}
if err != nil {
return nil, err
} }
return b, nil return b, nil
@@ -217,30 +228,31 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
const literalQuote = '\'' const literalQuote = '\''
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) ([]byte, error) { func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte {
if needsQuoting(v) { if needsQuoting(v) {
b = enc.encodeQuotedString(options.multiline, b, v) return enc.encodeQuotedString(options.multiline, b, v)
} else {
b = enc.encodeLiteralString(b, v)
} }
return b, nil
return enc.encodeLiteralString(b, v)
} }
func needsQuoting(v string) bool { func needsQuoting(v string) bool {
return strings.ContainsAny(v, "'\b\f\n\r\t") return strings.ContainsAny(v, "'\b\f\n\r\t")
} }
// caller should have checked that the string does not contain new lines or ' // caller should have checked that the string does not contain new lines or ' .
func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte { func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte {
b = append(b, literalQuote) b = append(b, literalQuote)
b = append(b, v...) b = append(b, v...)
b = append(b, literalQuote) b = append(b, literalQuote)
return b return b
} }
//nolint:cyclop
func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte { func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte {
const hextable = "0123456789ABCDEF"
stringQuote := `"` stringQuote := `"`
if multiline { if multiline {
stringQuote = `"""` stringQuote = `"""`
} }
@@ -250,6 +262,16 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, '\n') b = append(b, '\n')
} }
const (
hextable = "0123456789ABCDEF"
// U+0000 to U+0008, U+000A to U+001F, U+007F
nul = 0x0
bs = 0x8
lf = 0xa
us = 0x1f
del = 0x7f
)
for _, r := range []byte(v) { for _, r := range []byte(v) {
switch r { switch r {
case '\\': case '\\':
@@ -272,7 +294,7 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, `\t`...) b = append(b, `\t`...)
default: default:
switch { switch {
case r >= 0x0 && r <= 0x8, r >= 0xA && r <= 0x1F, r == 0x7F: case r >= nul && r <= bs, r >= lf && r <= us, r == del:
b = append(b, `\u00`...) b = append(b, `\u00`...)
b = append(b, hextable[r>>4]) b = append(b, hextable[r>>4])
b = append(b, hextable[r&0x0f]) b = append(b, hextable[r&0x0f])
@@ -280,14 +302,14 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
b = append(b, r) b = append(b, r)
} }
} }
// U+0000 to U+0008, U+000A to U+001F, U+007F
} }
b = append(b, stringQuote...) b = append(b, stringQuote...)
return b return b
} }
// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ // called should have checked that the string is in A-Z / a-z / 0-9 / - / _ .
func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte { func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...) return append(b, v...)
} }
@@ -300,6 +322,7 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
b = append(b, '[') b = append(b, '[')
var err error var err error
b, err = enc.encodeKey(b, key[0]) b, err = enc.encodeKey(b, key[0])
if err != nil { if err != nil {
return nil, err return nil, err
@@ -307,6 +330,7 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
for _, k := range key[1:] { for _, k := range key[1:] {
b = append(b, '.') b = append(b, '.')
b, err = enc.encodeKey(b, k) b, err = enc.encodeKey(b, k)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -318,6 +342,9 @@ func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) {
return b, nil return b, nil
} }
var errTomlNoMultiline = errors.New("TOML does not support multiline keys")
//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
needsQuotation := false needsQuotation := false
cannotUseLiteral := false cannotUseLiteral := false
@@ -326,32 +353,39 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
continue continue
} }
if c == '\n' { if c == '\n' {
return nil, fmt.Errorf("TOML does not support multiline keys") return nil, errTomlNoMultiline
} }
if c == literalQuote { if c == literalQuote {
cannotUseLiteral = true cannotUseLiteral = true
} }
needsQuotation = true needsQuotation = true
} }
if cannotUseLiteral { switch {
b = enc.encodeQuotedString(false, b, k) case cannotUseLiteral:
} else if needsQuotation { return enc.encodeQuotedString(false, b, k), nil
b = enc.encodeLiteralString(b, k) case needsQuotation:
} else { return enc.encodeLiteralString(b, k), nil
b = enc.encodeUnquotedKey(b, k) default:
return enc.encodeUnquotedKey(b, k), nil
}
} }
return b, nil var errNotSupportedAsMapKey = errors.New("type not supported as map key")
}
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String { if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("type '%s' not supported as map key", v.Type().Key().Kind()) return nil, fmt.Errorf("encodeMap '%s': %w", v.Type().Key().Kind(), errNotSupportedAsMapKey)
} }
t := table{} var (
t table
emptyValueOptions valueOptions
)
iter := v.MapRange() iter := v.MapRange()
for iter.Next() { for iter.Next() {
@@ -368,9 +402,9 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte
} }
if table { if table {
t.pushTable(k, v, valueOptions{}) t.pushTable(k, v, emptyValueOptions)
} else { } else {
t.pushKV(k, v, valueOptions{}) t.pushKV(k, v, emptyValueOptions)
} }
} }
@@ -405,13 +439,10 @@ func (t *table) pushTable(k string, v reflect.Value, options valueOptions) {
t.tables = append(t.tables, entry{Key: k, Value: v, Options: options}) t.tables = append(t.tables, entry{Key: k, Value: v, Options: options})
} }
func (t *table) hasKVs() bool {
return len(t.kvs) > 0
}
func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
t := table{} var t table
//nolint:godox
// TODO: cache this? // TODO: cache this?
typ := v.Type() typ := v.Type()
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
@@ -443,7 +474,7 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b
return nil, err return nil, err
} }
options := valueOptions{} var options valueOptions
ml, ok := fieldType.Tag.Lookup("multiline") ml, ok := fieldType.Tag.Lookup("multiline")
if ok { if ok {
@@ -466,38 +497,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
ctx.shiftKey() ctx.shiftKey()
if ctx.insideKv { if ctx.insideKv {
b = append(b, '{') return enc.encodeTableInsideKV(b, ctx, t)
first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}\n"...)
return b, nil
} }
if !ctx.skipTableHeader { if !ctx.skipTableHeader {
@@ -510,29 +510,76 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
for _, kv := range t.kvs { for _, kv := range t.kvs {
ctx.setKey(kv.Key) ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value) b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b = append(b, '\n') b = append(b, '\n')
} }
for _, table := range t.tables { for _, table := range t.tables {
ctx.setKey(table.Key) ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value) b, err = enc.encode(b, ctx, table.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b = append(b, '\n') b = append(b, '\n')
} }
return b, nil return b, nil
} }
func (enc *Encoder) encodeTableInsideKV(b []byte, ctx encoderCtx, t table) ([]byte, error) {
var err error
b = append(b, '{')
first := true
for _, kv := range t.kvs {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
if err != nil {
return nil, err
}
}
for _, table := range t.tables {
if first {
first = false
} else {
b = append(b, `, `...)
}
ctx.setKey(table.Key)
b, err = enc.encode(b, ctx, table.Value)
if err != nil {
return nil, err
}
b = append(b, '\n')
}
b = append(b, "}\n"...)
return b, nil
}
var errNilInterface = errors.New("nil interface not supported") var errNilInterface = errors.New("nil interface not supported")
var errNilPointer = errors.New("nil pointer not supported")
func willConvertToTable(v reflect.Value) (bool, error) { func willConvertToTable(v reflect.Value) (bool, error) {
//nolint:gocritic,godox
switch v.Interface().(type) { switch v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler case time.Time: // TODO: add TextMarshaler
return false, nil return false, nil
@@ -546,11 +593,13 @@ func willConvertToTable(v reflect.Value) (bool, error) {
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, errNilInterface
} }
return willConvertToTable(v.Elem()) return willConvertToTable(v.Elem())
case reflect.Ptr: case reflect.Ptr:
if v.IsNil() { if v.IsNil() {
return false, nil return false, nil
} }
return willConvertToTable(v.Elem()) return willConvertToTable(v.Elem())
default: default:
return false, nil return false, nil
@@ -564,6 +613,7 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, errNilInterface
} }
return willConvertToTableOrArrayTable(v.Elem()) return willConvertToTableOrArrayTable(v.Elem())
} }
@@ -572,15 +622,18 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
// An empty slice should be a kv = []. // An empty slice should be a kv = [].
return false, nil return false, nil
} }
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
t, err := willConvertToTable(v.Index(i)) t, err := willConvertToTable(v.Index(i))
if err != nil { if err != nil {
return false, err return false, err
} }
if !t { if !t {
return false, nil return false, nil
} }
} }
return true, nil return true, nil
} }
@@ -590,6 +643,7 @@ func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) {
func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Len() == 0 { if v.Len() == 0 {
b = append(b, "[]"...) b = append(b, "[]"...)
return b, nil return b, nil
} }
@@ -617,25 +671,30 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
var err error var err error
scratch := make([]byte, 0, 64) scratch := make([]byte, 0, 64)
scratch = append(scratch, "[["...) scratch = append(scratch, "[["...)
for i, k := range ctx.parentKey { for i, k := range ctx.parentKey {
if i > 0 { if i > 0 {
scratch = append(scratch, '.') scratch = append(scratch, '.')
} }
scratch, err = enc.encodeKey(scratch, k) scratch, err = enc.encodeKey(scratch, k)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
scratch = append(scratch, "]]\n"...) scratch = append(scratch, "]]\n"...)
ctx.skipTableHeader = true ctx.skipTableHeader = true
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
b = append(b, scratch...) b = append(b, scratch...)
b, err = enc.encode(b, ctx, v.Index(i)) b, err = enc.encode(b, ctx, v.Index(i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return b, nil return b, nil
} }
@@ -644,10 +703,12 @@ func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value
var err error var err error
first := true first := true
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if !first { if !first {
b = append(b, ", "...) b = append(b, ", "...)
} }
first = false first = false
b, err = enc.encode(b, ctx, v.Index(i)) b, err = enc.encode(b, ctx, v.Index(i))
@@ -657,5 +718,6 @@ func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value
} }
b = append(b, ']') b = append(b, ']')
return b, nil return b, nil
} }
+21 -2
View File
@@ -11,7 +11,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
//nolint:funlen
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
t.Parallel()
examples := []struct { examples := []struct {
desc string desc string
v interface{} v interface{}
@@ -65,6 +68,7 @@ hello = 'world'`,
a = 'test'`, a = 'test'`,
}, },
{ {
//nolint:godox
// TODO: this test is flaky because output changes depending on // TODO: this test is flaky because output changes depending on
// the map iteration order. // the map iteration order.
desc: "map in map in map and string with values", desc: "map in map in map and string with values",
@@ -89,6 +93,16 @@ a = 'test'`,
}, },
expected: `array = ['one', 'two', 'three']`, expected: `array = ['one', 'two', 'three']`,
}, },
{
desc: "empty string array",
v: map[string][]string{},
expected: ``,
},
{
desc: "map",
v: map[string][]string{},
expected: ``,
},
{ {
desc: "nested string arrays", desc: "nested string arrays",
v: map[string][][]string{ v: map[string][][]string{
@@ -104,7 +118,7 @@ a = 'test'`,
expected: `array = ['a string', ['one', 'two'], 'last']`, expected: `array = ['a string', ['one', 'two'], 'last']`,
}, },
{ {
desc: "slice of maps", desc: "array of maps",
v: map[string][]map[string]string{ v: map[string][]map[string]string{
"top": { "top": {
{"map1.1": "v1.1"}, {"map1.1": "v1.1"},
@@ -157,7 +171,7 @@ K2 = 'v2'
`, `,
}, },
{ {
desc: "structs in slice with interfaces", desc: "structs in array with interfaces",
v: map[string]interface{}{ v: map[string]interface{}{
"root": map[string]interface{}{ "root": map[string]interface{}{
"nested": []interface{}{ "nested": []interface{}{
@@ -237,7 +251,10 @@ world"""`,
} }
for _, e := range examples { for _, e := range examples {
e := e
t.Run(e.desc, func(t *testing.T) { t.Run(e.desc, func(t *testing.T) {
t.Parallel()
b, err := toml.Marshal(e.v) b, err := toml.Marshal(e.v)
if e.err { if e.err {
require.Error(t, err) require.Error(t, err)
@@ -256,6 +273,8 @@ func equalStringsIgnoreNewlines(t *testing.T, expected string, actual string) {
} }
func TestIssue436(t *testing.T) { func TestIssue436(t *testing.T) {
t.Parallel()
data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`)
var v interface{} var v interface{}
+6 -2
View File
@@ -155,12 +155,14 @@ func TestParser_AST_Numbers(t *testing.T) {
} }
} }
type astRoot []astNode type (
type astNode struct { astRoot []astNode
astNode struct {
Kind ast.Kind Kind ast.Kind
Data []byte Data []byte
Children []astNode Children []astNode
} }
)
func compareAST(t *testing.T, expected astRoot, actual *ast.Root) { func compareAST(t *testing.T, expected astRoot, actual *ast.Root) {
it := actual.Iterator() it := actual.Iterator()
@@ -168,6 +170,7 @@ func compareAST(t *testing.T, expected astRoot, actual *ast.Root) {
} }
func compareNode(t *testing.T, e astNode, n ast.Node) { func compareNode(t *testing.T, e astNode, n ast.Node) {
t.Helper()
require.Equal(t, e.Kind, n.Kind) require.Equal(t, e.Kind, n.Kind)
require.Equal(t, e.Data, n.Data) require.Equal(t, e.Data, n.Data)
@@ -175,6 +178,7 @@ func compareNode(t *testing.T, e astNode, n ast.Node) {
} }
func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) { func compareIterator(t *testing.T, expected []astNode, actual ast.Iterator) {
t.Helper()
idx := 0 idx := 0
for actual.Next() { for actual.Next() {
+26 -18
View File
@@ -2,26 +2,34 @@ package toml
import "fmt" import "fmt"
func scanFollows(pattern []byte) func(b []byte) bool { func scanFollows(b []byte, pattern string) bool {
return func(b []byte) bool { n := len(pattern)
if len(b) < len(pattern) { return len(b) >= n && string(b[:n]) == pattern
return false
}
for i, c := range pattern {
if b[i] != c {
return false
}
}
return true
}
} }
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'}) func scanFollowsMultilineBasicStringDelimiter(b []byte) bool {
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''}) return scanFollows(b, `"""`)
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'}) }
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'}) func scanFollowsMultilineLiteralStringDelimiter(b []byte) bool {
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'}) return scanFollows(b, `'''`)
}
func scanFollowsTrue(b []byte) bool {
return scanFollows(b, `true`)
}
func scanFollowsFalse(b []byte) bool {
return scanFollows(b, `false`)
}
func scanFollowsInf(b []byte) bool {
return scanFollows(b, `inf`)
}
func scanFollowsNan(b []byte) bool {
return scanFollows(b, `nan`)
}
func scanUnquotedKey(b []byte) ([]byte, []byte, error) { func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
// unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
+79
View File
@@ -0,0 +1,79 @@
package toml
import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker"
)
type strict struct {
Enabled bool
// Tracks the current key being processed.
key tracker.KeyTracker
missing []decodeError
}
func (s *strict) EnterTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateTable(node)
}
func (s *strict) EnterArrayTable(node ast.Node) {
if !s.Enabled {
return
}
s.key.UpdateArrayTable(node)
}
func (s *strict) EnterKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Push(node)
}
func (s *strict) ExitKeyValue(node ast.Node) {
if !s.Enabled {
return
}
s.key.Pop(node)
}
func (s *strict) MissingTable(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing table",
key: s.key.Key(),
})
}
func (s *strict) MissingField(node ast.Node) {
if !s.Enabled {
return
}
s.missing = append(s.missing, decodeError{
highlight: keyLocation(node),
message: "missing field",
key: s.key.Key(),
})
}
func (s *strict) Error(doc []byte) error {
if !s.Enabled || len(s.missing) == 0 {
return nil
}
err := &StrictMissingError{
Errors: make([]DecodeError, 0, len(s.missing)),
}
for _, derr := range s.missing {
err.Errors = append(err.Errors, *wrapDecodeError(doc, &derr))
}
return err
}
+3 -4
View File
@@ -516,11 +516,10 @@ func scopeStruct(v reflect.Value, name string) (target, bool, error) {
l := len(path) l := len(path)
path = append(path, i) path = append(path, i)
f := t.Field(i) f := t.Field(i)
if f.PkgPath != "" { if f.Anonymous {
// only consider exported fields
} else if f.Anonymous {
walk(v.Field(i)) walk(v.Field(i))
} else { } else if f.PkgPath == "" {
// only consider exported fields
fieldName, ok := f.Tag.Lookup("toml") fieldName, ok := f.Tag.Lookup("toml")
if !ok { if !ok {
fieldName = f.Name fieldName = f.Name
+1
View File
@@ -59,6 +59,7 @@ val = string / boolean / array / inline-table / date-time / float / integer
;; String ;; String
string = ml-basic-string / basic-string / ml-literal-string / literal-string string = ml-basic-string / basic-string / ml-literal-string / literal-string
;; Basic String ;; Basic String
basic-string = quotation-mark *basic-char quotation-mark basic-string = quotation-mark *basic-char quotation-mark
+49 -2
View File
@@ -10,6 +10,7 @@ import (
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/internal/unsafe"
) )
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
@@ -21,7 +22,11 @@ func Unmarshal(data []byte, v interface{}) error {
// Decoder reads and decode a TOML document from an input stream. // Decoder reads and decode a TOML document from an input stream.
type Decoder struct { type Decoder struct {
// input
r io.Reader r io.Reader
// global settings
strict bool
} }
// NewDecoder creates a new Decoder that will read from r. // NewDecoder creates a new Decoder that will read from r.
@@ -29,6 +34,16 @@ func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r} return &Decoder{r: r}
} }
// SetStrict toggles decoding in stict mode.
//
// When the decoder is in strict mode, it will record fields from the document
// that could not be set on the target value. In that case, the decoder returns
// a StrictMissingError that can be used to retrieve the individual errors as
// well as generate a human readable description of the missing fields.
func (d *Decoder) SetStrict(strict bool) {
d.strict = strict
}
// Decode the whole content of r into v. // Decode the whole content of r into v.
// //
// When a TOML local date is decoded into a time.Time, its value is represented // When a TOML local date is decoded into a time.Time, its value is represented
@@ -43,7 +58,11 @@ func (d *Decoder) Decode(v interface{}) error {
} }
p := parser{} p := parser{}
p.Reset(b) p.Reset(b)
dec := decoder{} dec := decoder{
strict: strict{
Enabled: d.strict,
},
}
return dec.FromParser(&p, v) return dec.FromParser(&p, v)
} }
@@ -52,7 +71,10 @@ type decoder struct {
arrayIndexes map[reflect.Value]int arrayIndexes map[reflect.Value]int
// Tracks keys that have been seen, with which type. // Tracks keys that have been seen, with which type.
seen tracker.Seen seen tracker.SeenTracker
// Strict mode
strict strict
} }
func (d *decoder) arrayIndex(append bool, v reflect.Value) int { func (d *decoder) arrayIndex(append bool, v reflect.Value) int {
@@ -79,9 +101,27 @@ func (d *decoder) FromParser(p *parser, v interface{}) error {
err = wrapDecodeError(p.data, de) err = wrapDecodeError(p.data, de)
} }
} }
if err == nil {
err = d.strict.Error(p.data)
}
return err return err
} }
func keyLocation(node ast.Node) []byte {
k := node.Key()
hasOne := k.Next()
if !hasOne {
panic("should not be called with empty key")
}
start := k.Node().Data
end := k.Node().Data
for k.Next() {
end = k.Node().Data
}
return unsafe.BytesRange(start, end)
}
func (d *decoder) fromParser(p *parser, v interface{}) error { func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v) r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr { if r.Kind() != reflect.Ptr {
@@ -113,6 +153,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
err = d.unmarshalKeyValue(current, node) err = d.unmarshalKeyValue(current, node)
found = true found = true
case ast.Table: case ast.Table:
d.strict.EnterTable(node)
current, found, err = d.scopeWithKey(root, node.Key()) current, found, err = d.scopeWithKey(root, node.Key())
if err == nil && found { if err == nil && found {
// In case this table points to an interface, // In case this table points to an interface,
@@ -123,6 +164,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
ensureMapIfInterface(current) ensureMapIfInterface(current)
} }
case ast.ArrayTable: case ast.ArrayTable:
d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key()) current, found, err = d.scopeWithArrayTable(root, node.Key())
default: default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
@@ -134,6 +176,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
if !found { if !found {
skipUntilTable = true skipUntilTable = true
d.strict.MissingTable(node)
} }
} }
@@ -217,6 +260,9 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error { func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
assertNode(ast.KeyValue, node) assertNode(ast.KeyValue, node)
d.strict.EnterKeyValue(node)
defer d.strict.ExitKeyValue(node)
x, found, err := d.scopeWithKey(x, node.Key()) x, found, err := d.scopeWithKey(x, node.Key())
if err != nil { if err != nil {
return err return err
@@ -224,6 +270,7 @@ func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
// A struct in the path was not found. Skip this value. // A struct in the path was not found. Skip this value.
if !found { if !found {
d.strict.MissingField(node)
return nil return nil
} }
+183 -4
View File
@@ -1,8 +1,10 @@
package toml_test package toml_test
import ( import (
"fmt"
"math" "math"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@@ -132,6 +134,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan", desc: "nan",
input: `nan`, input: `nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -139,6 +142,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan negative", desc: "nan negative",
input: `-nan`, input: `-nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -146,6 +150,7 @@ func TestUnmarshal_Floats(t *testing.T) {
desc: "nan positive", desc: "nan positive",
input: `+nan`, input: `+nan`,
testFn: func(t *testing.T, v float64) { testFn: func(t *testing.T, v float64) {
t.Helper()
assert.True(t, math.IsNaN(v)) assert.True(t, math.IsNaN(v))
}, },
}, },
@@ -706,6 +711,42 @@ B = "data"`,
} }
}, },
}, },
{
desc: "windows line endings",
input: "A = 1\r\n\r\nB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
expected: &map[string]interface{}{
"A": int64(1),
"B": int64(2),
},
}
},
},
{
desc: "dangling CR",
input: "A = 1\r",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
{
desc: "missing NL after CR",
input: "A = 1\rB = 2",
gen: func() test {
doc := map[string]interface{}{}
return test{
target: &doc,
err: true,
}
},
},
} }
for _, e := range examples { for _, e := range examples {
@@ -759,8 +800,10 @@ func TestIssue484(t *testing.T) {
}, cfg) }, cfg)
} }
type Map458 map[string]interface{} type (
type Slice458 []interface{} Map458 map[string]interface{}
Slice458 []interface{}
)
func (m Map458) A(s string) Slice458 { func (m Map458) A(s string) Slice458 {
return m[s].([]interface{}) return m[s].([]interface{})
@@ -779,7 +822,8 @@ version = "0.1.0"`)
map[string]interface{}{ map[string]interface{}{
"dependencies": []interface{}{"regex"}, "dependencies": []interface{}{"regex"},
"name": "decode", "name": "decode",
"version": "0.1.0"}, "version": "0.1.0",
},
} }
assert.Equal(t, expected, a) assert.Equal(t, expected, a)
} }
@@ -790,7 +834,7 @@ func TestIssue252(t *testing.T) {
Val2 string `toml:"val2"` Val2 string `toml:"val2"`
} }
var configFile = []byte( configFile := []byte(
` `
val1 = "test1" val1 = "test1"
`) `)
@@ -818,6 +862,13 @@ bar = 2021-04-08
require.NoError(t, err) require.NoError(t, err)
} }
func TestIssue507(t *testing.T) {
data := []byte{'0', '=', '\n', '0', 'a', 'm', 'e'}
m := map[string]interface{}{}
err := toml.Unmarshal(data, &m)
require.Error(t, err)
}
func TestUnmarshalDecodeErrors(t *testing.T) { func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct { examples := []struct {
desc string desc string
@@ -924,3 +975,131 @@ func TestIssue287(t *testing.T) {
} }
require.Equal(t, expected, v) require.Equal(t, expected, v)
} }
func TestIssue508(t *testing.T) {
type head struct {
Title string `toml:"title"`
}
type text struct {
head
}
b := []byte(`title = "This is a title"`)
t1 := text{}
err := toml.Unmarshal(b, &t1)
require.NoError(t, err)
require.Equal(t, "This is a title", t1.head.Title)
}
func TestDecoderStrict(t *testing.T) {
examples := []struct {
desc string
input string
expected string
target interface{}
}{
{
desc: "multiple missing root keys",
input: `
key1 = "value1"
key2 = "missing2"
key3 = "missing3"
key4 = "value4"
`,
expected: `
2| key1 = "value1"
3| key2 = "missing2"
| ~~~~ missing field
4| key3 = "missing3"
5| key4 = "value4"
---
2| key1 = "value1"
3| key2 = "missing2"
4| key3 = "missing3"
| ~~~~ missing field
5| key4 = "value4"
`,
target: &struct {
Key1 string
Key4 string
}{},
},
{
desc: "multi-part key",
input: `a.short.key="foo"`,
expected: `
1| a.short.key="foo"
| ~~~~~~~~~~~ missing field
`,
},
{
desc: "missing table",
input: `
[foo]
bar = 42
`,
expected: `
2| [foo]
| ~~~ missing table
3| bar = 42
`,
},
{
desc: "missing array table",
input: `
[[foo]]
bar = 42
`,
expected: `
2| [[foo]]
| ~~~ missing table
3| bar = 42
`,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
r := strings.NewReader(e.input)
d := toml.NewDecoder(r)
d.SetStrict(true)
x := e.target
if x == nil {
x = &struct{}{}
}
err := d.Decode(x)
details := err.(*toml.StrictMissingError)
equalStringsIgnoreNewlines(t, e.expected, details.String())
})
}
}
func ExampleDecoder_SetStrict() {
type S struct {
Key1 string
Key3 string
}
doc := `
key1 = "value1"
key2 = "value2"
key3 = "value3"
`
r := strings.NewReader(doc)
d := toml.NewDecoder(r)
d.SetStrict(true)
s := S{}
err := d.Decode(&s)
fmt.Println(err.Error())
// Output: strict mode: fields in the document are missing in the target struct
details := err.(*toml.StrictMissingError)
fmt.Println(details.String())
// Ouput:
// 2| key1 = "value1"
// 3| key2 = "value2"
// | ~~~~ missing field
// 4| key3 = "value3"
}