v2: errors (#534)

```
name                              old time/op    new time/op    delta
UnmarshalDataset/config-32          86.7ms ± 2%    87.5ms ± 2%     ~     (p=0.113 n=9+10)
UnmarshalDataset/canada-32           129ms ± 4%     106ms ± 3%  -17.94%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    59.4ms ± 5%    58.7ms ± 5%     ~     (p=0.393 n=10+10)
UnmarshalDataset/twitter-32         27.0ms ± 7%    26.9ms ± 6%     ~     (p=0.720 n=10+9)
UnmarshalDataset/code-32             326ms ± 4%     322ms ± 7%     ~     (p=0.661 n=9+10)
UnmarshalDataset/example-32          510µs ±11%     526µs ± 7%     ~     (p=0.182 n=10+9)
UnmarshalSimple-32                  1.41µs ± 6%    1.41µs ± 4%     ~     (p=0.736 n=10+9)
ReferenceFile-32                    45.6µs ± 3%    43.9µs ±10%     ~     (p=0.089 n=10+10)

name                              old speed      new speed      delta
UnmarshalDataset/config-32        12.1MB/s ± 2%  12.0MB/s ± 2%     ~     (p=0.108 n=9+10)
UnmarshalDataset/canada-32        17.1MB/s ± 4%  20.9MB/s ± 3%  +21.86%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32  9.41MB/s ± 5%  9.51MB/s ± 5%     ~     (p=0.362 n=10+10)
UnmarshalDataset/twitter-32       16.4MB/s ± 8%  16.5MB/s ± 6%     ~     (p=0.704 n=10+9)
UnmarshalDataset/code-32          8.24MB/s ± 4%  8.34MB/s ± 7%     ~     (p=0.675 n=9+10)
UnmarshalDataset/example-32       15.9MB/s ±11%  15.4MB/s ± 7%     ~     (p=0.182 n=10+9)
ReferenceFile-32                   115MB/s ± 4%   120MB/s ±10%     ~     (p=0.085 n=10+10)

name                              old alloc/op   new alloc/op   delta
UnmarshalDataset/config-32          16.9MB ± 0%    16.9MB ± 0%   -0.02%  (p=0.000 n=10+10)
UnmarshalDataset/canada-32          76.8MB ± 0%    74.3MB ± 0%   -3.31%  (p=0.000 n=10+10)
UnmarshalDataset/citm_catalog-32    37.3MB ± 0%    37.1MB ± 0%   -0.60%  (p=0.000 n=9+10)
UnmarshalDataset/twitter-32         15.6MB ± 0%    15.6MB ± 0%   -0.09%  (p=0.000 n=10+10)
UnmarshalDataset/code-32            60.2MB ± 0%    59.3MB ± 0%   -1.51%  (p=0.000 n=10+9)
UnmarshalDataset/example-32          238kB ± 0%     238kB ± 0%   -0.18%  (p=0.000 n=10+10)
ReferenceFile-32                    11.8kB ± 0%    11.8kB ± 0%     ~     (all equal)

name                              old allocs/op  new allocs/op  delta
UnmarshalDataset/config-32            653k ± 0%      645k ± 0%   -1.20%  (p=0.000 n=10+6)
UnmarshalDataset/canada-32           1.01M ± 0%     0.90M ± 0%  -11.04%  (p=0.000 n=9+10)
UnmarshalDataset/citm_catalog-32      384k ± 0%      370k ± 0%   -3.75%  (p=0.000 n=10+10)
UnmarshalDataset/twitter-32           160k ± 0%      157k ± 0%   -1.32%  (p=0.000 n=10+10)
UnmarshalDataset/code-32             2.97M ± 0%     2.91M ± 0%   -2.15%  (p=0.000 n=10+7)
UnmarshalDataset/example-32          3.69k ± 0%     3.63k ± 0%   -1.52%  (p=0.000 n=10+10)
ReferenceFile-32                       253 ± 0%       253 ± 0%     ~     (all equal)
```
This commit is contained in:
Thomas Pelletier
2021-05-08 16:04:25 -04:00
committed by GitHub
parent 4545a3e94b
commit ea225df3ed
12 changed files with 325 additions and 656 deletions
+66 -111
View File
@@ -1,11 +1,8 @@
package toml package toml
import ( import (
"errors"
"fmt"
"math" "math"
"strconv" "strconv"
"strings"
"time" "time"
) )
@@ -59,14 +56,12 @@ func parseLocalDate(b []byte) (LocalDate, error) {
return date, nil return date, nil
} }
var errNotDigit = errors.New("not a digit")
func parseDecimalDigits(b []byte) (int, error) { func parseDecimalDigits(b []byte) (int, error) {
v := 0 v := 0
for _, c := range b { for i, c := range b {
if !isDigit(c) { if !isDigit(c) {
return 0, fmt.Errorf("%s: %w", b, errNotDigit) return 0, newDecodeError(b[i:i+1], "should be a digit (0-9)")
} }
v *= 10 v *= 10
@@ -76,13 +71,14 @@ func parseDecimalDigits(b []byte) (int, error) {
return v, nil return v, nil
} }
var errParseDateTimeMissingInfo = errors.New("date-time missing timezone information")
func parseDateTime(b []byte) (time.Time, error) { func parseDateTime(b []byte) (time.Time, error) {
// offset-date-time = full-date time-delim full-time // offset-date-time = full-date time-delim full-time
// full-time = partial-time time-offset // full-time = partial-time time-offset
// time-offset = "Z" / time-numoffset // time-offset = "Z" / time-numoffset
// time-numoffset = ( "+" / "-" ) time-hour ":" time-minute // time-numoffset = ( "+" / "-" ) time-hour ":" time-minute
originalBytes := b
dt, b, err := parseLocalDateTime(b) dt, b, err := parseLocalDateTime(b)
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
@@ -91,7 +87,7 @@ func parseDateTime(b []byte) (time.Time, error) {
var zone *time.Location var zone *time.Location
if len(b) == 0 { if len(b) == 0 {
return time.Time{}, errParseDateTimeMissingInfo return time.Time{}, newDecodeError(originalBytes, "date-time is missing timezone")
} }
if b[0] == 'Z' { if b[0] == 'Z' {
@@ -134,19 +130,12 @@ func parseDateTime(b []byte) (time.Time, error) {
return t, nil return t, nil
} }
var (
errParseLocalDateTimeWrongLength = errors.New(
"local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]",
)
errParseLocalDateTimeWrongSeparator = errors.New("datetime separator is expected to be T or a space")
)
func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
var dt LocalDateTime var dt LocalDateTime
const localDateTimeByteMinLen = 11 const localDateTimeByteMinLen = 11
if len(b) < localDateTimeByteMinLen { if len(b) < localDateTimeByteMinLen {
return dt, nil, errParseLocalDateTimeWrongLength return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]")
} }
date, err := parseLocalDate(b[:10]) date, err := parseLocalDate(b[:10])
@@ -157,7 +146,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
sep := b[10] sep := b[10]
if sep != 'T' && sep != ' ' { if sep != 'T' && sep != ' ' {
return dt, nil, errParseLocalDateTimeWrongSeparator return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space")
} }
t, rest, err := parseLocalTime(b[11:]) t, rest, err := parseLocalTime(b[11:])
@@ -169,8 +158,6 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
return dt, rest, nil return dt, rest, nil
} }
var errParseLocalTimeWrongLength = errors.New("times are expected to have the format HH:MM:SS[.NNNNNN]")
// parseLocalTime is a bit different because it also returns the remaining // parseLocalTime is a bit different because it also returns the remaining
// []byte that is didn't need. This is to allow parseDateTime to parse those // []byte that is didn't need. This is to allow parseDateTime to parse those
// remaining bytes as a timezone. // remaining bytes as a timezone.
@@ -183,7 +170,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
const localTimeByteLen = 8 const localTimeByteLen = 8
if len(b) < localTimeByteLen { if len(b) < localTimeByteLen {
return t, nil, errParseLocalTimeWrongLength return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
var err error var err error
@@ -242,11 +229,6 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, b[8:], nil return t, b[8:], nil
} }
var (
errParseFloatStartDot = errors.New("float cannot start with a dot")
errParseFloatEndDot = errors.New("float cannot end with a dot")
)
//nolint:cyclop //nolint:cyclop
func parseFloat(b []byte) (float64, error) { func parseFloat(b []byte) (float64, error) {
//nolint:godox //nolint:godox
@@ -255,150 +237,123 @@ func parseFloat(b []byte) (float64, error) {
return math.NaN(), nil return math.NaN(), nil
} }
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b)
err := numberContainsInvalidUnderscore(tok)
if err != nil { if err != nil {
return 0, err return 0, err
} }
cleanedVal := cleanupNumberToken(tok) if cleaned[0] == '.' {
if cleanedVal[0] == '.' { return 0, newDecodeError(b, "float cannot start with a dot")
return 0, errParseFloatStartDot
} }
if cleanedVal[len(cleanedVal)-1] == '.' { if cleaned[len(cleaned)-1] == '.' {
return 0, errParseFloatEndDot return 0, newDecodeError(b, "float cannot end with a dot")
} }
f, err := strconv.ParseFloat(cleanedVal, 64) f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseFloat %w", err) return 0, newDecodeError(b, "coudn't parse float: %w", err)
} }
return f, nil return f, nil
} }
func parseIntHex(b []byte) (int64, error) { func parseIntHex(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 16, 64) i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntHex %w", err) return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntOct(b []byte) (int64, error) { func parseIntOct(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 8, 64) i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntOct %w", err) return 0, newDecodeError(b, "couldn't parse octal number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntBin(b []byte) (int64, error) { func parseIntBin(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b[2:])
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal[2:], 2, 64) i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't ParseIntBin %w", err) return 0, newDecodeError(b, "couldn't parse binary number: %w", err)
} }
return i, nil return i, nil
} }
func parseIntDec(b []byte) (int64, error) { func parseIntDec(b []byte) (int64, error) {
tok := string(b) cleaned, err := checkAndRemoveUnderscores(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil { if err != nil {
return 0, err return 0, err
} }
i, err := strconv.ParseInt(cleanedVal, 10, 64) i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("coudn't parseIntDec %w", err) return 0, newDecodeError(b, "couldn't parse decimal number: %w", err)
} }
return i, nil return i, nil
} }
func numberContainsInvalidUnderscore(value string) error { func checkAndRemoveUnderscores(b []byte) ([]byte, error) {
// For large numbers, you may use underscores between digits to enhance if len(b) == 0 {
// readability. Each underscore must be surrounded by at least one digit on return b, nil
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
} }
return nil if b[0] == '_' {
} return nil, newDecodeError(b[0:1], "number cannot start with underscore")
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
} }
return nil if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore")
}
// fast path
i := 0
for ; i < len(b); i++ {
if b[i] == '_' {
break
}
}
if i == len(b) {
return b, nil
}
before := false
cleaned := make([]byte, i, len(b))
copy(cleaned, b)
for i++; i < len(b); i++ {
c := b[i]
if c == '_' {
if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores")
}
before = false
} else {
before = true
cleaned = append(cleaned, c)
}
}
return cleaned, nil
} }
func cleanupNumberToken(value string) string {
cleanedVal := strings.ReplaceAll(value, "_", "")
return cleanedVal
}
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var (
errInvalidUnderscore = errors.New("invalid use of _ in number")
errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
)
+2 -2
View File
@@ -70,13 +70,13 @@ func (de *decodeError) Error() string {
func newDecodeError(highlight []byte, format string, args ...interface{}) error { func newDecodeError(highlight []byte, format string, args ...interface{}) error {
return &decodeError{ return &decodeError{
highlight: highlight, highlight: highlight,
message: fmt.Sprintf(format, args...), message: fmt.Errorf(format, args...).Error(),
} }
} }
// Error returns the error message contained in the DecodeError. // Error returns the error message contained in the DecodeError.
func (e *DecodeError) Error() string { func (e *DecodeError) Error() string {
return e.message return "toml: " + e.message
} }
// String returns the human-readable contextualized error. This string is multi-line. // String returns the human-readable contextualized error. This string is multi-line.
+4 -4
View File
@@ -123,10 +123,10 @@ func (s *SeenTracker) checkTable(node ast.Node) error {
i, found := s.current.Has(k) i, found := s.current.Has(k)
if found { if found {
if i.kind != tableKind { if i.kind != tableKind {
return fmt.Errorf("key %s should be a table", k) return fmt.Errorf("toml: key %s should be a table, not a %s", k, i.kind)
} }
if i.explicit { if i.explicit {
return fmt.Errorf("table %s already exists", k) return fmt.Errorf("toml: table %s already exists", k)
} }
i.explicit = true i.explicit = true
s.current = i s.current = i
@@ -162,7 +162,7 @@ func (s *SeenTracker) checkArrayTable(node ast.Node) error {
info, found := s.current.Has(k) info, found := s.current.Has(k)
if found { if found {
if info.kind != arrayTableKind { if info.kind != arrayTableKind {
return fmt.Errorf("key %s already exists but is not an array table", k) return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", info.kind, k)
} }
info.Clear() info.Clear()
} else { } else {
@@ -182,7 +182,7 @@ func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error {
child, found := context.Has(k) child, found := context.Has(k)
if found { if found {
if child.kind != tableKind { if child.kind != tableKind {
return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind) return fmt.Errorf("toml: expected %s to be a table, not a %s", k, child.kind)
} }
} else { } else {
child = context.CreateTable(k, false) child = context.CreateTable(k, false)
+3 -3
View File
@@ -53,7 +53,7 @@ func LocalDateOf(t time.Time) LocalDate {
func ParseLocalDate(s string) (LocalDate, error) { func ParseLocalDate(s string) (LocalDate, error) {
t, err := time.Parse("2006-01-02", s) t, err := time.Parse("2006-01-02", s)
if err != nil { if err != nil {
return LocalDate{}, fmt.Errorf("ParseLocalDate: %w", err) return LocalDate{}, err
} }
return LocalDateOf(t), nil return LocalDateOf(t), nil
@@ -166,7 +166,7 @@ func LocalTimeOf(t time.Time) LocalTime {
func ParseLocalTime(s string) (LocalTime, error) { func ParseLocalTime(s string) (LocalTime, error) {
t, err := time.Parse("15:04:05.999999999", s) t, err := time.Parse("15:04:05.999999999", s)
if err != nil { if err != nil {
return LocalTime{}, fmt.Errorf("ParseLocalTime: %w", err) return LocalTime{}, err
} }
return LocalTimeOf(t), nil return LocalTimeOf(t), nil
@@ -237,7 +237,7 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) {
if err != nil { if err != nil {
t, err = time.Parse("2006-01-02t15:04:05.999999999", s) t, err = time.Parse("2006-01-02t15:04:05.999999999", s)
if err != nil { if err != nil {
return LocalDateTime{}, fmt.Errorf("ParseLocalDateTime: %w", err) return LocalDateTime{}, err
} }
} }
+11 -24
View File
@@ -3,7 +3,6 @@ package toml
import ( import (
"bytes" "bytes"
"encoding" "encoding"
"errors"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
@@ -116,12 +115,12 @@ func (enc *Encoder) Encode(v interface{}) error {
b, err := enc.encode(b, ctx, reflect.ValueOf(v)) b, err := enc.encode(b, ctx, reflect.ValueOf(v))
if err != nil { if err != nil {
return fmt.Errorf("Encode: %w", err) return err
} }
_, err = enc.w.Write(b) _, err = enc.w.Write(b)
if err != nil { if err != nil {
return fmt.Errorf("Encode: %w", err) return fmt.Errorf("toml: cannot write: %w", err)
} }
return nil return nil
@@ -178,11 +177,6 @@ func (ctx *encoderCtx) isRoot() bool {
return len(ctx.parentKey) == 0 && !ctx.hasKey return len(ctx.parentKey) == 0 && !ctx.hasKey
} }
var (
errUnsupportedValue = errors.New("unsupported encode value kind")
errTextMarshalerCannotBeAtRoot = errors.New("type implementing TextMarshaler cannot be at root")
)
//nolint:cyclop,funlen //nolint:cyclop,funlen
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
i, ok := v.Interface().(time.Time) i, ok := v.Interface().(time.Time)
@@ -192,12 +186,12 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
if v.Type().Implements(textMarshalerType) { if v.Type().Implements(textMarshalerType) {
if ctx.isRoot() { if ctx.isRoot() {
return nil, errTextMarshalerCannotBeAtRoot return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type())
} }
text, err := v.Interface().(encoding.TextMarshaler).MarshalText() text, err := v.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return nil, fmt.Errorf("encode: %w", err) return nil, err
} }
b = enc.encodeString(b, string(text), ctx.options) b = enc.encodeString(b, string(text), ctx.options)
@@ -215,7 +209,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
return enc.encodeSlice(b, ctx, v) return enc.encodeSlice(b, ctx, v)
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return nil, errNilInterface return nil, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
@@ -244,7 +238,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
b = strconv.AppendInt(b, v.Int(), 10) b = strconv.AppendInt(b, v.Int(), 10)
default: default:
return nil, fmt.Errorf("encode(type %s): %w", v.Kind(), errUnsupportedValue) return nil, fmt.Errorf("toml: cannot encode value of type %s", v.Kind())
} }
return b, nil return b, nil
@@ -412,8 +406,6 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
return b, nil return b, nil
} }
var errTomlNoMultiline = errors.New("TOML does not support multiline keys")
//nolint:cyclop //nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
needsQuotation := false needsQuotation := false
@@ -425,7 +417,7 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
} }
if c == '\n' { if c == '\n' {
return nil, errTomlNoMultiline return nil, fmt.Errorf("toml: new line characters in keys are not supported")
} }
if c == literalQuote { if c == literalQuote {
@@ -445,11 +437,9 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) {
} }
} }
var errNotSupportedAsMapKey = errors.New("type not supported as map key")
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
if v.Type().Key().Kind() != reflect.String { if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("encodeMap '%s': %w", v.Type().Key().Kind(), errNotSupportedAsMapKey) return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind())
} }
var ( var (
@@ -658,10 +648,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
return b, nil return b, nil
} }
var ( var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
errNilInterface = errors.New("nil interface not supported")
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
)
func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
if v.Type() == timeType || v.Type().Implements(textMarshalerType) { if v.Type() == timeType || v.Type().Implements(textMarshalerType) {
@@ -674,7 +661,7 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
return !ctx.inline, nil return !ctx.inline, nil
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return willConvertToTable(ctx, v.Elem()) return willConvertToTable(ctx, v.Elem())
@@ -694,7 +681,7 @@ func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) (bool, erro
if t.Kind() == reflect.Interface { if t.Kind() == reflect.Interface {
if v.IsNil() { if v.IsNil() {
return false, errNilInterface return false, fmt.Errorf("toml: encoding a nil interface is not supported")
} }
return willConvertToTableOrArrayTable(ctx, v.Elem()) return willConvertToTableOrArrayTable(ctx, v.Elem())
+23 -41
View File
@@ -2,7 +2,6 @@ package toml
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"strconv" "strconv"
@@ -225,19 +224,13 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
} }
var (
errExpectedValNotEOF = errors.New("expected value, not eof")
errExpectedTrue = errors.New("expected 'true'")
errExpectedFalse = errors.New("expected 'false'")
)
//nolint:cyclop,funlen //nolint:cyclop,funlen
func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
// val = string / boolean / array / inline-table / date-time / float / integer // val = string / boolean / array / inline-table / date-time / float / integer
var ref ast.Reference var ref ast.Reference
if len(b) == 0 { if len(b) == 0 {
return ref, nil, errExpectedValNotEOF return ref, nil, newDecodeError(b, "expected value, not eof")
} }
var err error var err error
@@ -278,7 +271,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
case 't': case 't':
if !scanFollowsTrue(b) { if !scanFollowsTrue(b) {
return ref, nil, errExpectedTrue return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
@@ -289,7 +282,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b[4:], nil return ref, b[4:], nil
case 'f': case 'f':
if !scanFollowsFalse(b) { if !scanFollowsFalse(b) {
return ast.Reference{}, nil, errExpectedFalse return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(ast.Node{
@@ -307,6 +300,13 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
} }
func atmost(b []byte, n int) []byte {
if n >= len(b) {
return b
}
return b[:n]
}
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) {
v, rest, err := scanLiteralString(b) v, rest, err := scanLiteralString(b)
if err != nil { if err != nil {
@@ -370,8 +370,6 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
return parent, rest, err return parent, rest, err
} }
var errArrayCannotStartWithComma = errors.New("array cannot start with comma")
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
// array = array-open [ array-values ] ws-comment-newline array-close // array = array-open [ array-values ] ws-comment-newline array-close
@@ -409,7 +407,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if b[0] == ',' { if b[0] == ',' {
if first { if first {
return parent, nil, errArrayCannotStartWithComma return parent, nil, newDecodeError(b[0:1], "array cannot start with comma")
} }
b = b[1:] b = b[1:]
@@ -494,8 +492,6 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) {
return token[i : len(token)-3], rest, err return token[i : len(token)-3], rest, err
} }
var errInvalidEscapeChar = errors.New("invalid escaped character")
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
@@ -582,7 +578,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) {
builder.WriteString(x) builder.WriteString(x)
i += 8 i += 8
default: default:
return nil, nil, fmt.Errorf("parseMultilineBasicString: %w - %#U", errInvalidEscapeChar, c) return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
} else { } else {
builder.WriteByte(c) builder.WriteByte(c)
@@ -721,7 +717,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
builder.WriteString(x) builder.WriteString(x)
i += 8 i += 8
default: default:
return nil, nil, fmt.Errorf("parseBasicString: %w - %#U", errInvalidEscapeChar, c) return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c)
} }
} else { } else {
builder.WriteByte(c) builder.WriteByte(c)
@@ -731,18 +727,17 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) {
return builder.Bytes(), rest, nil return builder.Bytes(), rest, nil
} }
var errUnicodePointNeedsRightCountChar = errors.New("unicode point needs right number of hex characters")
func hexToString(b []byte, length int) (string, error) { func hexToString(b []byte, length int) (string, error) {
if len(b) < length { if len(b) < length {
return "", fmt.Errorf("hexToString: %w - %d", errUnicodePointNeedsRightCountChar, length) return "", newDecodeError(b, "unicode point needs %d character, not %d", length, len(b))
} }
b = b[:length]
//nolint:godox //nolint:godox
// TODO: slow // TODO: slow
intcode, err := strconv.ParseInt(string(b[:length]), 16, 32) intcode, err := strconv.ParseInt(string(b), 16, 32)
if err != nil { if err != nil {
return "", fmt.Errorf("hexToString: %w", err) return "", newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
} }
return string(rune(intcode)), nil return string(rune(intcode)), nil
@@ -757,17 +752,12 @@ func (p *parser) parseWhitespace(b []byte) []byte {
return rest return rest
} }
var (
errExpectedInf = errors.New("expected 'inf'")
errExpectedNan = errors.New("expected 'nan'")
)
//nolint:cyclop //nolint:cyclop
func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) {
switch b[0] { switch b[0] {
case 'i': case 'i':
if !scanFollowsInf(b) { if !scanFollowsInf(b) {
return ast.Reference{}, nil, errExpectedInf return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'inf'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(ast.Node{
@@ -776,7 +766,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
}), b[3:], nil }), b[3:], nil
case 'n': case 'n':
if !scanFollowsNan(b) { if !scanFollowsNan(b) {
return ast.Reference{}, nil, errExpectedNan return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'nan'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(ast.Node{
@@ -821,8 +811,6 @@ func digitsToInt(b []byte) int {
return x return x
} }
var errTimezoneButNoTimeComponent = errors.New("possible DateTime cannot have a timezone but no time component")
//nolint:gocognit,cyclop //nolint:gocognit,cyclop
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
@@ -867,7 +855,7 @@ byteLoop:
} }
} else { } else {
if hasTz { if hasTz {
return ast.Reference{}, nil, errTimezoneButNoTimeComponent return ast.Reference{}, nil, newDecodeError(b, "date-time has timezone but not time component")
} }
kind = ast.LocalDate kind = ast.LocalDate
} }
@@ -878,12 +866,6 @@ byteLoop:
}), b[i:], nil }), b[i:], nil
} }
var (
errUnexpectedCharI = fmt.Errorf("unexpected character i while scanning for a number")
errUnexpectedCharN = fmt.Errorf("unexpected character n while scanning for a number")
errExpectedIntOrFloat = fmt.Errorf("expected integer or float")
)
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
i := 0 i := 0
@@ -940,7 +922,7 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.Reference{}, nil, errUnexpectedCharI return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number")
} }
if c == 'n' { if c == 'n' {
@@ -951,14 +933,14 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.Reference{}, nil, errUnexpectedCharN return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number")
} }
break break
} }
if i == 0 { if i == 0 {
return ast.Reference{}, b, errExpectedIntOrFloat return ast.Reference{}, b, newDecodeError(b, "incomplete number")
} }
kind := ast.Integer kind := ast.Integer
+6 -23
View File
@@ -1,9 +1,5 @@
package toml package toml
import (
"errors"
)
func scanFollows(b []byte, pattern string) bool { func scanFollows(b []byte, pattern string) bool {
n := len(pattern) n := len(pattern)
@@ -83,22 +79,17 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`)
} }
var (
errWindowsNewLineMissing = errors.New(`windows new line missing \n`)
errWindowsNewLineCRLF = errors.New(`windows new line should be \r\n`)
)
func scanWindowsNewline(b []byte) ([]byte, []byte, error) { func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
const lenLF = 2 const lenCRLF = 2
if len(b) < lenLF { if len(b) < lenCRLF {
return nil, nil, errWindowsNewLineMissing return nil, nil, newDecodeError(b, "windows new line expected")
} }
if b[1] != '\n' { if b[1] != '\n' {
return nil, nil, errWindowsNewLineCRLF return nil, nil, newDecodeError(b, `windows new line should be \r\n`)
} }
return b[:lenLF], b[lenLF:], nil return b[:lenCRLF], b[lenCRLF:], nil
} }
func scanWhitespace(b []byte) ([]byte, []byte) { func scanWhitespace(b []byte) ([]byte, []byte) {
@@ -116,8 +107,6 @@ func scanWhitespace(b []byte) ([]byte, []byte) {
//nolint:unparam //nolint:unparam
func scanComment(b []byte) ([]byte, []byte) { func scanComment(b []byte) ([]byte, []byte) {
// ;; Comment
//
// comment-start-symbol = %x23 ; # // comment-start-symbol = %x23 ; #
// non-ascii = %x80-D7FF / %xE000-10FFFF // non-ascii = %x80-D7FF / %xE000-10FFFF
// non-eol = %x09 / %x20-7F / non-ascii // non-eol = %x09 / %x20-7F / non-ascii
@@ -132,10 +121,6 @@ func scanComment(b []byte) ([]byte, []byte) {
return b, nil return b, nil
} }
var errBasicLineNotTerminatedByQuote = errors.New(`basic string not terminated by "`)
//nolint:godox
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) { func scanBasicString(b []byte) ([]byte, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
@@ -156,11 +141,9 @@ func scanBasicString(b []byte) ([]byte, []byte, error) {
} }
} }
return nil, nil, errBasicLineNotTerminatedByQuote return nil, nil, newDecodeError(b[len(b):], `basic string not terminated by "`)
} }
//nolint:godox
// TODO perform validation on the string?
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim // ml-basic-string-delim
+111 -358
View File
@@ -1,7 +1,6 @@
package toml package toml
import ( import (
"errors"
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
@@ -14,19 +13,19 @@ type target interface {
get() reflect.Value get() reflect.Value
// Store a string at the target. // Store a string at the target.
setString(v string) error setString(v string)
// Store a boolean at the target // Store a boolean at the target
setBool(v bool) error setBool(v bool)
// Store an int64 at the target // Store an int64 at the target
setInt64(v int64) error setInt64(v int64)
// Store a float64 at the target // Store a float64 at the target
setFloat64(v float64) error setFloat64(v float64)
// Stores any value at the target // Stores any value at the target
set(v reflect.Value) error set(v reflect.Value)
} }
// valueTarget just contains a reflect.Value that can be set. // valueTarget just contains a reflect.Value that can be set.
@@ -37,34 +36,24 @@ func (t valueTarget) get() reflect.Value {
return reflect.Value(t) return reflect.Value(t)
} }
func (t valueTarget) set(v reflect.Value) error { func (t valueTarget) set(v reflect.Value) {
reflect.Value(t).Set(v) reflect.Value(t).Set(v)
return nil
} }
func (t valueTarget) setString(v string) error { func (t valueTarget) setString(v string) {
t.get().SetString(v) t.get().SetString(v)
return nil
} }
func (t valueTarget) setBool(v bool) error { func (t valueTarget) setBool(v bool) {
t.get().SetBool(v) t.get().SetBool(v)
return nil
} }
func (t valueTarget) setInt64(v int64) error { func (t valueTarget) setInt64(v int64) {
t.get().SetInt(v) t.get().SetInt(v)
return nil
} }
func (t valueTarget) setFloat64(v float64) error { func (t valueTarget) setFloat64(v float64) {
t.get().SetFloat(v) t.get().SetFloat(v)
return nil
} }
// interfaceTarget wraps an other target to dereference on get. // interfaceTarget wraps an other target to dereference on get.
@@ -76,49 +65,24 @@ func (t interfaceTarget) get() reflect.Value {
return t.x.get().Elem() return t.x.get().Elem()
} }
func (t interfaceTarget) set(v reflect.Value) error { func (t interfaceTarget) set(v reflect.Value) {
err := t.x.set(v) t.x.set(v)
if err != nil {
return fmt.Errorf("interfaceTarget set: %w", err)
}
return nil
} }
func (t interfaceTarget) setString(v string) error { func (t interfaceTarget) setString(v string) {
err := t.x.setString(v) t.x.setString(v)
if err != nil {
return fmt.Errorf("interfaceTarget setString: %w", err)
}
return nil
} }
func (t interfaceTarget) setBool(v bool) error { func (t interfaceTarget) setBool(v bool) {
err := t.x.setBool(v) t.x.setBool(v)
if err != nil {
return fmt.Errorf("interfaceTarget setBool: %w", err)
}
return nil
} }
func (t interfaceTarget) setInt64(v int64) error { func (t interfaceTarget) setInt64(v int64) {
err := t.x.setInt64(v) t.x.setInt64(v)
if err != nil {
return fmt.Errorf("interfaceTarget setInt64: %w", err)
}
return nil
} }
func (t interfaceTarget) setFloat64(v float64) error { func (t interfaceTarget) setFloat64(v float64) {
err := t.x.setFloat64(v) t.x.setFloat64(v)
if err != nil {
return fmt.Errorf("interfaceTarget setFloat64: %w", err)
}
return nil
} }
// mapTarget targets a specific key of a map. // mapTarget targets a specific key of a map.
@@ -131,33 +95,26 @@ func (t mapTarget) get() reflect.Value {
return t.v.MapIndex(t.k) return t.v.MapIndex(t.k)
} }
func (t mapTarget) set(v reflect.Value) error { func (t mapTarget) set(v reflect.Value) {
t.v.SetMapIndex(t.k, v) t.v.SetMapIndex(t.k, v)
return nil
} }
func (t mapTarget) setString(v string) error { func (t mapTarget) setString(v string) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setBool(v bool) error { func (t mapTarget) setBool(v bool) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setInt64(v int64) error { func (t mapTarget) setInt64(v int64) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
func (t mapTarget) setFloat64(v float64) error { func (t mapTarget) setFloat64(v float64) {
return t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
} }
var (
errValIndexExpectingSlice = errors.New("expecting a slice")
errValIndexCannotInitSlice = errors.New("cannot initialize a slice")
)
//nolint:cyclop //nolint:cyclop
// makes sure that the value pointed at by t is indexable (Slice, Array), or // makes sure that the value pointed at by t is indexable (Slice, Array), or
// dereferences to an indexable (Ptr, Interface). // dereferences to an indexable (Ptr, Interface).
@@ -167,43 +124,20 @@ func ensureValueIndexable(t target) error {
switch f.Type().Kind() { switch f.Type().Kind() {
case reflect.Slice: case reflect.Slice:
if f.IsNil() { if f.IsNil() {
err := t.set(reflect.MakeSlice(f.Type(), 0, 0)) t.set(reflect.MakeSlice(f.Type(), 0, 0))
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
return nil return nil
} }
case reflect.Interface: case reflect.Interface:
if f.IsNil() || f.Elem().Type() != sliceInterfaceType { if f.IsNil() || f.Elem().Type() != sliceInterfaceType {
err := t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0))
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
return nil return nil
} }
if f.Elem().Type().Kind() != reflect.Slice {
return fmt.Errorf("ensureValueIndexable: %w, not a %s", errValIndexExpectingSlice, f.Kind())
}
case reflect.Ptr: case reflect.Ptr:
if f.IsNil() { panic("pointer should have already been dereferenced")
ptr := reflect.New(f.Type().Elem())
err := t.set(ptr)
if err != nil {
return fmt.Errorf("ensureValueIndexable: %w", err)
}
f = t.get()
}
return ensureValueIndexable(valueTarget(f.Elem()))
case reflect.Array: case reflect.Array:
// arrays are always initialized. // arrays are always initialized.
default: default:
return fmt.Errorf("ensureValueIndexable: %w with %s", errValIndexCannotInitSlice, f.Kind()) return fmt.Errorf("toml: cannot store array in a %s", f.Kind())
} }
return nil return nil
@@ -214,69 +148,44 @@ var (
mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
) )
func ensureMapIfInterface(x target) error { func ensureMapIfInterface(x target) {
v := x.get() v := x.get()
if v.Kind() == reflect.Interface && v.IsNil() { if v.Kind() == reflect.Interface && v.IsNil() {
newElement := reflect.MakeMap(mapStringInterfaceType) newElement := reflect.MakeMap(mapStringInterfaceType)
err := x.set(newElement) x.set(newElement)
if err != nil {
return fmt.Errorf("ensureMapIfInterface: %w", err)
}
} }
return nil
} }
var errSetStringCannotAssignString = errors.New("cannot assign string")
func setString(t target, v string) error { func setString(t target, v string) error {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.String: case reflect.String:
err := t.setString(v) t.setString(v)
if err != nil {
return fmt.Errorf("setString: %w", err)
}
return nil
case reflect.Interface: case reflect.Interface:
err := t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setString: %w", err)
}
return nil
default: default:
return fmt.Errorf("setString: %w to a %s", errSetStringCannotAssignString, f.Kind()) return fmt.Errorf("toml: cannot assign string to a %s", f.Kind())
} }
}
var errSetBoolCannotAssignBool = errors.New("cannot assign bool") return nil
}
func setBool(t target, v bool) error { func setBool(t target, v bool) error {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.Bool: case reflect.Bool:
err := t.setBool(v) t.setBool(v)
if err != nil {
return fmt.Errorf("setBool: %w", err)
}
return nil
case reflect.Interface: case reflect.Interface:
err := t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setBool: %w", err)
}
return nil
default: default:
return fmt.Errorf("setBool: %w to a %s", errSetBoolCannotAssignBool, f.String()) return fmt.Errorf("toml: cannot assign boolean to a %s", f.Kind())
} }
return nil
} }
const ( const (
@@ -284,207 +193,104 @@ const (
minInt = -maxInt - 1 minInt = -maxInt - 1
) )
var (
errSetInt64InInt32 = errors.New("does not fit in an int32")
errSetInt64InInt16 = errors.New("does not fit in an int16")
errSetInt64InInt8 = errors.New("does not fit in an int8")
errSetInt64InInt = errors.New("does not fit in an int")
errSetInt64InUint64 = errors.New("negative integer does not fit in an uint64")
errSetInt64InUint32 = errors.New("negative integer does not fit in an uint32")
errSetInt64InUint32Max = errors.New("integer does not fit in an uint32")
errSetInt64InUint16 = errors.New("negative integer does not fit in an uint16")
errSetInt64InUint16Max = errors.New("integer does not fit in an uint16")
errSetInt64InUint8 = errors.New("negative integer does not fit in an uint8")
errSetInt64InUint8Max = errors.New("integer does not fit in an uint8")
errSetInt64InUint = errors.New("negative integer does not fit in an uint")
errSetInt64Unknown = errors.New("does not fit in an uint")
)
//nolint:funlen,gocognit,cyclop,gocyclo //nolint:funlen,gocognit,cyclop,gocyclo
func setInt64(t target, v int64) error { func setInt64(t target, v int64) error {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.Int64: case reflect.Int64:
err := t.setInt64(v) t.setInt64(v)
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Int32: case reflect.Int32:
if v < math.MinInt32 || v > math.MaxInt32 { if v < math.MinInt32 || v > math.MaxInt32 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt32) return fmt.Errorf("toml: number %d does not fit in an int32", v)
}
err := t.set(reflect.ValueOf(int32(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
} }
t.set(reflect.ValueOf(int32(v)))
return nil return nil
case reflect.Int16: case reflect.Int16:
if v < math.MinInt16 || v > math.MaxInt16 { if v < math.MinInt16 || v > math.MaxInt16 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt16) return fmt.Errorf("toml: number %d does not fit in an int16", v)
} }
err := t.set(reflect.ValueOf(int16(v))) t.set(reflect.ValueOf(int16(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Int8: case reflect.Int8:
if v < math.MinInt8 || v > math.MaxInt8 { if v < math.MinInt8 || v > math.MaxInt8 {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt8) return fmt.Errorf("toml: number %d does not fit in an int8", v)
} }
err := t.set(reflect.ValueOf(int8(v))) t.set(reflect.ValueOf(int8(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Int: case reflect.Int:
if v < minInt || v > maxInt { if v < minInt || v > maxInt {
return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt) return fmt.Errorf("toml: number %d does not fit in an int", v)
} }
err := t.set(reflect.ValueOf(int(v))) t.set(reflect.ValueOf(int(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Uint64: case reflect.Uint64:
if v < 0 { if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint64) return fmt.Errorf("toml: negative number %d does not fit in an uint64", v)
} }
err := t.set(reflect.ValueOf(uint64(v))) t.set(reflect.ValueOf(uint64(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Uint32: case reflect.Uint32:
if v < 0 { if v < 0 || v > math.MaxUint32 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32) return fmt.Errorf("toml: negative number %d does not fit in an uint32", v)
} }
if v > math.MaxUint32 { t.set(reflect.ValueOf(uint32(v)))
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32Max)
}
err := t.set(reflect.ValueOf(uint32(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Uint16: case reflect.Uint16:
if v < 0 { if v < 0 || v > math.MaxUint16 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16) return fmt.Errorf("toml: negative number %d does not fit in an uint16", v)
} }
if v > math.MaxUint16 { t.set(reflect.ValueOf(uint16(v)))
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16Max)
}
err := t.set(reflect.ValueOf(uint16(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Uint8: case reflect.Uint8:
if v < 0 { if v < 0 || v > math.MaxUint8 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8) return fmt.Errorf("toml: negative number %d does not fit in an uint8", v)
} }
if v > math.MaxUint8 { t.set(reflect.ValueOf(uint8(v)))
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8Max)
}
err := t.set(reflect.ValueOf(uint8(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Uint: case reflect.Uint:
if v < 0 { if v < 0 {
return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint) return fmt.Errorf("toml: negative number %d does not fit in an uint", v)
} }
err := t.set(reflect.ValueOf(uint(v))) t.set(reflect.ValueOf(uint(v)))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
case reflect.Interface: case reflect.Interface:
err := t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setInt64: %w", err)
}
return nil
default: default:
return fmt.Errorf("setInt64: %s, %w", f.String(), errSetInt64Unknown) return fmt.Errorf("toml: integer cannot be assigned to %s", f.Kind())
} }
}
var ( return nil
errSetFloat64InFloat32Max = errors.New("does not fit in an float32") }
errSetFloat64Unknown = errors.New("does not fit in an float32")
)
func setFloat64(t target, v float64) error { func setFloat64(t target, v float64) error {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
case reflect.Float64: case reflect.Float64:
err := t.setFloat64(v) t.setFloat64(v)
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
case reflect.Float32: case reflect.Float32:
if v > math.MaxFloat32 { if v > math.MaxFloat32 {
return fmt.Errorf("setFloat64: %f %w", v, errSetFloat64InFloat32Max) return fmt.Errorf("toml: number %f does not fit in a float32", v)
} }
err := t.set(reflect.ValueOf(float32(v))) t.set(reflect.ValueOf(float32(v)))
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
case reflect.Interface: case reflect.Interface:
err := t.set(reflect.ValueOf(v)) t.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setFloat64: %w", err)
}
return nil
default: default:
return fmt.Errorf("setFloat64: %s %w", f.String(), errSetFloat64Unknown) return fmt.Errorf("toml: float cannot be assigned to %s", f.Kind())
} }
}
var ( return nil
errElementAtCannotOn = errors.New("cannot elementAt") }
errElementAtCannotOnUnknown = errors.New("cannot elementAt")
)
//nolint:cyclop //nolint:cyclop
// Returns the element at idx of the value pointed at by target, or an error if // Returns the element at idx of the value pointed at by target, or an error if
// t does not point to an indexable. // t does not point to an indexable.
// If the target points to an Array and idx is out of bounds, it returns // If the target points to an Array and idx is out of bounds, it returns
// (nil, nil) as this is not a fatal error (the unmarshaler will skip). // (nil, nil) as this is not a fatal error (the unmarshaler will skip).
func elementAt(t target, idx int) (target, error) { func elementAt(t target, idx int) target {
f := t.get() f := t.get()
switch f.Kind() { switch f.Kind() {
@@ -493,42 +299,30 @@ func elementAt(t target, idx int) (target, error) {
// TODO: use the idx function argument and avoid alloc if possible. // TODO: use the idx function argument and avoid alloc if possible.
idx := f.Len() idx := f.Len()
err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem()))
if err != nil {
return nil, fmt.Errorf("elementAt: %w", err)
}
return valueTarget(t.get().Index(idx)), nil return valueTarget(t.get().Index(idx))
case reflect.Array: case reflect.Array:
if idx >= f.Len() { if idx >= f.Len() {
return nil, nil return nil
} }
return valueTarget(f.Index(idx)), nil return valueTarget(f.Index(idx))
case reflect.Interface: case reflect.Interface:
if f.IsNil() { // This function is called after ensureValueIndexable, so it's
panic("interface should have been initialized") // guaranteed that f contains an initialized slice.
}
ifaceElem := f.Elem() ifaceElem := f.Elem()
if ifaceElem.Kind() != reflect.Slice {
return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOn, f.Kind())
}
idx := ifaceElem.Len() idx := ifaceElem.Len()
newElem := reflect.New(ifaceElem.Type().Elem()).Elem() newElem := reflect.New(ifaceElem.Type().Elem()).Elem()
newSlice := reflect.Append(ifaceElem, newElem) newSlice := reflect.Append(ifaceElem, newElem)
err := t.set(newSlice) t.set(newSlice)
if err != nil {
return nil, fmt.Errorf("elementAt: %w", err)
}
return valueTarget(t.get().Elem().Index(idx)), nil return valueTarget(t.get().Elem().Index(idx))
case reflect.Ptr:
return elementAt(valueTarget(f.Elem()), idx)
default: default:
return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOnUnknown, f.Kind()) // Why ensureValueIndexable let it go through?
panic(fmt.Errorf("elementAt received unhandled value type: %s", f.Kind()))
} }
} }
@@ -539,31 +333,19 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
switch x.Kind() { switch x.Kind() {
// Kinds that need to recurse // Kinds that need to recurse
case reflect.Interface: case reflect.Interface:
t, err := scopeInterface(shouldAppend, t) t := scopeInterface(shouldAppend, t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
return d.scopeTableTarget(shouldAppend, t, name) return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Ptr: case reflect.Ptr:
t, err := scopePtr(t) t := scopePtr(t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
return d.scopeTableTarget(shouldAppend, t, name) return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Slice: case reflect.Slice:
t, err := scopeSlice(shouldAppend, t) t := scopeSlice(shouldAppend, t)
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
shouldAppend = false shouldAppend = false
return d.scopeTableTarget(shouldAppend, t, name) return d.scopeTableTarget(shouldAppend, t, name)
case reflect.Array: case reflect.Array:
t, err := d.scopeArray(shouldAppend, t) t, err := d.scopeArray(shouldAppend, t)
if err != nil { if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err) return t, false, err
} }
shouldAppend = false shouldAppend = false
@@ -574,11 +356,7 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
return scopeStruct(x, name) return scopeStruct(x, name)
case reflect.Map: case reflect.Map:
if x.IsNil() { if x.IsNil() {
err := t.set(reflect.MakeMap(x.Type())) t.set(reflect.MakeMap(x.Type()))
if err != nil {
return t, false, fmt.Errorf("scopeTableTarget: %w", err)
}
x = t.get() x = t.get()
} }
@@ -588,42 +366,29 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta
} }
} }
func scopeInterface(shouldAppend bool, t target) (target, error) { func scopeInterface(shouldAppend bool, t target) target {
err := initInterface(shouldAppend, t) initInterface(shouldAppend, t)
if err != nil { return interfaceTarget{t}
return t, err
}
return interfaceTarget{t}, nil
} }
func scopePtr(t target) (target, error) { func scopePtr(t target) target {
err := initPtr(t) initPtr(t)
if err != nil { return valueTarget(t.get().Elem())
return t, err
}
return valueTarget(t.get().Elem()), nil
} }
func initPtr(t target) error { func initPtr(t target) {
x := t.get() x := t.get()
if !x.IsNil() { if !x.IsNil() {
return nil return
} }
err := t.set(reflect.New(x.Type().Elem())) t.set(reflect.New(x.Type().Elem()))
if err != nil {
return fmt.Errorf("initPtr: %w", err)
}
return nil
} }
// initInterface makes sure that the interface pointed at by the target is not // initInterface makes sure that the interface pointed at by the target is not
// nil. // nil.
// Returns the target to the initialized value of the target. // Returns the target to the initialized value of the target.
func initInterface(shouldAppend bool, t target) error { func initInterface(shouldAppend bool, t target) {
x := t.get() x := t.get()
if x.Kind() != reflect.Interface { if x.Kind() != reflect.Interface {
@@ -631,7 +396,7 @@ func initInterface(shouldAppend bool, t target) error {
} }
if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) { if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) {
return nil return
} }
var newElement reflect.Value var newElement reflect.Value
@@ -641,55 +406,43 @@ func initInterface(shouldAppend bool, t target) error {
newElement = reflect.MakeMap(mapStringInterfaceType) newElement = reflect.MakeMap(mapStringInterfaceType)
} }
err := t.set(newElement) t.set(newElement)
if err != nil {
return fmt.Errorf("initInterface: %w", err)
}
return nil
} }
func scopeSlice(shouldAppend bool, t target) (target, error) { func scopeSlice(shouldAppend bool, t target) target {
v := t.get() v := t.get()
if shouldAppend { if shouldAppend {
newElem := reflect.New(v.Type().Elem()) newElem := reflect.New(v.Type().Elem())
newSlice := reflect.Append(v, newElem.Elem()) newSlice := reflect.Append(v, newElem.Elem())
err := t.set(newSlice) t.set(newSlice)
if err != nil {
return t, fmt.Errorf("scopeSlice: %w", err)
}
v = t.get() v = t.get()
} }
return valueTarget(v.Index(v.Len() - 1)), nil return valueTarget(v.Index(v.Len() - 1))
} }
var errScopeArrayNotEnoughSpace = errors.New("not enough space in the array")
func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) { func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) {
v := t.get() v := t.get()
idx := d.arrayIndex(shouldAppend, v) idx := d.arrayIndex(shouldAppend, v)
if idx >= v.Len() { if idx >= v.Len() {
return nil, errScopeArrayNotEnoughSpace return nil, fmt.Errorf("toml: impossible to insert element beyond array's size: %d", v.Len())
} }
return valueTarget(v.Index(idx)), nil return valueTarget(v.Index(idx)), nil
} }
var errScopeMapCannotConvertStringToKey = errors.New("cannot convert string into map key type")
func scopeMap(v reflect.Value, name string) (target, bool, error) { func scopeMap(v reflect.Value, name string) (target, bool, error) {
k := reflect.ValueOf(name) k := reflect.ValueOf(name)
keyType := v.Type().Key() keyType := v.Type().Key()
if !k.Type().AssignableTo(keyType) { if !k.Type().AssignableTo(keyType) {
if !k.Type().ConvertibleTo(keyType) { if !k.Type().ConvertibleTo(keyType) {
return nil, false, fmt.Errorf("scopeMap: %w %s", errScopeMapCannotConvertStringToKey, keyType) return nil, false, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", k.Type(), keyType)
} }
k = k.Convert(keyType) k = k.Convert(keyType)
+6 -10
View File
@@ -128,14 +128,12 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := elementAt(x, 0) n := elementAt(x, 0)
require.NoError(t, err) n.setString("hello")
require.NoError(t, n.setString("hello"))
require.Equal(t, []string{"hello"}, d.A) require.Equal(t, []string{"hello"}, d.A)
n, err = elementAt(x, 1) n = elementAt(x, 1)
require.NoError(t, err) n.setString("world")
require.NoError(t, n.setString("world"))
require.Equal(t, []string{"hello", "world"}, d.A) require.Equal(t, []string{"hello", "world"}, d.A)
}) })
@@ -151,13 +149,11 @@ func TestPushNew(t *testing.T) {
x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A")
require.NoError(t, err) require.NoError(t, err)
n, err := elementAt(x, 0) n := elementAt(x, 0)
require.NoError(t, err)
require.NoError(t, setString(n, "hello")) require.NoError(t, setString(n, "hello"))
require.Equal(t, []interface{}{"hello"}, d.A) require.Equal(t, []interface{}{"hello"}, d.A)
n, err = elementAt(x, 1) n = elementAt(x, 1)
require.NoError(t, err)
require.NoError(t, setString(n, "world")) require.NoError(t, setString(n, "world"))
require.Equal(t, []interface{}{"hello", "world"}, d.A) require.Equal(t, []interface{}{"hello", "world"}, d.A)
}) })
+2 -10
View File
@@ -94,12 +94,7 @@ func testGenTranslateDesc(input interface{}) interface{} {
if ok { if ok {
dvalue, ok = d["value"] dvalue, ok = d["value"]
if ok { if ok {
var okdt bool dtype = dtypeiface.(string)
dtype, okdt = dtypeiface.(string)
if !okdt {
panic(fmt.Sprintf("dtypeiface should be valid string: %v", dtypeiface))
}
switch dtype { switch dtype {
case "string": case "string":
@@ -132,10 +127,7 @@ func testGenTranslateDesc(input interface{}) interface{} {
return nil return nil
} }
a, oka := dvalue.([]interface{}) a := dvalue.([]interface{})
if !oka {
panic(fmt.Sprintf("a should be valid []interface{}: %v", a))
}
xs := make([]interface{}, len(a)) xs := make([]interface{}, len(a))
+34 -67
View File
@@ -56,7 +56,7 @@ func (d *Decoder) SetStrict(strict bool) {
func (d *Decoder) Decode(v interface{}) error { func (d *Decoder) Decode(v interface{}) error {
b, err := ioutil.ReadAll(d.r) b, err := ioutil.ReadAll(d.r)
if err != nil { if err != nil {
return fmt.Errorf("Decode: %w", err) return fmt.Errorf("toml: %w", err)
} }
p := parser{} p := parser{}
@@ -130,20 +130,15 @@ func keyLocation(node ast.Node) []byte {
return unsafe.BytesRange(start, end) return unsafe.BytesRange(start, end)
} }
var (
errFromParserExpectingPointer = errors.New("expecting a pointer as target")
errFromParserExpectingNonNilPointer = errors.New("expecting non nil pointer as target")
)
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (d *decoder) fromParser(p *parser, v interface{}) error { func (d *decoder) fromParser(p *parser, v interface{}) error {
r := reflect.ValueOf(v) r := reflect.ValueOf(v)
if r.Kind() != reflect.Ptr { if r.Kind() != reflect.Ptr {
return fmt.Errorf("fromParser: %w, not %s", errFromParserExpectingPointer, r.Kind()) return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
} }
if r.IsNil() { if r.IsNil() {
return errFromParserExpectingNonNilPointer return fmt.Errorf("toml: decoding pointer target cannot be nil")
} }
var ( var (
@@ -162,7 +157,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
err := d.seen.CheckExpression(node) err := d.seen.CheckExpression(node)
if err != nil { if err != nil {
return fmt.Errorf("fromParser: %w", err) return err
} }
var found bool var found bool
@@ -181,16 +176,13 @@ func (d *decoder) fromParser(p *parser, v interface{}) error {
// looks like a table. Otherwise the information // looks like a table. Otherwise the information
// of a table is lost, and marshal cannot do the // of a table is lost, and marshal cannot do the
// round trip. // round trip.
err := ensureMapIfInterface(current) ensureMapIfInterface(current)
if err != nil {
panic(fmt.Sprintf("ensureMapIfInterface: %s", err))
}
} }
case ast.ArrayTable: case ast.ArrayTable:
d.strict.EnterArrayTable(node) d.strict.EnterArrayTable(node)
current, found, err = d.scopeWithArrayTable(root, node.Key()) current, found, err = d.scopeWithArrayTable(root, node.Key())
default: default:
panic(fmt.Sprintf("fromParser: this should not be a top level node type: %s", node.Kind)) panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind))
} }
if err != nil { if err != nil {
@@ -267,26 +259,18 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool,
v := x.get() v := x.get()
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
x, err = scopePtr(x) x = scopePtr(x)
if err != nil {
return x, false, err
}
v = x.get() v = x.get()
} }
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
x, err = scopeInterface(true, x) x = scopeInterface(true, x)
if err != nil {
return x, found, err
}
v = x.get() v = x.get()
} }
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
x, err = scopeSlice(true, x) x = scopeSlice(true, x)
case reflect.Array: case reflect.Array:
x, err = d.scopeArray(true, x) x, err = d.scopeArray(true, x)
default: default:
@@ -334,7 +318,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
if v.Type().Implements(textUnmarshalerType) { if v.Type().Implements(textUnmarshalerType) {
err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil { if err != nil {
return false, fmt.Errorf("tryTextUnmarshaler: %w", err) return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
} }
return true, nil return true, nil
@@ -343,7 +327,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil { if err != nil {
return false, fmt.Errorf("tryTextUnmarshaler: %w", err) return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
} }
return true, nil return true, nil
@@ -358,11 +342,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if !v.Elem().IsValid() { if !v.Elem().IsValid() {
err := x.set(reflect.New(v.Type().Elem())) x.set(reflect.New(v.Type().Elem()))
if err != nil {
return fmt.Errorf("unmarshalValue: %w", err)
}
v = x.get() v = x.get()
} }
@@ -394,7 +374,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error {
case ast.LocalDate: case ast.LocalDate:
return unmarshalLocalDate(x, node) return unmarshalLocalDate(x, node)
default: default:
panic(fmt.Sprintf("unmarshalValue: unhandled unmarshalValue kind %s", node.Kind)) panic(fmt.Sprintf("unhandled node kind %s", node.Kind))
} }
} }
@@ -406,7 +386,9 @@ func unmarshalLocalDate(x target, node ast.Node) error {
return err return err
} }
return setDate(x, v) setDate(x, v)
return nil
} }
func unmarshalLocalDateTime(x target, node ast.Node) error { func unmarshalLocalDateTime(x target, node ast.Node) error {
@@ -421,7 +403,9 @@ func unmarshalLocalDateTime(x target, node ast.Node) error {
return newDecodeError(rest, "extra characters at the end of a local date time") return newDecodeError(rest, "extra characters at the end of a local date time")
} }
return setLocalDateTime(x, v) setLocalDateTime(x, v)
return nil
} }
func unmarshalDateTime(x target, node ast.Node) error { func unmarshalDateTime(x target, node ast.Node) error {
@@ -432,48 +416,37 @@ func unmarshalDateTime(x target, node ast.Node) error {
return err return err
} }
return setDateTime(x, v) setDateTime(x, v)
return nil
} }
func setLocalDateTime(x target, v LocalDateTime) error { func setLocalDateTime(x target, v LocalDateTime) {
if x.get().Type() == timeType { if x.get().Type() == timeType {
cast := v.In(time.Local) cast := v.In(time.Local)
return setDateTime(x, cast) setDateTime(x, cast)
return
} }
err := x.set(reflect.ValueOf(v)) x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setLocalDateTime: %w", err)
}
return nil
} }
func setDateTime(x target, v time.Time) error { func setDateTime(x target, v time.Time) {
err := x.set(reflect.ValueOf(v)) x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setDateTime: %w", err)
}
return nil
} }
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf(time.Time{})
func setDate(x target, v LocalDate) error { func setDate(x target, v LocalDate) {
if x.get().Type() == timeType { if x.get().Type() == timeType {
cast := v.In(time.Local) cast := v.In(time.Local)
return setDateTime(x, cast) setDateTime(x, cast)
return
} }
err := x.set(reflect.ValueOf(v)) x.set(reflect.ValueOf(v))
if err != nil {
return fmt.Errorf("setDate: %w", err)
}
return nil
} }
func unmarshalString(x target, node ast.Node) error { func unmarshalString(x target, node ast.Node) error {
@@ -514,10 +487,7 @@ func unmarshalFloat(x target, node ast.Node) error {
func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error { func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error {
assertNode(ast.InlineTable, node) assertNode(ast.InlineTable, node)
err := ensureMapIfInterface(x) ensureMapIfInterface(x)
if err != nil {
return fmt.Errorf("unmarshalInlineTable: %w", err)
}
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
@@ -546,10 +516,7 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error {
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
v, err := elementAt(x, idx) v := elementAt(x, idx)
if err != nil {
return err
}
if v == nil { if v == nil {
// when we go out of bound for an array just stop processing it to // when we go out of bound for an array just stop processing it to
+57 -3
View File
@@ -38,6 +38,11 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `+99`, input: `+99`,
expected: 99, expected: 99,
}, },
{
desc: "integer decimal underscore",
input: `123_456`,
expected: 123456,
},
{ {
desc: "integer hex uppercase", desc: "integer hex uppercase",
input: `0xDEADBEEF`, input: `0xDEADBEEF`,
@@ -58,6 +63,21 @@ func TestUnmarshal_Integers(t *testing.T) {
input: `0b11010110`, input: `0b11010110`,
expected: 0b11010110, expected: 0b11010110,
}, },
{
desc: "double underscore",
input: "12__3",
err: true,
},
{
desc: "starts with underscore",
input: "_1",
err: true,
},
{
desc: "ends with underscore",
input: "1_",
err: true,
},
} }
type doc struct { type doc struct {
@@ -71,8 +91,12 @@ func TestUnmarshal_Integers(t *testing.T) {
doc := doc{} doc := doc{}
err := toml.Unmarshal([]byte(`A = `+e.input), &doc) err := toml.Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err) if e.err {
assert.Equal(t, e.expected, doc.A) require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, e.expected, doc.A)
}
}) })
} }
} }
@@ -799,6 +823,33 @@ B = "data"`,
} }
}, },
}, },
{
desc: "mismatch types int to string",
input: `A = 42`,
gen: func() test {
type S struct {
A string
}
return test{
target: &S{},
err: true,
}
},
},
{
desc: "mismatch types array of int to interface with non-slice",
input: `A = [[42]]`,
skip: true,
gen: func() test {
type S struct {
A *string
}
return test{
target: &S{},
expected: &S{},
}
},
},
} }
for _, e := range examples { for _, e := range examples {
@@ -815,6 +866,9 @@ B = "data"`,
} }
err := toml.Unmarshal([]byte(e.input), test.target) err := toml.Unmarshal([]byte(e.input), test.target)
if test.err { if test.err {
if err == nil {
t.Log("=>", test.target)
}
require.Error(t, err) require.Error(t, err)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
@@ -1030,7 +1084,7 @@ world'`,
if e.msg != "" { if e.msg != "" {
t.Log("\n" + de.String()) t.Log("\n" + de.String())
require.Equal(t, e.msg, de.Error()) require.Equal(t, "toml: "+e.msg, de.Error())
} }
}) })
} }