v2: errors (#534)

```
name                              old time/op    new time/op    delta
UnmarshalDataset/config-32          86.7ms ± 2%    87.5ms ± 2%     ~     (p=0.113 n=9+10)
UnmarshalDataset/canada-32           129ms ± 4%     106ms ± 3%  -17.94%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    59.4ms ± 5%    58.7ms ± 5%     ~     (p=0.393 n=10+10)
UnmarshalDataset/twitter-32         27.0ms ± 7%    26.9ms ± 6%     ~     (p=0.720 n=10+9)
UnmarshalDataset/code-32             326ms ± 4%     322ms ± 7%     ~     (p=0.661 n=9+10)
UnmarshalDataset/example-32          510µs ±11%     526µs ± 7%     ~     (p=0.182 n=10+9)
UnmarshalSimple-32                  1.41µs ± 6%    1.41µs ± 4%     ~     (p=0.736 n=10+9)
ReferenceFile-32                    45.6µs ± 3%    43.9µs ±10%     ~     (p=0.089 n=10+10)

name                              old speed      new speed      delta
UnmarshalDataset/config-32        12.1MB/s ± 2%  12.0MB/s ± 2%     ~     (p=0.108 n=9+10)
UnmarshalDataset/canada-32        17.1MB/s ± 4%  20.9MB/s ± 3%  +21.86%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32  9.41MB/s ± 5%  9.51MB/s ± 5%     ~     (p=0.362 n=10+10)
UnmarshalDataset/twitter-32       16.4MB/s ± 8%  16.5MB/s ± 6%     ~     (p=0.704 n=10+9)
UnmarshalDataset/code-32          8.24MB/s ± 4%  8.34MB/s ± 7%     ~     (p=0.675 n=9+10)
UnmarshalDataset/example-32       15.9MB/s ±11%  15.4MB/s ± 7%     ~     (p=0.182 n=10+9)
ReferenceFile-32                   115MB/s ± 4%   120MB/s ±10%     ~     (p=0.085 n=10+10)

name                              old alloc/op   new alloc/op   delta
UnmarshalDataset/config-32          16.9MB ± 0%    16.9MB ± 0%   -0.02%  (p=0.000 n=10+10)
UnmarshalDataset/canada-32          76.8MB ± 0%    74.3MB ± 0%   -3.31%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    37.3MB ± 0%    37.1MB ± 0%   -0.60%  (p=0.000 n=9+10)
UnmarshalDataset/twitter-32         15.6MB ± 0%    15.6MB ± 0%   -0.09%  (p=0.000 n=10+10)
UnmarshalDataset/code-32            60.2MB ± 0%    59.3MB ± 0%   -1.51%  (p=0.000 n=10+9)
UnmarshalDataset/example-32          238kB ± 0%     238kB ± 0%   -0.18%  (p=0.000 n=10+10)
ReferenceFile-32                    11.8kB ± 0%    11.8kB ± 0%     ~     (all equal)

name                              old allocs/op  new allocs/op  delta
UnmarshalDataset/config-32            653k ± 0%      645k ± 0%   -1.20%  (p=0.000 n=10+6)
UnmarshalDataset/canada-32           1.01M ± 0%     0.90M ± 0%  -11.04%  (p=0.000 n=9+10)
UnmarshalDataset/citm_catalog-32      384k ± 0%      370k ± 0%   -3.75%  (p=0.000 n=10+10)
UnmarshalDataset/twitter-32           160k ± 0%      157k ± 0%   -1.32%  (p=0.000 n=10+10)
UnmarshalDataset/code-32             2.97M ± 0%     2.91M ± 0%   -2.15%  (p=0.000 n=10+7)
UnmarshalDataset/example-32          3.69k ± 0%     3.63k ± 0%   -1.52%  (p=0.000 n=10+10)
ReferenceFile-32                       253 ± 0%       253 ± 0%     ~     (all equal)
```
This commit is contained in:
Thomas Pelletier
2021-05-08 16:04:25 -04:00
committed by GitHub
parent 4545a3e94b
commit ea225df3ed
12 changed files with 325 additions and 656 deletions
+66 -111
View File
@@ -1,11 +1,8 @@
package toml
import (
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
)
@@ -59,14 +56,12 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, nil
}
var errNotDigit = errors.New("not a digit")
func parseDecimalDigits(b []byte) (int, error) {
v := 0
for _, c := range b {
for i, c := range b {
if !isDigit(c) {
return 0, fmt.Errorf("%s: %w", b, errNotDigit)
return 0, newDecodeError(b[i:i+1], "should be a digit (0-9)")
}
v *= 10
@@ -76,13 +71,14 @@ func parseDecimalDigits(b []byte) (int, error) {
return v, nil
}
var errParseDateTimeMissingInfo = errors.New("date-time missing timezone information")
func parseDateTime(b []byte) (time.Time, error) {
// offset-date-time = full-date time-delim full-time
// full-time = partial-time time-offset
// time-offset = "Z" / time-numoffset
// time-numoffset = ( "+" / "-" ) time-hour ":" time-minute
originalBytes := b
dt, b, err := parseLocalDateTime(b)
if err != nil {
return time.Time{}, err
@@ -91,7 +87,7 @@ func parseDateTime(b []byte) (time.Time, error) {
var zone *time.Location
if len(b) == 0 {
return time.Time{}, errParseDateTimeMissingInfo
return time.Time{}, newDecodeError(originalBytes, "date-time is missing timezone")
}
if b[0] == 'Z' {
@@ -134,19 +130,12 @@ func parseDateTime(b []byte) (time.Time, error) {
return t, nil
}
var (
errParseLocalDateTimeWrongLength = errors.New(
"local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]",
)
errParseLocalDateTimeWrongSeparator = errors.New("datetime separator is expected to be T or a space")
)
func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
var dt LocalDateTime
const localDateTimeByteMinLen = 11
if len(b) < localDateTimeByteMinLen {
return dt, nil, errParseLocalDateTimeWrongLength
return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]")
}
date, err := parseLocalDate(b[:10])
@@ -157,7 +146,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
sep := b[10]
if sep != 'T' && sep != ' ' {
return dt, nil, errParseLocalDateTimeWrongSeparator
return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space")
}
t, rest, err := parseLocalTime(b[11:])
@@ -169,8 +158,6 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
return dt, rest, nil
}
var errParseLocalTimeWrongLength = errors.New("times are expected to have the format HH:MM:SS[.NNNNNN]")
// parseLocalTime is a bit different because it also returns the remaining
// []byte that is didn't need. This is to allow parseDateTime to parse those
// remaining bytes as a timezone.
@@ -183,7 +170,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
const localTimeByteLen = 8
if len(b) < localTimeByteLen {
return t, nil, errParseLocalTimeWrongLength
return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
}
var err error
@@ -242,11 +229,6 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, b[8:], nil
}
var (
errParseFloatStartDot = errors.New("float cannot start with a dot")
errParseFloatEndDot = errors.New("float cannot end with a dot")
)
//nolint:cyclop
func parseFloat(b []byte) (float64, error) {
//nolint:godox
@@ -255,150 +237,123 @@ func parseFloat(b []byte) (float64, error) {
return math.NaN(), nil
}
tok := string(b)
err := numberContainsInvalidUnderscore(tok)
cleaned, err := checkAndRemoveUnderscores(b)
if err != nil {
return 0, err
}
cleanedVal := cleanupNumberToken(tok)
if cleanedVal[0] == '.' {
return 0, errParseFloatStartDot
if cleaned[0] == '.' {
return 0, newDecodeError(b, "float cannot start with a dot")
}
if cleanedVal[len(cleanedVal)-1] == '.' {
return 0, errParseFloatEndDot
if cleaned[len(cleaned)-1] == '.' {
return 0, newDecodeError(b, "float cannot end with a dot")
}
f, err := strconv.ParseFloat(cleanedVal, 64)
f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil {
return 0, fmt.Errorf("coudn't ParseFloat %w", err)
return 0, newDecodeError(b, "coudn't parse float: %w", err)
}
return f, nil
}
func parseIntHex(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
cleaned, err := checkAndRemoveUnderscores(b[2:])
if err != nil {
return 0, err
}
i, err := strconv.ParseInt(cleanedVal[2:], 16, 64)
i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil {
return 0, fmt.Errorf("coudn't ParseIntHex %w", err)
return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
}
return i, nil
}
func parseIntOct(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
cleaned, err := checkAndRemoveUnderscores(b[2:])
if err != nil {
return 0, err
}
i, err := strconv.ParseInt(cleanedVal[2:], 8, 64)
i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil {
return 0, fmt.Errorf("coudn't ParseIntOct %w", err)
return 0, newDecodeError(b, "couldn't parse octal number: %w", err)
}
return i, nil
}
func parseIntBin(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
cleaned, err := checkAndRemoveUnderscores(b[2:])
if err != nil {
return 0, err
}
i, err := strconv.ParseInt(cleanedVal[2:], 2, 64)
i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil {
return 0, fmt.Errorf("coudn't ParseIntBin %w", err)
return 0, newDecodeError(b, "couldn't parse binary number: %w", err)
}
return i, nil
}
func parseIntDec(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
cleaned, err := checkAndRemoveUnderscores(b)
if err != nil {
return 0, err
}
i, err := strconv.ParseInt(cleanedVal, 10, 64)
i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil {
return 0, fmt.Errorf("coudn't parseIntDec %w", err)
return 0, newDecodeError(b, "couldn't parse decimal number: %w", err)
}
return i, nil
}
func numberContainsInvalidUnderscore(value string) error {
// For large numbers, you may use underscores between digits to enhance
// readability. Each underscore must be surrounded by at least one digit on
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
func checkAndRemoveUnderscores(b []byte) ([]byte, error) {
if len(b) == 0 {
return b, nil
}
return nil
}
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
if b[0] == '_' {
return nil, newDecodeError(b[0:1], "number cannot start with underscore")
}
return nil
if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore")
}
// fast path
i := 0
for ; i < len(b); i++ {
if b[i] == '_' {
break
}
}
if i == len(b) {
return b, nil
}
before := false
cleaned := make([]byte, i, len(b))
copy(cleaned, b)
for i++; i < len(b); i++ {
c := b[i]
if c == '_' {
if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores")
}
before = false
} else {
before = true
cleaned = append(cleaned, c)
}
}
return cleaned, nil
}
func cleanupNumberToken(value string) string {
cleanedVal := strings.ReplaceAll(value, "_", "")
return cleanedVal
}
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var (
errInvalidUnderscore = errors.New("invalid use of _ in number")
errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
)
+2 -2
View File
@@ -70,13 +70,13 @@ func (de *decodeError) Error() string {
func newDecodeError(highlight []byte, format string, args ...interface{}) error {
return &decodeError{
highlight: highlight,
message: fmt.Sprintf(format, args...),
message: fmt.Errorf(format, args...).Error(),
}
}
// Error returns the error message contained in the DecodeError.
func (e *DecodeError) Error() string {
return e.message
return "toml: " + e.message
}
// String returns the human-readable contextualized error. This string is multi-line.
+4 -4
View File
@@ -123,10 +123,10 @@ func (s *SeenTracker) checkTable(node ast.Node) error {
i, found := s.current.Has(k)
if found {
if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k)
return fmt.Errorf("toml: key %s should be a table, not a %s", k, i.kind)
}
if i.explicit {
return fmt.Errorf("table %s already exists", k)
return fmt.Errorf("toml: table %s already exists", k)
}
i.explicit = true
s.current = i
@@ -162,7 +162,7 @@ func (s *SeenTracker) checkArrayTable(node ast.Node) error {
info, found := s.current.Has(k)
if found {
if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k)
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", info.kind, k)
}
info.Clear()
} else {
@@ -182,7 +182,7 @@ func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
child, found := context.Has(k)
if found {
if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind)
return fmt.Errorf("toml: expected %s to be a table, not a %s", k, child.kind)
}
} else {
child = context.CreateTable(k, false)
+3 -3
View File
@@ -53,7 +53,7 @@ func LocalDateOf(t time.Time) LocalDate {
func ParseLocalDate(s string) (LocalDate, error) {
t, err := time.Parse("2006-01-02", s)
if err != nil {
return LocalDate{}, fmt.Errorf("ParseLocalDate: %w", err)
return LocalDate{}, err
}
return LocalDateOf(t), nil
@@ -166,7 +166,7 @@ func LocalTimeOf(t time.Time) LocalTime {
func ParseLocalTime(s string) (LocalTime, error) {
t, err := time.Parse("15:04:05.999999999", s)
if err != nil {
return LocalTime{}, fmt.Errorf("ParseLocalTime: %w", err)
return LocalTime{}, err
}
return LocalTimeOf(t), nil
@@ -237,7 +237,7 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) {
if err != nil {
t, err = time.Parse("2006-01-02t15:04:05.999999999", s)
if err != nil {
return LocalDateTime{}, fmt.Errorf("ParseLocalDateTime: %w", err)
return LocalDateTime{}, err
}
}
+11 -24
View File
@@ -3,7 +3,6 @@ package toml
import (
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
@@ -116,12 +115,12 @@ func (enc *Encoder) Encode(v interface{}) error {
b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("Encode: %w", err)
return err
}
_, err = enc.w.Write(b)
if err != nil {
return fmt.Errorf("Encode: %w", err)
return fmt.Errorf("toml: cannot write: %w", err)
}
return nil
@@ -178,11 +177,6 @@ func (ctx *encoderCtx) isRoot() bool {
return len(ctx.parentKey) == 0 && !ctx.hasKey
}
var (
errUnsupportedValue = errors.New("unsupported encode value kind")
errTextMarshalerCannotBeAtRoot = errors.New("type implementing TextMarshaler cannot be at root")
)
//nolint:cyclop,funlen
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
i, ok := v.Interface().(time.Time)
@@ -192,12 +186,12 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
if v.Type().Implements(textMarshalerType) {
if ctx.isRoot() {
return nil, errTextMarshalerCannotBeAtRoot
return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type())
}
text, err := v.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return nil, fmt.Errorf("encode: %w", err)
return nil, err
}
b = enc.encodeString(b, string(text), ctx.options)
@@ -215,7 +209,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
return enc.encodeSlice(b, ctx, v)
case reflect.Interface:
if v.IsNil() {
return nil, errNilInterface
return nil, fmt.Errorf("toml: encoding a nil interface is not supported")
}
return enc.encode(b, ctx, v.Elem())
@@ -244,7 +238,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
b = strconv.AppendInt(b, v.Int(), 10)
default:
return nil, fmt.Errorf("encode(type %s): %w", v.Kind(), errUnsupportedValue)
return nil, fmt.Errorf("toml: cannot encode value of type %s", v.Kind())
}
return b, nil
@@ -412,8 +406,6 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
return b, nil
}
var errTomlNoMultiline = errors.New("TOML does not support multiline keys")
//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
needsQuotation := false
@@ -425,7 +417,7 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
}
if c == '\n' {
return nil, errTomlNoMultiline
return nil, fmt.Errorf("toml: new line characters in keys are not supported")
}
if c == literalQuote {
@@ -445,11 +437,9 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
}
}
var errNotSupportedAsMapKey = errors.New("type not supported as map key")
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("encodeMap '%s': %w", v.Type().Key().Kind(), errNotSupportedAsMapKey)
return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind())
}
var (
@@ -658,10 +648,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
return b, nil
}
var (
errNilInterface = errors.New("nil interface not supported")
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
)
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
if v.Type() == timeType || v.Type().Implements(textMarshalerType) {
@@ -674,7 +661,7 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
return !ctx.inline, nil
case reflect.Interface:
if v.IsNil() {
return false, errNilInterface
return false, fmt.Errorf("toml: encoding a nil interface is not supported")
}
return willConvertToTable(ctx, v.Elem())
@@ -694,7 +681,7 @@ func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) (bool, erro
if t.Kind() == reflect.Interface {
if v.IsNil() {
return false, errNilInterface
return false, fmt.Errorf("toml: encoding a nil interface is not supported")
}
return willConvertToTableOrArrayTable(ctx, v.Elem())
+23 -41
View File
@@ -2,7 +2,6 @@ package toml
import (
"bytes"
"errors"
"fmt"
"strconv"
@@ -225,19 +224,13 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
return ref, b, err
}
var (
errExpectedValNotEOF = errors.New("expected value, not eof")
errExpectedTrue = errors.New("expected 'true'")
errExpectedFalse = errors.New("expected 'false'")
)
//nolint:cyclop,funlen
func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
// val = string / boolean / array / inline-table / date-time / float / integer
var ref ast.Reference
if len(b) == 0 {
return ref, nil, errExpectedValNotEOF
return ref, nil, newDecodeError(b, "expected value, not eof")
}
var err error
@@ -278,7 +271,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b, err
case 't':
if !scanFollowsTrue(b) {
return ref, nil, errExpectedTrue
return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'")
}
ref = p.builder.Push(ast.Node{
@@ -289,7 +282,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b[4:], nil
case 'f':
if !scanFollowsFalse(b) {
return ast.Reference{}, nil, errExpectedFalse
return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'")
}
ref = p.builder.Push(ast.Node{
@@ -307,6 +300,13 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
}
}
func atmost(b []byte, n int) []byte {
if n >= len(b) {
return b
}
return b[:n]
}
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) {
v, rest, err := scanLiteralString(b)
if err != nil {
@@ -370,8 +370,6 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
return parent, rest, err
}
var errArrayCannotStartWithComma = errors.New("array cannot start with comma")
//nolint:funlen,cyclop
func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
// array = array-open [ array-values ] ws-comment-newline array-close
@@ -409,7 +407,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if b[0] == ',' {
if first {
return parent, nil, errArrayCannotStartWithComma
return parent, nil, newDecodeError(b[0:1], "array cannot start with comma")
}
b = b[1:]
@@ -494,8 +492,6 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) {
return token[i : len(token)-3], rest, err
}
var errInvalidEscapeChar = errors.New("invalid escaped character")
//nolint:funlen,gocognit,cyclop
func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
@@ -582,7 +578,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
builder.WriteString(x)
i += 8
default:
return nil, nil, fmt.Errorf("parseMultilineBasicString: %w - %#U", errInvalidEscapeChar, c)
return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
}
} else {
builder.WriteByte(c)
@@ -721,7 +717,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
builder.WriteString(x)
i += 8
default:
return nil, nil, fmt.Errorf("parseBasicString: %w - %#U", errInvalidEscapeChar, c)
return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
}
} else {
builder.WriteByte(c)
@@ -731,18 +727,17 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
return builder.Bytes(), rest, nil
}
var errUnicodePointNeedsRightCountChar = errors.New("unicode point needs right number of hex characters")
func hexToString(b []byte, length int) (string, error) {
if len(b) < length {
return "", fmt.Errorf("hexToString: %w - %d", errUnicodePointNeedsRightCountChar, length)
return "", newDecodeError(b, "unicode point needs %d character, not %d", length, len(b))
}
b = b[:length]
//nolint:godox
// TODO: slow
intcode, err := strconv.ParseInt(string(b[:length]), 16, 32)
intcode, err := strconv.ParseInt(string(b), 16, 32)
if err != nil {
return "", fmt.Errorf("hexToString: %w", err)
return "", newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
}
return string(rune(intcode)), nil
@@ -757,17 +752,12 @@ func (p *parser) parseWhitespace(b []byte) []byte {
return rest
}
var (
errExpectedInf = errors.New("expected 'inf'")
errExpectedNan = errors.New("expected 'nan'")
)
//nolint:cyclop
func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) {
switch b[0] {
case 'i':
if !scanFollowsInf(b) {
return ast.Reference{}, nil, errExpectedInf
return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'inf'")
}
return p.builder.Push(ast.Node{
@@ -776,7 +766,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
}), b[3:], nil
case 'n':
if !scanFollowsNan(b) {
return ast.Reference{}, nil, errExpectedNan
return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'nan'")
}
return p.builder.Push(ast.Node{
@@ -821,8 +811,6 @@ func digitsToInt(b []byte) int {
return x
}
var errTimezoneButNoTimeComponent = errors.New("possible DateTime cannot have a timezone but no time component")
//nolint:gocognit,cyclop
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if
@@ -867,7 +855,7 @@ byteLoop:
}
} else {
if hasTz {
return ast.Reference{}, nil, errTimezoneButNoTimeComponent
return ast.Reference{}, nil, newDecodeError(b, "date-time has timezone but not time component")
}
kind = ast.LocalDate
}
@@ -878,12 +866,6 @@ byteLoop:
}), b[i:], nil
}
var (
errUnexpectedCharI = fmt.Errorf("unexpected character i while scanning for a number")
errUnexpectedCharN = fmt.Errorf("unexpected character n while scanning for a number")
errExpectedIntOrFloat = fmt.Errorf("expected integer or float")
)
//nolint:funlen,gocognit,cyclop
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
i := 0
@@ -940,7 +922,7 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
}), b[i+3:], nil
}
return ast.Reference{}, nil, errUnexpectedCharI
return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number")
}
if c == 'n' {
@@ -951,14 +933,14 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
}), b[i+3:], nil
}
return ast.Reference{}, nil, errUnexpectedCharN
return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number")
}
break
}
if i == 0 {
return ast.Reference{}, b, errExpectedIntOrFloat
return ast.Reference{}, b, newDecodeError(b, "incomplete number")
}
kind := ast.Integer
+6 -23
View File
@@ -1,9 +1,5 @@
package toml
import (
"errors"
)
func scanFollows(b []byte, pattern string) bool {
n := len(pattern)
@@ -83,22 +79,17 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`)
}
var (
errWindowsNewLineMissing = errors.New(`windows new line missing \n`)
errWindowsNewLineCRLF = errors.New(`windows new line should be \r\n`)
)
func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
const lenLF = 2
if len(b) < lenLF {
return nil, nil, errWindowsNewLineMissing
const lenCRLF = 2
if len(b) < lenCRLF {
return nil, nil, newDecodeError(b, "windows new line expected")
}
if b[1] != '\n' {
return nil, nil, errWindowsNewLineCRLF
return nil, nil, newDecodeError(b, `windows new line should be \r\n`)
}
return b[:lenLF], b[lenLF:], nil
return b[:lenCRLF], b[lenCRLF:], nil
}
func scanWhitespace(b []byte) ([]byte, []byte) {
@@ -116,8 +107,6 @@ func scanWhitespace(b []byte) ([]byte, []byte) {
//nolint:unparam
func scanComment(b []byte) ([]byte, []byte) {
// ;; Comment
//
// comment-start-symbol = %x23 ; #
// non-ascii = %x80-D7FF / %xE000-10FFFF
// non-eol = %x09 / %x20-7F / non-ascii
@@ -132,10 +121,6 @@ func scanComment(b []byte) ([]byte, []byte) {
return b, nil
}
var errBasicLineNotTerminatedByQuote = errors.New(`basic string not terminated by "`)
//nolint:godox
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; "
@@ -156,11 +141,9 @@ func scanBasicString(b []byte) ([]byte, []byte, error) {
}
}
return nil, nil, errBasicLineNotTerminatedByQuote
return nil, nil, newDecodeError(b[len(b):], `basic string not terminated by "`)
}
//nolint:godox
// TODO perform validation on the string?
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim
+111 -358
View File
@@ -1,7 +1,6 @@
package toml
import (
"errors"
"fmt"
"math"
"reflect"
@@ -14,19 +13,19 @@ type target interface {
get() reflect.Value
// Store a string at the target.
setString(v string) error
setString(v string)
// Store a boolean at the target
setBool(v bool) error
setBool(v bool)
// Store an int64 at the target
setInt64(v int64) error
setInt64(v int64)
// Store a float64 at the target
setFloat64(v float64) error
setFloat64(v float64)
// Stores any value at the target
set(v reflect.Value) error
set(v reflect.Value)
}
// valueTarget just contains a reflect.Value that can be set.
@@ -37,34 +36,24 @@ func (t valueTarget) get() reflect.Value {
return reflect.Value(t)
}
func (t valueTarget) set(v reflect.Value) error {
func (t valueTarget) set(v reflect.Value) {
reflect.Value(t).Set(v)
return nil
}
func (t valueTarget) setString(v string) error {
func (t valueTarget) setString(v string) {
t.get().SetString(v)
return nil
}
func (t valueTarget) setBool(v bool) error {
func (t valueTarget) setBool(v bool) {
t.get().SetBool(v)
return nil
}
func (t valueTarget) setInt64(v int64) error {
func (t valueTarget) setInt64(v int64) {
t.get().SetInt(v)
return nil
}
func (t valueTarget) setFloat64(v float64) error {
func (t valueTarget) setFloat64(v float64) {
t.get().SetFloat(v)
return nil
}
// interfaceTarget wraps an other target to dereference on get.
@@ -76,49 +65,24 @@ func (t interfaceTarget) get() reflect.Value {
return t.x.get().Elem()
}
func (t interfaceTarget) set(v reflect.Value) error {
err := t.x.set(v)
if err != nil {
return fmt.Errorf("interfaceTarget set: %w", err)
}
return nil
func (t interfaceTarget) set(v reflect.Value) {
t.x.set(v)
}
func (t interfaceTarget) setString(v string) error {
err := t.x.setString(v)
if err != nil {
return fmt.Errorf("interfaceTarget setString: %w", err)
}
return nil
func (t interfaceTarget) setString(v string) {
t.x.setString(v)
}
func (t interfaceTarget) setBool(v bool) error {
err := t.x.setBool(v)
if err != nil {
return fmt.Errorf("interfaceTarget setBool: %w", err)
}
return nil
func (t interfaceTarget) setBool(v bool) {
t.x.setBool(v)
}
func (t interfaceTarget) setInt64(v int64) error {
err := t.x.setInt64(v)
if err != nil {
return fmt.Errorf("interfaceTarget setInt64: %w", err)
}
return nil
func (t interfaceTarget) setInt64(v int64) {
t.x.setInt64(v)
}
func (t interfaceTarget) setFloat64(v float64) error {
err := t.x.setFloat64(v)
if err != nil {
return fmt.Errorf("interfaceTarget setFloat64: %w", err)
}
return nil
func (t interfaceTarget) setFloat64(v float64) {
t.x.setFloat64(v)
}
// mapTarget targets a specific key of a map.
@@ -131,33 +95,26 @@ func (t mapTarget) get() reflect.Value {
return t.v.MapIndex(t.k)
}
func (t mapTarget) set(v reflect.Value) error {
func (t mapTarget) set(v reflect.Value) {
t.v.SetMapIndex(t.k, v)
return nil
}
func (t mapTarget) setString(v string) error {
return t.set(reflect.ValueOf(v))
func (t mapTarget) setString(v string) {
t.set(reflect.ValueOf(v))
}
func (t mapTarget) setBool(v bool) error {
return t.set(reflect.ValueOf(v))
func (t mapTarget) setBool(v bool) {
t.set(reflect.ValueOf(v))
}
func (t mapTarget) setInt64(v int64) error {
return t.set(reflect.ValueOf(v))
func (t mapTarget) setInt64(v int64) {
t.set(reflect.ValueOf(v))
}
func (t mapTarget) setFloat64(v float64) error {
return t.set(reflect.ValueOf(v))
func (t mapTarget) setFloat64(v float64) {
t.set(reflect.ValueOf(v))
}
var (
errValIndexExpectingSlice = errors.New("expecting a slice")
errValIndexCannotInitSlice = errors.New("cannot initialize a slice")
)
//nolint:cyclop
// makes sure that the value pointed at by t is indexable (Slice, Array), or
// dereferences to an indexable (Ptr, Interface).
@@ -167,43 +124,20 @@ func ensureValueIndexable(t target) error {
switch f.Type().Kind() {
case reflect.Slice:
if f.IsNil() {
err := t.set(reflect.MakeSlice(f.Type(), 0, 0))
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
t.set(reflect.MakeSlice(f.Type(), 0, 0))
return nil
}
case reflect.Interface:
if f.IsNil() || f.Elem().Type() != sliceInterfaceType {
err := t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0))
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0))
return nil
}
if f.Elem().Type().Kind() != reflect.Slice {
return fmt.Errorf("ensureValueIndexable: %w, not a %s", errValIndexExpectingSlice, f.Kind())
}
case reflect.Ptr:
if f.IsNil() {
ptr := reflect.New(f.Type().Elem())
err := t.set(ptr)
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
f = t.get()
}
return ensureValueIndexable(valueTarget(f.Elem()))
panic("pointer should have already been dereferenced")
case reflect.Array:
// arrays are always initialized.
default:
return fmt.Errorf("ensureValueIndexable: %w with %s", errValIndexCannotInitSlice, f.Kind())
return fmt.Errorf("toml: cannot store array in a %s", f.Kind())
}
return nil
@@ -214,69 +148,44 @@ var (
mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
)
func ensureMapIfInterface(x target) error {
func ensureMapIfInterface(x target) {
v := x.get()
if v.Kind() == reflect.Interface && v.IsNil() {
newElement := reflect.MakeMap(mapStringInterfaceType)
err := x.set(newElement)
if err != nil {
return fmt.Errorf("ensureMapIfInterface: %w", err)
}
x.set(newElement)
}
return nil
}
var errSetStringCannotAssignString = errors.New("cannot assign string")
func setString(t target, v string) error {
f := t.get()
switch f.Kind() {
case reflect.String:
err := t.setString(v)
if err != nil {
return fmt.Errorf("setString: %w", err)
}
return nil
t.setString(v)
case reflect.Interface:
err := t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setString: %w", err)
}
return nil
t.set(reflect.ValueOf(v))
default:
return fmt.Errorf("setString: %w to a %s", errSetStringCannotAssignString, f.Kind())
return fmt.Errorf("toml: cannot assign string to a %s", f.Kind())
}
}
var errSetBoolCannotAssignBool = errors.New("cannot assign bool")
return nil
}
func setBool(t target, v bool) error {
f := t.get()
switch f.Kind() {
case reflect.Bool:
err := t.setBool(v)
if err != nil {
return fmt.Errorf("setBool: %w", err)
}
return nil
t.setBool(v)
case reflect.Interface:
err := t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setBool: %w", err)
}
return nil
t.set(reflect.ValueOf(v))
default:
return fmt.Errorf("setBool: %w to a %s", errSetBoolCannotAssignBool, f.String())
return fmt.Errorf("toml: cannot assign boolean to a %s", f.Kind())
}
return nil
}
const (
@@ -284,207 +193,104 @@ const (
minInt = -maxInt - 1
)
var (
errSetInt64InInt32 = errors.New("does not fit in an int32")
errSetInt64InInt16 = errors.New("does not fit in an int16")
errSetInt64InInt8 = errors.New("does not fit in an int8")
errSetInt64InInt = errors.New("does not fit in an int")
errSetInt64InUint64 = errors.New("negative integer does not fit in an uint64")
errSetInt64InUint32 = errors.New("negative integer does not fit in an uint32")
errSetInt64InUint32Max = errors.New("integer does not fit in an uint32")
errSetInt64InUint16 = errors.New("negative integer does not fit in an uint16")
errSetInt64InUint16Max = errors.New("integer does not fit in an uint16")
errSetInt64InUint8 = errors.New("negative integer does not fit in an uint8")
errSetInt64InUint8Max = errors.New("integer does not fit in an uint8")
errSetInt64InUint = errors.New("negative integer does not fit in an uint")
errSetInt64Unknown = errors.New("does not fit in an uint")
)
//nolint:funlen,gocognit,cyclop,gocyclo
func setInt64(t target, v int64) error {
f := t.get()
switch f.Kind() {
case reflect.Int64:
err := t.setInt64(v)
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.setInt64(v)
case reflect.Int32:
if v < math.MinInt32 || v > math.MaxInt32 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt32)
}
err := t.set(reflect.ValueOf(int32(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
return fmt.Errorf("toml: number %d does not fit in an int32", v)
}
t.set(reflect.ValueOf(int32(v)))
return nil
case reflect.Int16:
if v < math.MinInt16 || v > math.MaxInt16 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt16)
return fmt.Errorf("toml: number %d does not fit in an int16", v)
}
err := t.set(reflect.ValueOf(int16(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(int16(v)))
case reflect.Int8:
if v < math.MinInt8 || v > math.MaxInt8 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt8)
return fmt.Errorf("toml: number %d does not fit in an int8", v)
}
err := t.set(reflect.ValueOf(int8(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(int8(v)))
case reflect.Int:
if v < minInt || v > maxInt {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt)
return fmt.Errorf("toml: number %d does not fit in an int", v)
}
err := t.set(reflect.ValueOf(int(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(int(v)))
case reflect.Uint64:
if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint64)
return fmt.Errorf("toml: negative number %d does not fit in an uint64", v)
}
err := t.set(reflect.ValueOf(uint64(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(uint64(v)))
case reflect.Uint32:
if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32)
if v < 0 || v > math.MaxUint32 {
return fmt.Errorf("toml: negative number %d does not fit in an uint32", v)
}
if v > math.MaxUint32 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32Max)
}
err := t.set(reflect.ValueOf(uint32(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(uint32(v)))
case reflect.Uint16:
if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16)
if v < 0 || v > math.MaxUint16 {
return fmt.Errorf("toml: negative number %d does not fit in an uint16", v)
}
if v > math.MaxUint16 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16Max)
}
err := t.set(reflect.ValueOf(uint16(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(uint16(v)))
case reflect.Uint8:
if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8)
if v < 0 || v > math.MaxUint8 {
return fmt.Errorf("toml: negative number %d does not fit in an uint8", v)
}
if v > math.MaxUint8 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8Max)
}
err := t.set(reflect.ValueOf(uint8(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(uint8(v)))
case reflect.Uint:
if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint)
return fmt.Errorf("toml: negative number %d does not fit in an uint", v)
}
err := t.set(reflect.ValueOf(uint(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(uint(v)))
case reflect.Interface:
err := t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
t.set(reflect.ValueOf(v))
default:
return fmt.Errorf("setInt64: %s, %w", f.String(), errSetInt64Unknown)
return fmt.Errorf("toml: integer cannot be assigned to %s", f.Kind())
}
}
var (
errSetFloat64InFloat32Max = errors.New("does not fit in an float32")
errSetFloat64Unknown = errors.New("does not fit in an float32")
)
return nil
}
func setFloat64(t target, v float64) error {
f := t.get()
switch f.Kind() {
case reflect.Float64:
err := t.setFloat64(v)
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
t.setFloat64(v)
case reflect.Float32:
if v > math.MaxFloat32 {
return fmt.Errorf("setFloat64: %f %w", v, errSetFloat64InFloat32Max)
return fmt.Errorf("toml: number %f does not fit in a float32", v)
}
err := t.set(reflect.ValueOf(float32(v)))
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
t.set(reflect.ValueOf(float32(v)))
case reflect.Interface:
err := t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
t.set(reflect.ValueOf(v))
default:
return fmt.Errorf("setFloat64: %s %w", f.String(), errSetFloat64Unknown)
return fmt.Errorf("toml: float cannot be assigned to %s", f.Kind())
}
}
var (
errElementAtCannotOn = errors.New("cannot elementAt")
errElementAtCannotOnUnknown = errors.New("cannot elementAt")
)
return nil
}
//nolint:cyclop
// Returns the element at idx of the value pointed at by target, or an error if
// t does not point to an indexable.
// If the target points to an Array and idx is out of bounds, it returns
// (nil, nil) as this is not a fatal error (the unmarshaler will skip).
func elementAt(t target, idx int) (target, error) {
func elementAt(t target, idx int) target {
f := t.get()
switch f.Kind() {
@@ -493,42 +299,30 @@ func elementAt(t target, idx int) (target, error) {
// TODO: use the idx function argument and avoid alloc if possible.
idx := f.Len()
err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
if err != nil {
return nil, fmt.Errorf("elementAt: %w", err)
}
t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
return valueTarget(t.get().Index(idx)), nil
return valueTarget(t.get().Index(idx))
case reflect.Array:
if idx >= f.Len() {
return nil, nil
return nil
}
return valueTarget(f.Index(idx)), nil
return valueTarget(f.Index(idx))
case reflect.Interface:
if f.IsNil() {
panic("interface should have been initialized")
}
// This function is called after ensureValueIndexable, so it's
// guaranteed that f contains an initialized slice.
ifaceElem := f.Elem()
if ifaceElem.Kind() != reflect.Slice {
return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOn, f.Kind())
}
idx := ifaceElem.Len()
newElem := reflect.New(ifaceElem.Type().Elem()).Elem()
newSlice := reflect.Append(ifaceElem, newElem)
err := t.set(newSlice)
if err != nil {
return nil, fmt.Errorf("elementAt: %w", err)
}
t.set(newSlice)
return valueTarget(t.get().Elem().Index(idx)), nil
case reflect.Ptr:
return elementAt(valueTarget(f.Elem()), idx)
return valueTarget(t.get().Elem().Index(idx))
default:
return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOnUnknown, f.Kind())
// Why ensureValueIndexable let it go through?
panic(fmt.Errorf("elementAt received unhandled value type: %s", f.Kind()))
}
}
@@ -539,31 +333,19 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
switch x.Kind() {
// Kinds that need to recurse
case reflect.Interface:
t, err := scopeInterface(shouldAppend, t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
t := scopeInterface(shouldAppend, t)
return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Ptr:
t, err := scopePtr(t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
t := scopePtr(t)
return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Slice:
t, err := scopeSlice(shouldAppend, t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
t := scopeSlice(shouldAppend, t)
shouldAppend = false
return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Array:
t, err := d.scopeArray(shouldAppend, t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
return t, false, err
}
shouldAppend = false
@@ -574,11 +356,7 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
return scopeStruct(x, name)
case reflect.Map:
if x.IsNil() {
err := t.set(reflect.MakeMap(x.Type()))
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
t.set(reflect.MakeMap(x.Type()))
x = t.get()
}
@@ -588,42 +366,29 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
}
}
func scopeInterface(shouldAppend bool, t target) (target, error) {
err := initInterface(shouldAppend, t)
if err != nil {
return t, err
}
return interfaceTarget{t}, nil
func scopeInterface(shouldAppend bool, t target) target {
initInterface(shouldAppend, t)
return interfaceTarget{t}
}
func scopePtr(t target) (target, error) {
err := initPtr(t)
if err != nil {
return t, err
}
return valueTarget(t.get().Elem()), nil
func scopePtr(t target) target {
initPtr(t)
return valueTarget(t.get().Elem())
}
func initPtr(t target) error {
func initPtr(t target) {
x := t.get()
if !x.IsNil() {
return nil
return
}
err := t.set(reflect.New(x.Type().Elem()))
if err != nil {
return fmt.Errorf("initPtr: %w", err)
}
return nil
t.set(reflect.New(x.Type().Elem()))
}
// initInterface makes sure that the interface pointed at by the target is not
// nil.
// Returns the target to the initialized value of the target.
func initInterface(shouldAppend bool, t target) error {
func initInterface(shouldAppend bool, t target) {
x := t.get()
if x.Kind() != reflect.Interface {
@@ -631,7 +396,7 @@ func initInterface(shouldAppend bool, t target) error {
}
if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) {
return nil
return
}
var newElement reflect.Value
@@ -641,55 +406,43 @@ func initInterface(shouldAppend bool, t target) error {
newElement = reflect.MakeMap(mapStringInterfaceType)
}
err := t.set(newElement)
if err != nil {
return fmt.Errorf("initInterface: %w", err)
}
return nil
t.set(newElement)
}
func scopeSlice(shouldAppend bool, t target) (target, error) {
func scopeSlice(shouldAppend bool, t target) target {
v := t.get()
if shouldAppend {
newElem := reflect.New(v.Type().Elem())
newSlice := reflect.Append(v, newElem.Elem())
err := t.set(newSlice)
if err != nil {
return t, fmt.Errorf("scopeSlice: %w", err)
}
t.set(newSlice)
v = t.get()
}
return valueTarget(v.Index(v.Len() - 1)), nil
return valueTarget(v.Index(v.Len() - 1))
}
var errScopeArrayNotEnoughSpace = errors.New("not enough space in the array")
func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) {
v := t.get()
idx := d.arrayIndex(shouldAppend, v)
if idx >= v.Len() {
return nil, errScopeArrayNotEnoughSpace
return nil, fmt.Errorf("toml: impossible to insert element beyond array's size: %d", v.Len())
}
return valueTarget(v.Index(idx)), nil
}
var errScopeMapCannotConvertStringToKey = errors.New("cannot convert string into map key type")
func scopeMap(v reflect.Value, name string) (target, bool, error) {
k := reflect.ValueOf(name)
keyType := v.Type().Key()
if !k.Type().AssignableTo(keyType) {
if !k.Type().ConvertibleTo(keyType) {
return nil, false, fmt.Errorf("scopeMap: %w %s", errScopeMapCannotConvertStringToKey, keyType)
return nil, false, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", k.Type(), keyType)
}
k = k.Convert(keyType)
+6 -10
View File
@@ -128,14 +128,12 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err)
n, err := elementAt(x, 0)
require.NoError(t, err)
require.NoError(t, n.setString("hello"))
n := elementAt(x, 0)
n.setString("hello")
require.Equal(t, []string{"hello"}, d.A)
n, err = elementAt(x, 1)
require.NoError(t, err)
require.NoError(t, n.setString("world"))
n = elementAt(x, 1)
n.setString("world")
require.Equal(t, []string{"hello", "world"}, d.A)
})
@@ -151,13 +149,11 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err)
n, err := elementAt(x, 0)
require.NoError(t, err)
n := elementAt(x, 0)
require.NoError(t, setString(n, "hello"))
require.Equal(t, []interface{}{"hello"}, d.A)
n, err = elementAt(x, 1)
require.NoError(t, err)
n = elementAt(x, 1)
require.NoError(t, setString(n, "world"))
require.Equal(t, []interface{}{"hello", "world"}, d.A)
})
+2 -10
View File
@@ -94,12 +94,7 @@ func testGenTranslateDesc(input interface{}) interface{} {
if ok {
dvalue, ok = d["value"]
if ok {
var okdt bool
dtype, okdt = dtypeiface.(string)
if !okdt {
panic(fmt.Sprintf("dtypeiface should be valid string: %v", dtypeiface))
}
dtype = dtypeiface.(string)
switch dtype {
case "string":
@@ -132,10 +127,7 @@ func testGenTranslateDesc(input interface{}) interface{} {
return nil
}
a, oka := dvalue.([]interface{})
if !oka {
panic(fmt.Sprintf("a should be valid []interface{}: %v", a))
}
a := dvalue.([]interface{})
xs := make([]interface{}, len(a))
+34 -67
View File
@@ -56,7 +56,7 @@ func (d *Decoder) SetStrict(strict bool) {
func (d *Decoder) Decode(v interface{}) error {
b, err := ioutil.ReadAll(d.r)
if err != nil {
return fmt.Errorf("Decode: %w", err)
return fmt.Errorf("toml: %w", err)
}
p := parser{}
@@ -130,20 +130,15 @@ func keyLocation(node ast.Node) []byte {
return unsafe.BytesRange(start, end)
}
var (
errFromParserExpectingPointer = errors.New("expecting a pointer as target")
errFromParserExpectingNonNilPointer = errors.New("expecting non nil pointer as target")
)
//nolint:funlen,cyclop
func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr {
return fmt.Errorf("fromParser: %w, not %s", errFromParserExpectingPointer, r.Kind())
return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
}
if r.IsNil() {
return errFromParserExpectingNonNilPointer
return fmt.Errorf("toml: decoding pointer target cannot be nil")
}
var (
@@ -162,7 +157,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
err := d.seen.CheckExpression(node)
if err != nil {
return fmt.Errorf("fromParser: %w", err)
return err
}
var found bool
@@ -181,16 +176,13 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
// looks like a table. Otherwise the information
// of a table is lost, and marshal cannot do the
// round trip.
err := ensureMapIfInterface(current)
if err != nil {
panic(fmt.Sprintf("ensureMapIfInterface: %s", err))
}
ensureMapIfInterface(current)
}
case ast.ArrayTable:
d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key())
default:
panic(fmt.Sprintf("fromParser: this should not be a top level node type: %s", node.Kind))
panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind))
}
if err != nil {
@@ -267,26 +259,18 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
v := x.get()
if v.Kind() == reflect.Ptr {
x, err = scopePtr(x)
if err != nil {
return x, false, err
}
x = scopePtr(x)
v = x.get()
}
if v.Kind() == reflect.Interface {
x, err = scopeInterface(true, x)
if err != nil {
return x, found, err
}
x = scopeInterface(true, x)
v = x.get()
}
switch v.Kind() {
case reflect.Slice:
x, err = scopeSlice(true, x)
x = scopeSlice(true, x)
case reflect.Array:
x, err = d.scopeArray(true, x)
default:
@@ -334,7 +318,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
if v.Type().Implements(textUnmarshalerType) {
err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil {
return false, fmt.Errorf("tryTextUnmarshaler: %w", err)
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
}
return true, nil
@@ -343,7 +327,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil {
return false, fmt.Errorf("tryTextUnmarshaler: %w", err)
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
}
return true, nil
@@ -358,11 +342,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error {
if v.Kind() == reflect.Ptr {
if !v.Elem().IsValid() {
err := x.set(reflect.New(v.Type().Elem()))
if err != nil {
return fmt.Errorf("unmarshalValue: %w", err)
}
x.set(reflect.New(v.Type().Elem()))
v = x.get()
}
@@ -394,7 +374,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error {
case ast.LocalDate:
return unmarshalLocalDate(x, node)
default:
panic(fmt.Sprintf("unmarshalValue: unhandled unmarshalValue kind %s", node.Kind))
panic(fmt.Sprintf("unhandled node kind %s", node.Kind))
}
}
@@ -406,7 +386,9 @@ func unmarshalLocalDate(x target, node ast.Node) error {
return err
}
return setDate(x, v)
setDate(x, v)
return nil
}
func unmarshalLocalDateTime(x target, node ast.Node) error {
@@ -421,7 +403,9 @@ func unmarshalLocalDateTime(x target, node ast.Node) error {
return newDecodeError(rest, "extra characters at the end of a local date time")
}
return setLocalDateTime(x, v)
setLocalDateTime(x, v)
return nil
}
func unmarshalDateTime(x target, node ast.Node) error {
@@ -432,48 +416,37 @@ func unmarshalDateTime(x target, node ast.Node) error {
return err
}
return setDateTime(x, v)
setDateTime(x, v)
return nil
}
func setLocalDateTime(x target, v LocalDateTime) error {
func setLocalDateTime(x target, v LocalDateTime) {
if x.get().Type() == timeType {
cast := v.In(time.Local)
return setDateTime(x, cast)
setDateTime(x, cast)
return
}
err := x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setLocalDateTime: %w", err)
}
return nil
x.set(reflect.ValueOf(v))
}
func setDateTime(x target, v time.Time) error {
err := x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setDateTime: %w", err)
}
return nil
func setDateTime(x target, v time.Time) {
x.set(reflect.ValueOf(v))
}
var timeType = reflect.TypeOf(time.Time{})
func setDate(x target, v LocalDate) error {
func setDate(x target, v LocalDate) {
if x.get().Type() == timeType {
cast := v.In(time.Local)
return setDateTime(x, cast)
setDateTime(x, cast)
return
}
err := x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setDate: %w", err)
}
return nil
x.set(reflect.ValueOf(v))
}
func unmarshalString(x target, node ast.Node) error {
@@ -514,10 +487,7 @@ func unmarshalFloat(x target, node ast.Node) error {
func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error {
assertNode(ast.InlineTable, node)
err := ensureMapIfInterface(x)
if err != nil {
return fmt.Errorf("unmarshalInlineTable: %w", err)
}
ensureMapIfInterface(x)
it := node.Children()
for it.Next() {
@@ -546,10 +516,7 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error {
for it.Next() {
n := it.Node()
v, err := elementAt(x, idx)
if err != nil {
return err
}
v := elementAt(x, idx)
if v == nil {
// when we go out of bound for an array just stop processing it to
+57 -3
View File
@@ -38,6 +38,11 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `+99`,
expected: 99,
},
{
desc: "integer decimal underscore",
input: `123_456`,
expected: 123456,
},
{
desc: "integer hex uppercase",
input: `0xDEADBEEF`,
@@ -58,6 +63,21 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `0b11010110`,
expected: 0b11010110,
},
{
desc: "double underscore",
input: "12__3",
err: true,
},
{
desc: "starts with underscore",
input: "_1",
err: true,
},
{
desc: "ends with underscore",
input: "1_",
err: true,
},
}
type doc struct {
@@ -71,8 +91,12 @@ func TestUnmarshal_Integers(t *testing.T) {
doc := doc{}
err := toml.Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err)
assert.Equal(t, e.expected, doc.A)
if e.err {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, e.expected, doc.A)
}
})
}
}
@@ -799,6 +823,33 @@ B = "data"`,
}
},
},
{
desc: "mismatch types int to string",
input: `A = 42`,
gen: func() test {
type S struct {
A string
}
return test{
target: &S{},
err: true,
}
},
},
{
desc: "mismatch types array of int to interface with non-slice",
input: `A = [[42]]`,
skip: true,
gen: func() test {
type S struct {
A *string
}
return test{
target: &S{},
expected: &S{},
}
},
},
}
for _, e := range examples {
@@ -815,6 +866,9 @@ B = "data"`,
}
err := toml.Unmarshal([]byte(e.input), test.target)
if test.err {
if err == nil {
t.Log("=>", test.target)
}
require.Error(t, err)
} else {
require.NoError(t, err)
@@ -1030,7 +1084,7 @@ world'`,
if e.msg != "" {
t.Log("\n" + de.String())
require.Equal(t, e.msg, de.Error())
require.Equal(t, "toml: "+e.msg, de.Error())
}
})
}