diff --git a/decode.go b/decode.go index 87330a7..afa6db3 100644 --- a/decode.go +++ b/decode.go @@ -1,11 +1,8 @@ package toml import ( - "errors" - "fmt" "math" "strconv" - "strings" "time" ) @@ -59,14 +56,12 @@ func parseLocalDate(b []byte) (LocalDate, error) { return date, nil } -var errNotDigit = errors.New("not a digit") - func parseDecimalDigits(b []byte) (int, error) { v := 0 - for _, c := range b { + for i, c := range b { if !isDigit(c) { - return 0, fmt.Errorf("%s: %w", b, errNotDigit) + return 0, newDecodeError(b[i:i+1], "should be a digit (0-9)") } v *= 10 @@ -76,13 +71,14 @@ func parseDecimalDigits(b []byte) (int, error) { return v, nil } -var errParseDateTimeMissingInfo = errors.New("date-time missing timezone information") - func parseDateTime(b []byte) (time.Time, error) { // offset-date-time = full-date time-delim full-time // full-time = partial-time time-offset // time-offset = "Z" / time-numoffset // time-numoffset = ( "+" / "-" ) time-hour ":" time-minute + + originalBytes := b + dt, b, err := parseLocalDateTime(b) if err != nil { return time.Time{}, err @@ -91,7 +87,7 @@ func parseDateTime(b []byte) (time.Time, error) { var zone *time.Location if len(b) == 0 { - return time.Time{}, errParseDateTimeMissingInfo + return time.Time{}, newDecodeError(originalBytes, "date-time is missing timezone") } if b[0] == 'Z' { @@ -134,19 +130,12 @@ func parseDateTime(b []byte) (time.Time, error) { return t, nil } -var ( - errParseLocalDateTimeWrongLength = errors.New( - "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]", - ) - errParseLocalDateTimeWrongSeparator = errors.New("datetime separator is expected to be T or a space") -) - func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { var dt LocalDateTime const localDateTimeByteMinLen = 11 if len(b) < localDateTimeByteMinLen { - return dt, nil, errParseLocalDateTimeWrongLength + return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]") } date, err := parseLocalDate(b[:10]) @@ -157,7 +146,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { sep := b[10] if sep != 'T' && sep != ' ' { - return dt, nil, errParseLocalDateTimeWrongSeparator + return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space") } t, rest, err := parseLocalTime(b[11:]) @@ -169,8 +158,6 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { return dt, rest, nil } -var errParseLocalTimeWrongLength = errors.New("times are expected to have the format HH:MM:SS[.NNNNNN]") - // parseLocalTime is a bit different because it also returns the remaining // []byte that is didn't need. This is to allow parseDateTime to parse those // remaining bytes as a timezone. @@ -183,7 +170,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) { const localTimeByteLen = 8 if len(b) < localTimeByteLen { - return t, nil, errParseLocalTimeWrongLength + return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") } var err error @@ -242,11 +229,6 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) { return t, b[8:], nil } -var ( - errParseFloatStartDot = errors.New("float cannot start with a dot") - errParseFloatEndDot = errors.New("float cannot end with a dot") -) - //nolint:cyclop func parseFloat(b []byte) (float64, error) { //nolint:godox @@ -255,150 +237,123 @@ func parseFloat(b []byte) (float64, error) { return math.NaN(), nil } - tok := string(b) - - err := numberContainsInvalidUnderscore(tok) + cleaned, err := checkAndRemoveUnderscores(b) if err != nil { return 0, err } - cleanedVal := cleanupNumberToken(tok) - if cleanedVal[0] == '.' { - return 0, errParseFloatStartDot + if cleaned[0] == '.' { + return 0, newDecodeError(b, "float cannot start with a dot") } - if cleanedVal[len(cleanedVal)-1] == '.' { - return 0, errParseFloatEndDot + if cleaned[len(cleaned)-1] == '.' { + return 0, newDecodeError(b, "float cannot end with a dot") } - f, err := strconv.ParseFloat(cleanedVal, 64) + f, err := strconv.ParseFloat(string(cleaned), 64) if err != nil { - return 0, fmt.Errorf("coudn't ParseFloat %w", err) + return 0, newDecodeError(b, "coudn't parse float: %w", err) } return f, nil } func parseIntHex(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - - err := hexNumberContainsInvalidUnderscore(cleanedVal) + cleaned, err := checkAndRemoveUnderscores(b[2:]) if err != nil { return 0, err } - i, err := strconv.ParseInt(cleanedVal[2:], 16, 64) + i, err := strconv.ParseInt(string(cleaned), 16, 64) if err != nil { - return 0, fmt.Errorf("coudn't ParseIntHex %w", err) + return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err) } return i, nil } func parseIntOct(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - - err := numberContainsInvalidUnderscore(cleanedVal) + cleaned, err := checkAndRemoveUnderscores(b[2:]) if err != nil { return 0, err } - i, err := strconv.ParseInt(cleanedVal[2:], 8, 64) + i, err := strconv.ParseInt(string(cleaned), 8, 64) if err != nil { - return 0, fmt.Errorf("coudn't ParseIntOct %w", err) + return 0, newDecodeError(b, "couldn't parse octal number: %w", err) } return i, nil } func parseIntBin(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - - err := numberContainsInvalidUnderscore(cleanedVal) + cleaned, err := checkAndRemoveUnderscores(b[2:]) if err != nil { return 0, err } - i, err := strconv.ParseInt(cleanedVal[2:], 2, 64) + i, err := strconv.ParseInt(string(cleaned), 2, 64) if err != nil { - return 0, fmt.Errorf("coudn't ParseIntBin %w", err) + return 0, newDecodeError(b, "couldn't parse binary number: %w", err) } return i, nil } func parseIntDec(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - - err := numberContainsInvalidUnderscore(cleanedVal) + cleaned, err := checkAndRemoveUnderscores(b) if err != nil { return 0, err } - i, err := strconv.ParseInt(cleanedVal, 10, 64) + i, err := strconv.ParseInt(string(cleaned), 10, 64) if err != nil { - return 0, fmt.Errorf("coudn't parseIntDec %w", err) + return 0, newDecodeError(b, "couldn't parse decimal number: %w", err) } return i, nil } -func numberContainsInvalidUnderscore(value string) error { - // For large numbers, you may use underscores between digits to enhance - // readability. Each underscore must be surrounded by at least one digit on - // each side. - hasBefore := false - - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscore - } - } - hasBefore = isDigitRune(r) +func checkAndRemoveUnderscores(b []byte) ([]byte, error) { + if len(b) == 0 { + return b, nil } - return nil -} - -func hexNumberContainsInvalidUnderscore(value string) error { - hasBefore := false - - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscoreHex - } - } - hasBefore = isHexDigit(r) + if b[0] == '_' { + return nil, newDecodeError(b[0:1], "number cannot start with underscore") } - return nil + if b[len(b)-1] == '_' { + return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore") + } + + // fast path + i := 0 + for ; i < len(b); i++ { + if b[i] == '_' { + break + } + } + if i == len(b) { + return b, nil + } + + before := false + cleaned := make([]byte, i, len(b)) + copy(cleaned, b) + + for i++; i < len(b); i++ { + c := b[i] + if c == '_' { + if !before { + return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") + } + before = false + } else { + before = true + cleaned = append(cleaned, c) + } + } + + return cleaned, nil } - -func cleanupNumberToken(value string) string { - cleanedVal := strings.ReplaceAll(value, "_", "") - - return cleanedVal -} - -func isHexDigit(r rune) bool { - return isDigitRune(r) || - (r >= 'a' && r <= 'f') || - (r >= 'A' && r <= 'F') -} - -func isDigitRune(r rune) bool { - return r >= '0' && r <= '9' -} - -var ( - errInvalidUnderscore = errors.New("invalid use of _ in number") - errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") -) diff --git a/errors.go b/errors.go index 7773796..82c21ef 100644 --- a/errors.go +++ b/errors.go @@ -70,13 +70,13 @@ func (de *decodeError) Error() string { func newDecodeError(highlight []byte, format string, args ...interface{}) error { return &decodeError{ highlight: highlight, - message: fmt.Sprintf(format, args...), + message: fmt.Errorf(format, args...).Error(), } } // Error returns the error message contained in the DecodeError. func (e *DecodeError) Error() string { - return e.message + return "toml: " + e.message } // String returns the human-readable contextualized error. This string is multi-line. diff --git a/internal/tracker/seen.go b/internal/tracker/seen.go index f80052d..e245aac 100644 --- a/internal/tracker/seen.go +++ b/internal/tracker/seen.go @@ -123,10 +123,10 @@ func (s *SeenTracker) checkTable(node ast.Node) error { i, found := s.current.Has(k) if found { if i.kind != tableKind { - return fmt.Errorf("key %s should be a table", k) + return fmt.Errorf("toml: key %s should be a table, not a %s", k, i.kind) } if i.explicit { - return fmt.Errorf("table %s already exists", k) + return fmt.Errorf("toml: table %s already exists", k) } i.explicit = true s.current = i @@ -162,7 +162,7 @@ func (s *SeenTracker) checkArrayTable(node ast.Node) error { info, found := s.current.Has(k) if found { if info.kind != arrayTableKind { - return fmt.Errorf("key %s already exists but is not an array table", k) + return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", info.kind, k) } info.Clear() } else { @@ -182,7 +182,7 @@ func (s *SeenTracker) checkKeyValue(context *info, node ast.Node) error { child, found := context.Has(k) if found { if child.kind != tableKind { - return fmt.Errorf("expected %s to be a table, not a %s", k, child.kind) + return fmt.Errorf("toml: expected %s to be a table, not a %s", k, child.kind) } } else { child = context.CreateTable(k, false) diff --git a/localtime.go b/localtime.go index f271bef..a947044 100644 --- a/localtime.go +++ b/localtime.go @@ -53,7 +53,7 @@ func LocalDateOf(t time.Time) LocalDate { func ParseLocalDate(s string) (LocalDate, error) { t, err := time.Parse("2006-01-02", s) if err != nil { - return LocalDate{}, fmt.Errorf("ParseLocalDate: %w", err) + return LocalDate{}, err } return LocalDateOf(t), nil @@ -166,7 +166,7 @@ func LocalTimeOf(t time.Time) LocalTime { func ParseLocalTime(s string) (LocalTime, error) { t, err := time.Parse("15:04:05.999999999", s) if err != nil { - return LocalTime{}, fmt.Errorf("ParseLocalTime: %w", err) + return LocalTime{}, err } return LocalTimeOf(t), nil @@ -237,7 +237,7 @@ func ParseLocalDateTime(s string) (LocalDateTime, error) { if err != nil { t, err = time.Parse("2006-01-02t15:04:05.999999999", s) if err != nil { - return LocalDateTime{}, fmt.Errorf("ParseLocalDateTime: %w", err) + return LocalDateTime{}, err } } diff --git a/marshaler.go b/marshaler.go index 172c033..aa8783a 100644 --- a/marshaler.go +++ b/marshaler.go @@ -3,7 +3,6 @@ package toml import ( "bytes" "encoding" - "errors" "fmt" "io" "reflect" @@ -116,12 +115,12 @@ func (enc *Encoder) Encode(v interface{}) error { b, err := enc.encode(b, ctx, reflect.ValueOf(v)) if err != nil { - return fmt.Errorf("Encode: %w", err) + return err } _, err = enc.w.Write(b) if err != nil { - return fmt.Errorf("Encode: %w", err) + return fmt.Errorf("toml: cannot write: %w", err) } return nil @@ -178,11 +177,6 @@ func (ctx *encoderCtx) isRoot() bool { return len(ctx.parentKey) == 0 && !ctx.hasKey } -var ( - errUnsupportedValue = errors.New("unsupported encode value kind") - errTextMarshalerCannotBeAtRoot = errors.New("type implementing TextMarshaler cannot be at root") -) - //nolint:cyclop,funlen func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { i, ok := v.Interface().(time.Time) @@ -192,12 +186,12 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e if v.Type().Implements(textMarshalerType) { if ctx.isRoot() { - return nil, errTextMarshalerCannotBeAtRoot + return nil, fmt.Errorf("toml: type %s implementing the TextMarshaler interface cannot be a root element", v.Type()) } text, err := v.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { - return nil, fmt.Errorf("encode: %w", err) + return nil, err } b = enc.encodeString(b, string(text), ctx.options) @@ -215,7 +209,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e return enc.encodeSlice(b, ctx, v) case reflect.Interface: if v.IsNil() { - return nil, errNilInterface + return nil, fmt.Errorf("toml: encoding a nil interface is not supported") } return enc.encode(b, ctx, v.Elem()) @@ -244,7 +238,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: b = strconv.AppendInt(b, v.Int(), 10) default: - return nil, fmt.Errorf("encode(type %s): %w", v.Kind(), errUnsupportedValue) + return nil, fmt.Errorf("toml: cannot encode value of type %s", v.Kind()) } return b, nil @@ -412,8 +406,6 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error) return b, nil } -var errTomlNoMultiline = errors.New("TOML does not support multiline keys") - //nolint:cyclop func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { needsQuotation := false @@ -425,7 +417,7 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { } if c == '\n' { - return nil, errTomlNoMultiline + return nil, fmt.Errorf("toml: new line characters in keys are not supported") } if c == literalQuote { @@ -445,11 +437,9 @@ func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { } } -var errNotSupportedAsMapKey = errors.New("type not supported as map key") - func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { if v.Type().Key().Kind() != reflect.String { - return nil, fmt.Errorf("encodeMap '%s': %w", v.Type().Key().Kind(), errNotSupportedAsMapKey) + return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind()) } var ( @@ -658,10 +648,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte return b, nil } -var ( - errNilInterface = errors.New("nil interface not supported") - textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() -) +var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { if v.Type() == timeType || v.Type().Implements(textMarshalerType) { @@ -674,7 +661,7 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { return !ctx.inline, nil case reflect.Interface: if v.IsNil() { - return false, errNilInterface + return false, fmt.Errorf("toml: encoding a nil interface is not supported") } return willConvertToTable(ctx, v.Elem()) @@ -694,7 +681,7 @@ func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) (bool, erro if t.Kind() == reflect.Interface { if v.IsNil() { - return false, errNilInterface + return false, fmt.Errorf("toml: encoding a nil interface is not supported") } return willConvertToTableOrArrayTable(ctx, v.Elem()) diff --git a/parser.go b/parser.go index 9724190..5b6e3ba 100644 --- a/parser.go +++ b/parser.go @@ -2,7 +2,6 @@ package toml import ( "bytes" - "errors" "fmt" "strconv" @@ -225,19 +224,13 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { return ref, b, err } -var ( - errExpectedValNotEOF = errors.New("expected value, not eof") - errExpectedTrue = errors.New("expected 'true'") - errExpectedFalse = errors.New("expected 'false'") -) - //nolint:cyclop,funlen func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { // val = string / boolean / array / inline-table / date-time / float / integer var ref ast.Reference if len(b) == 0 { - return ref, nil, errExpectedValNotEOF + return ref, nil, newDecodeError(b, "expected value, not eof") } var err error @@ -278,7 +271,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { return ref, b, err case 't': if !scanFollowsTrue(b) { - return ref, nil, errExpectedTrue + return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'") } ref = p.builder.Push(ast.Node{ @@ -289,7 +282,7 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { return ref, b[4:], nil case 'f': if !scanFollowsFalse(b) { - return ast.Reference{}, nil, errExpectedFalse + return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'") } ref = p.builder.Push(ast.Node{ @@ -307,6 +300,13 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { } } +func atmost(b []byte, n int) []byte { + if n >= len(b) { + return b + } + return b[:n] +} + func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { v, rest, err := scanLiteralString(b) if err != nil { @@ -370,8 +370,6 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { return parent, rest, err } -var errArrayCannotStartWithComma = errors.New("array cannot start with comma") - //nolint:funlen,cyclop func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { // array = array-open [ array-values ] ws-comment-newline array-close @@ -409,7 +407,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { if b[0] == ',' { if first { - return parent, nil, errArrayCannotStartWithComma + return parent, nil, newDecodeError(b[0:1], "array cannot start with comma") } b = b[1:] @@ -494,8 +492,6 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) { return token[i : len(token)-3], rest, err } -var errInvalidEscapeChar = errors.New("invalid escaped character") - //nolint:funlen,gocognit,cyclop func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body @@ -582,7 +578,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { builder.WriteString(x) i += 8 default: - return nil, nil, fmt.Errorf("parseMultilineBasicString: %w - %#U", errInvalidEscapeChar, c) + return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) } } else { builder.WriteByte(c) @@ -721,7 +717,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { builder.WriteString(x) i += 8 default: - return nil, nil, fmt.Errorf("parseBasicString: %w - %#U", errInvalidEscapeChar, c) + return nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) } } else { builder.WriteByte(c) @@ -731,18 +727,17 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { return builder.Bytes(), rest, nil } -var errUnicodePointNeedsRightCountChar = errors.New("unicode point needs right number of hex characters") - func hexToString(b []byte, length int) (string, error) { if len(b) < length { - return "", fmt.Errorf("hexToString: %w - %d", errUnicodePointNeedsRightCountChar, length) + return "", newDecodeError(b, "unicode point needs %d character, not %d", length, len(b)) } + b = b[:length] //nolint:godox // TODO: slow - intcode, err := strconv.ParseInt(string(b[:length]), 16, 32) + intcode, err := strconv.ParseInt(string(b), 16, 32) if err != nil { - return "", fmt.Errorf("hexToString: %w", err) + return "", newDecodeError(b, "couldn't parse hexadecimal number: %w", err) } return string(rune(intcode)), nil @@ -757,17 +752,12 @@ func (p *parser) parseWhitespace(b []byte) []byte { return rest } -var ( - errExpectedInf = errors.New("expected 'inf'") - errExpectedNan = errors.New("expected 'nan'") -) - //nolint:cyclop func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) { switch b[0] { case 'i': if !scanFollowsInf(b) { - return ast.Reference{}, nil, errExpectedInf + return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'inf'") } return p.builder.Push(ast.Node{ @@ -776,7 +766,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err }), b[3:], nil case 'n': if !scanFollowsNan(b) { - return ast.Reference{}, nil, errExpectedNan + return ast.Reference{}, nil, newDecodeError(atmost(b, 3), "expected 'nan'") } return p.builder.Push(ast.Node{ @@ -821,8 +811,6 @@ func digitsToInt(b []byte) int { return x } -var errTimezoneButNoTimeComponent = errors.New("possible DateTime cannot have a timezone but no time component") - //nolint:gocognit,cyclop func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { // scans for contiguous characters in [0-9T:Z.+-], and up to one space if @@ -867,7 +855,7 @@ byteLoop: } } else { if hasTz { - return ast.Reference{}, nil, errTimezoneButNoTimeComponent + return ast.Reference{}, nil, newDecodeError(b, "date-time has timezone but not time component") } kind = ast.LocalDate } @@ -878,12 +866,6 @@ byteLoop: }), b[i:], nil } -var ( - errUnexpectedCharI = fmt.Errorf("unexpected character i while scanning for a number") - errUnexpectedCharN = fmt.Errorf("unexpected character n while scanning for a number") - errExpectedIntOrFloat = fmt.Errorf("expected integer or float") -) - //nolint:funlen,gocognit,cyclop func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { i := 0 @@ -940,7 +922,7 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { }), b[i+3:], nil } - return ast.Reference{}, nil, errUnexpectedCharI + return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number") } if c == 'n' { @@ -951,14 +933,14 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { }), b[i+3:], nil } - return ast.Reference{}, nil, errUnexpectedCharN + return ast.Reference{}, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number") } break } if i == 0 { - return ast.Reference{}, b, errExpectedIntOrFloat + return ast.Reference{}, b, newDecodeError(b, "incomplete number") } kind := ast.Integer diff --git a/scanner.go b/scanner.go index b63ccb4..047aef5 100644 --- a/scanner.go +++ b/scanner.go @@ -1,9 +1,5 @@ package toml -import ( - "errors" -) - func scanFollows(b []byte, pattern string) bool { n := len(pattern) @@ -83,22 +79,17 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) } -var ( - errWindowsNewLineMissing = errors.New(`windows new line missing \n`) - errWindowsNewLineCRLF = errors.New(`windows new line should be \r\n`) -) - func scanWindowsNewline(b []byte) ([]byte, []byte, error) { - const lenLF = 2 - if len(b) < lenLF { - return nil, nil, errWindowsNewLineMissing + const lenCRLF = 2 + if len(b) < lenCRLF { + return nil, nil, newDecodeError(b, "windows new line expected") } if b[1] != '\n' { - return nil, nil, errWindowsNewLineCRLF + return nil, nil, newDecodeError(b, `windows new line should be \r\n`) } - return b[:lenLF], b[lenLF:], nil + return b[:lenCRLF], b[lenCRLF:], nil } func scanWhitespace(b []byte) ([]byte, []byte) { @@ -116,8 +107,6 @@ func scanWhitespace(b []byte) ([]byte, []byte) { //nolint:unparam func scanComment(b []byte) ([]byte, []byte) { - // ;; Comment - // // comment-start-symbol = %x23 ; # // non-ascii = %x80-D7FF / %xE000-10FFFF // non-eol = %x09 / %x20-7F / non-ascii @@ -132,10 +121,6 @@ func scanComment(b []byte) ([]byte, []byte) { return b, nil } -var errBasicLineNotTerminatedByQuote = errors.New(`basic string not terminated by "`) - -//nolint:godox -// TODO perform validation on the string? func scanBasicString(b []byte) ([]byte, []byte, error) { // basic-string = quotation-mark *basic-char quotation-mark // quotation-mark = %x22 ; " @@ -156,11 +141,9 @@ func scanBasicString(b []byte) ([]byte, []byte, error) { } } - return nil, nil, errBasicLineNotTerminatedByQuote + return nil, nil, newDecodeError(b[len(b):], `basic string not terminated by "`) } -//nolint:godox -// TODO perform validation on the string? func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string-delim diff --git a/targets.go b/targets.go index 68f9428..fd3eaec 100644 --- a/targets.go +++ b/targets.go @@ -1,7 +1,6 @@ package toml import ( - "errors" "fmt" "math" "reflect" @@ -14,19 +13,19 @@ type target interface { get() reflect.Value // Store a string at the target. - setString(v string) error + setString(v string) // Store a boolean at the target - setBool(v bool) error + setBool(v bool) // Store an int64 at the target - setInt64(v int64) error + setInt64(v int64) // Store a float64 at the target - setFloat64(v float64) error + setFloat64(v float64) // Stores any value at the target - set(v reflect.Value) error + set(v reflect.Value) } // valueTarget just contains a reflect.Value that can be set. @@ -37,34 +36,24 @@ func (t valueTarget) get() reflect.Value { return reflect.Value(t) } -func (t valueTarget) set(v reflect.Value) error { +func (t valueTarget) set(v reflect.Value) { reflect.Value(t).Set(v) - - return nil } -func (t valueTarget) setString(v string) error { +func (t valueTarget) setString(v string) { t.get().SetString(v) - - return nil } -func (t valueTarget) setBool(v bool) error { +func (t valueTarget) setBool(v bool) { t.get().SetBool(v) - - return nil } -func (t valueTarget) setInt64(v int64) error { +func (t valueTarget) setInt64(v int64) { t.get().SetInt(v) - - return nil } -func (t valueTarget) setFloat64(v float64) error { +func (t valueTarget) setFloat64(v float64) { t.get().SetFloat(v) - - return nil } // interfaceTarget wraps an other target to dereference on get. @@ -76,49 +65,24 @@ func (t interfaceTarget) get() reflect.Value { return t.x.get().Elem() } -func (t interfaceTarget) set(v reflect.Value) error { - err := t.x.set(v) - if err != nil { - return fmt.Errorf("interfaceTarget set: %w", err) - } - - return nil +func (t interfaceTarget) set(v reflect.Value) { + t.x.set(v) } -func (t interfaceTarget) setString(v string) error { - err := t.x.setString(v) - if err != nil { - return fmt.Errorf("interfaceTarget setString: %w", err) - } - - return nil +func (t interfaceTarget) setString(v string) { + t.x.setString(v) } -func (t interfaceTarget) setBool(v bool) error { - err := t.x.setBool(v) - if err != nil { - return fmt.Errorf("interfaceTarget setBool: %w", err) - } - - return nil +func (t interfaceTarget) setBool(v bool) { + t.x.setBool(v) } -func (t interfaceTarget) setInt64(v int64) error { - err := t.x.setInt64(v) - if err != nil { - return fmt.Errorf("interfaceTarget setInt64: %w", err) - } - - return nil +func (t interfaceTarget) setInt64(v int64) { + t.x.setInt64(v) } -func (t interfaceTarget) setFloat64(v float64) error { - err := t.x.setFloat64(v) - if err != nil { - return fmt.Errorf("interfaceTarget setFloat64: %w", err) - } - - return nil +func (t interfaceTarget) setFloat64(v float64) { + t.x.setFloat64(v) } // mapTarget targets a specific key of a map. @@ -131,33 +95,26 @@ func (t mapTarget) get() reflect.Value { return t.v.MapIndex(t.k) } -func (t mapTarget) set(v reflect.Value) error { +func (t mapTarget) set(v reflect.Value) { t.v.SetMapIndex(t.k, v) - - return nil } -func (t mapTarget) setString(v string) error { - return t.set(reflect.ValueOf(v)) +func (t mapTarget) setString(v string) { + t.set(reflect.ValueOf(v)) } -func (t mapTarget) setBool(v bool) error { - return t.set(reflect.ValueOf(v)) +func (t mapTarget) setBool(v bool) { + t.set(reflect.ValueOf(v)) } -func (t mapTarget) setInt64(v int64) error { - return t.set(reflect.ValueOf(v)) +func (t mapTarget) setInt64(v int64) { + t.set(reflect.ValueOf(v)) } -func (t mapTarget) setFloat64(v float64) error { - return t.set(reflect.ValueOf(v)) +func (t mapTarget) setFloat64(v float64) { + t.set(reflect.ValueOf(v)) } -var ( - errValIndexExpectingSlice = errors.New("expecting a slice") - errValIndexCannotInitSlice = errors.New("cannot initialize a slice") -) - //nolint:cyclop // makes sure that the value pointed at by t is indexable (Slice, Array), or // dereferences to an indexable (Ptr, Interface). @@ -167,43 +124,20 @@ func ensureValueIndexable(t target) error { switch f.Type().Kind() { case reflect.Slice: if f.IsNil() { - err := t.set(reflect.MakeSlice(f.Type(), 0, 0)) - if err != nil { - return fmt.Errorf("ensureValueIndexable: %w", err) - } - + t.set(reflect.MakeSlice(f.Type(), 0, 0)) return nil } case reflect.Interface: if f.IsNil() || f.Elem().Type() != sliceInterfaceType { - err := t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) - if err != nil { - return fmt.Errorf("ensureValueIndexable: %w", err) - } - + t.set(reflect.MakeSlice(sliceInterfaceType, 0, 0)) return nil } - - if f.Elem().Type().Kind() != reflect.Slice { - return fmt.Errorf("ensureValueIndexable: %w, not a %s", errValIndexExpectingSlice, f.Kind()) - } case reflect.Ptr: - if f.IsNil() { - ptr := reflect.New(f.Type().Elem()) - - err := t.set(ptr) - if err != nil { - return fmt.Errorf("ensureValueIndexable: %w", err) - } - - f = t.get() - } - - return ensureValueIndexable(valueTarget(f.Elem())) + panic("pointer should have already been dereferenced") case reflect.Array: // arrays are always initialized. default: - return fmt.Errorf("ensureValueIndexable: %w with %s", errValIndexCannotInitSlice, f.Kind()) + return fmt.Errorf("toml: cannot store array in a %s", f.Kind()) } return nil @@ -214,69 +148,44 @@ var ( mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) ) -func ensureMapIfInterface(x target) error { +func ensureMapIfInterface(x target) { v := x.get() if v.Kind() == reflect.Interface && v.IsNil() { newElement := reflect.MakeMap(mapStringInterfaceType) - err := x.set(newElement) - if err != nil { - return fmt.Errorf("ensureMapIfInterface: %w", err) - } + x.set(newElement) } - - return nil } -var errSetStringCannotAssignString = errors.New("cannot assign string") - func setString(t target, v string) error { f := t.get() switch f.Kind() { case reflect.String: - err := t.setString(v) - if err != nil { - return fmt.Errorf("setString: %w", err) - } - - return nil + t.setString(v) case reflect.Interface: - err := t.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setString: %w", err) - } - - return nil + t.set(reflect.ValueOf(v)) default: - return fmt.Errorf("setString: %w to a %s", errSetStringCannotAssignString, f.Kind()) + return fmt.Errorf("toml: cannot assign string to a %s", f.Kind()) } -} -var errSetBoolCannotAssignBool = errors.New("cannot assign bool") + return nil +} func setBool(t target, v bool) error { f := t.get() switch f.Kind() { case reflect.Bool: - err := t.setBool(v) - if err != nil { - return fmt.Errorf("setBool: %w", err) - } - - return nil + t.setBool(v) case reflect.Interface: - err := t.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setBool: %w", err) - } - - return nil + t.set(reflect.ValueOf(v)) default: - return fmt.Errorf("setBool: %w to a %s", errSetBoolCannotAssignBool, f.String()) + return fmt.Errorf("toml: cannot assign boolean to a %s", f.Kind()) } + + return nil } const ( @@ -284,207 +193,104 @@ const ( minInt = -maxInt - 1 ) -var ( - errSetInt64InInt32 = errors.New("does not fit in an int32") - errSetInt64InInt16 = errors.New("does not fit in an int16") - errSetInt64InInt8 = errors.New("does not fit in an int8") - errSetInt64InInt = errors.New("does not fit in an int") - errSetInt64InUint64 = errors.New("negative integer does not fit in an uint64") - errSetInt64InUint32 = errors.New("negative integer does not fit in an uint32") - errSetInt64InUint32Max = errors.New("integer does not fit in an uint32") - errSetInt64InUint16 = errors.New("negative integer does not fit in an uint16") - errSetInt64InUint16Max = errors.New("integer does not fit in an uint16") - errSetInt64InUint8 = errors.New("negative integer does not fit in an uint8") - errSetInt64InUint8Max = errors.New("integer does not fit in an uint8") - errSetInt64InUint = errors.New("negative integer does not fit in an uint") - errSetInt64Unknown = errors.New("does not fit in an uint") -) - //nolint:funlen,gocognit,cyclop,gocyclo func setInt64(t target, v int64) error { f := t.get() switch f.Kind() { case reflect.Int64: - err := t.setInt64(v) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.setInt64(v) case reflect.Int32: if v < math.MinInt32 || v > math.MaxInt32 { - return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt32) - } - - err := t.set(reflect.ValueOf(int32(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) + return fmt.Errorf("toml: number %d does not fit in an int32", v) } + t.set(reflect.ValueOf(int32(v))) return nil case reflect.Int16: if v < math.MinInt16 || v > math.MaxInt16 { - return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt16) + return fmt.Errorf("toml: number %d does not fit in an int16", v) } - err := t.set(reflect.ValueOf(int16(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(int16(v))) case reflect.Int8: if v < math.MinInt8 || v > math.MaxInt8 { - return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt8) + return fmt.Errorf("toml: number %d does not fit in an int8", v) } - err := t.set(reflect.ValueOf(int8(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(int8(v))) case reflect.Int: if v < minInt || v > maxInt { - return fmt.Errorf("setInt64: integer %d %w", v, errSetInt64InInt) + return fmt.Errorf("toml: number %d does not fit in an int", v) } - err := t.set(reflect.ValueOf(int(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(int(v))) case reflect.Uint64: if v < 0 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint64) + return fmt.Errorf("toml: negative number %d does not fit in an uint64", v) } - err := t.set(reflect.ValueOf(uint64(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(uint64(v))) case reflect.Uint32: - if v < 0 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32) + if v < 0 || v > math.MaxUint32 { + return fmt.Errorf("toml: negative number %d does not fit in an uint32", v) } - if v > math.MaxUint32 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint32Max) - } - - err := t.set(reflect.ValueOf(uint32(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(uint32(v))) case reflect.Uint16: - if v < 0 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16) + if v < 0 || v > math.MaxUint16 { + return fmt.Errorf("toml: negative number %d does not fit in an uint16", v) } - if v > math.MaxUint16 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint16Max) - } - - err := t.set(reflect.ValueOf(uint16(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(uint16(v))) case reflect.Uint8: - if v < 0 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8) + if v < 0 || v > math.MaxUint8 { + return fmt.Errorf("toml: negative number %d does not fit in an uint8", v) } - if v > math.MaxUint8 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint8Max) - } - - err := t.set(reflect.ValueOf(uint8(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(uint8(v))) case reflect.Uint: if v < 0 { - return fmt.Errorf("setInt64: %d, %w", v, errSetInt64InUint) + return fmt.Errorf("toml: negative number %d does not fit in an uint", v) } - err := t.set(reflect.ValueOf(uint(v))) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(uint(v))) case reflect.Interface: - err := t.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setInt64: %w", err) - } - - return nil + t.set(reflect.ValueOf(v)) default: - return fmt.Errorf("setInt64: %s, %w", f.String(), errSetInt64Unknown) + return fmt.Errorf("toml: integer cannot be assigned to %s", f.Kind()) } -} -var ( - errSetFloat64InFloat32Max = errors.New("does not fit in an float32") - errSetFloat64Unknown = errors.New("does not fit in an float32") -) + return nil +} func setFloat64(t target, v float64) error { f := t.get() switch f.Kind() { case reflect.Float64: - err := t.setFloat64(v) - if err != nil { - return fmt.Errorf("setFloat64: %w", err) - } - - return nil + t.setFloat64(v) case reflect.Float32: if v > math.MaxFloat32 { - return fmt.Errorf("setFloat64: %f %w", v, errSetFloat64InFloat32Max) + return fmt.Errorf("toml: number %f does not fit in a float32", v) } - err := t.set(reflect.ValueOf(float32(v))) - if err != nil { - return fmt.Errorf("setFloat64: %w", err) - } - - return nil + t.set(reflect.ValueOf(float32(v))) case reflect.Interface: - err := t.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setFloat64: %w", err) - } - - return nil + t.set(reflect.ValueOf(v)) default: - return fmt.Errorf("setFloat64: %s %w", f.String(), errSetFloat64Unknown) + return fmt.Errorf("toml: float cannot be assigned to %s", f.Kind()) } -} -var ( - errElementAtCannotOn = errors.New("cannot elementAt") - errElementAtCannotOnUnknown = errors.New("cannot elementAt") -) + return nil +} //nolint:cyclop // Returns the element at idx of the value pointed at by target, or an error if // t does not point to an indexable. // If the target points to an Array and idx is out of bounds, it returns // (nil, nil) as this is not a fatal error (the unmarshaler will skip). -func elementAt(t target, idx int) (target, error) { +func elementAt(t target, idx int) target { f := t.get() switch f.Kind() { @@ -493,42 +299,30 @@ func elementAt(t target, idx int) (target, error) { // TODO: use the idx function argument and avoid alloc if possible. idx := f.Len() - err := t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) - if err != nil { - return nil, fmt.Errorf("elementAt: %w", err) - } + t.set(reflect.Append(f, reflect.New(f.Type().Elem()).Elem())) - return valueTarget(t.get().Index(idx)), nil + return valueTarget(t.get().Index(idx)) case reflect.Array: if idx >= f.Len() { - return nil, nil + return nil } - return valueTarget(f.Index(idx)), nil + return valueTarget(f.Index(idx)) case reflect.Interface: - if f.IsNil() { - panic("interface should have been initialized") - } + // This function is called after ensureValueIndexable, so it's + // guaranteed that f contains an initialized slice. ifaceElem := f.Elem() - if ifaceElem.Kind() != reflect.Slice { - return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOn, f.Kind()) - } - idx := ifaceElem.Len() newElem := reflect.New(ifaceElem.Type().Elem()).Elem() newSlice := reflect.Append(ifaceElem, newElem) - err := t.set(newSlice) - if err != nil { - return nil, fmt.Errorf("elementAt: %w", err) - } + t.set(newSlice) - return valueTarget(t.get().Elem().Index(idx)), nil - case reflect.Ptr: - return elementAt(valueTarget(f.Elem()), idx) + return valueTarget(t.get().Elem().Index(idx)) default: - return nil, fmt.Errorf("elementAt: %w on a %s", errElementAtCannotOnUnknown, f.Kind()) + // Why ensureValueIndexable let it go through? + panic(fmt.Errorf("elementAt received unhandled value type: %s", f.Kind())) } } @@ -539,31 +333,19 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta switch x.Kind() { // Kinds that need to recurse case reflect.Interface: - t, err := scopeInterface(shouldAppend, t) - if err != nil { - return t, false, fmt.Errorf("scopeTableTarget: %w", err) - } - + t := scopeInterface(shouldAppend, t) return d.scopeTableTarget(shouldAppend, t, name) case reflect.Ptr: - t, err := scopePtr(t) - if err != nil { - return t, false, fmt.Errorf("scopeTableTarget: %w", err) - } - + t := scopePtr(t) return d.scopeTableTarget(shouldAppend, t, name) case reflect.Slice: - t, err := scopeSlice(shouldAppend, t) - if err != nil { - return t, false, fmt.Errorf("scopeTableTarget: %w", err) - } + t := scopeSlice(shouldAppend, t) shouldAppend = false - return d.scopeTableTarget(shouldAppend, t, name) case reflect.Array: t, err := d.scopeArray(shouldAppend, t) if err != nil { - return t, false, fmt.Errorf("scopeTableTarget: %w", err) + return t, false, err } shouldAppend = false @@ -574,11 +356,7 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta return scopeStruct(x, name) case reflect.Map: if x.IsNil() { - err := t.set(reflect.MakeMap(x.Type())) - if err != nil { - return t, false, fmt.Errorf("scopeTableTarget: %w", err) - } - + t.set(reflect.MakeMap(x.Type())) x = t.get() } @@ -588,42 +366,29 @@ func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (ta } } -func scopeInterface(shouldAppend bool, t target) (target, error) { - err := initInterface(shouldAppend, t) - if err != nil { - return t, err - } - - return interfaceTarget{t}, nil +func scopeInterface(shouldAppend bool, t target) target { + initInterface(shouldAppend, t) + return interfaceTarget{t} } -func scopePtr(t target) (target, error) { - err := initPtr(t) - if err != nil { - return t, err - } - - return valueTarget(t.get().Elem()), nil +func scopePtr(t target) target { + initPtr(t) + return valueTarget(t.get().Elem()) } -func initPtr(t target) error { +func initPtr(t target) { x := t.get() if !x.IsNil() { - return nil + return } - err := t.set(reflect.New(x.Type().Elem())) - if err != nil { - return fmt.Errorf("initPtr: %w", err) - } - - return nil + t.set(reflect.New(x.Type().Elem())) } // initInterface makes sure that the interface pointed at by the target is not // nil. // Returns the target to the initialized value of the target. -func initInterface(shouldAppend bool, t target) error { +func initInterface(shouldAppend bool, t target) { x := t.get() if x.Kind() != reflect.Interface { @@ -631,7 +396,7 @@ func initInterface(shouldAppend bool, t target) error { } if !x.IsNil() && (x.Elem().Type() == sliceInterfaceType || x.Elem().Type() == mapStringInterfaceType) { - return nil + return } var newElement reflect.Value @@ -641,55 +406,43 @@ func initInterface(shouldAppend bool, t target) error { newElement = reflect.MakeMap(mapStringInterfaceType) } - err := t.set(newElement) - if err != nil { - return fmt.Errorf("initInterface: %w", err) - } - - return nil + t.set(newElement) } -func scopeSlice(shouldAppend bool, t target) (target, error) { +func scopeSlice(shouldAppend bool, t target) target { v := t.get() if shouldAppend { newElem := reflect.New(v.Type().Elem()) newSlice := reflect.Append(v, newElem.Elem()) - err := t.set(newSlice) - if err != nil { - return t, fmt.Errorf("scopeSlice: %w", err) - } + t.set(newSlice) v = t.get() } - return valueTarget(v.Index(v.Len() - 1)), nil + return valueTarget(v.Index(v.Len() - 1)) } -var errScopeArrayNotEnoughSpace = errors.New("not enough space in the array") - func (d *decoder) scopeArray(shouldAppend bool, t target) (target, error) { v := t.get() idx := d.arrayIndex(shouldAppend, v) if idx >= v.Len() { - return nil, errScopeArrayNotEnoughSpace + return nil, fmt.Errorf("toml: impossible to insert element beyond array's size: %d", v.Len()) } return valueTarget(v.Index(idx)), nil } -var errScopeMapCannotConvertStringToKey = errors.New("cannot convert string into map key type") - func scopeMap(v reflect.Value, name string) (target, bool, error) { k := reflect.ValueOf(name) keyType := v.Type().Key() if !k.Type().AssignableTo(keyType) { if !k.Type().ConvertibleTo(keyType) { - return nil, false, fmt.Errorf("scopeMap: %w %s", errScopeMapCannotConvertStringToKey, keyType) + return nil, false, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", k.Type(), keyType) } k = k.Convert(keyType) diff --git a/targets_test.go b/targets_test.go index 7b57fe0..c895ad5 100644 --- a/targets_test.go +++ b/targets_test.go @@ -128,14 +128,12 @@ func TestPushNew(t *testing.T) { x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) - n, err := elementAt(x, 0) - require.NoError(t, err) - require.NoError(t, n.setString("hello")) + n := elementAt(x, 0) + n.setString("hello") require.Equal(t, []string{"hello"}, d.A) - n, err = elementAt(x, 1) - require.NoError(t, err) - require.NoError(t, n.setString("world")) + n = elementAt(x, 1) + n.setString("world") require.Equal(t, []string{"hello", "world"}, d.A) }) @@ -151,13 +149,11 @@ func TestPushNew(t *testing.T) { x, _, err := dec.scopeTableTarget(false, valueTarget(reflect.ValueOf(&d).Elem()), "A") require.NoError(t, err) - n, err := elementAt(x, 0) - require.NoError(t, err) + n := elementAt(x, 0) require.NoError(t, setString(n, "hello")) require.Equal(t, []interface{}{"hello"}, d.A) - n, err = elementAt(x, 1) - require.NoError(t, err) + n = elementAt(x, 1) require.NoError(t, setString(n, "world")) require.Equal(t, []interface{}{"hello", "world"}, d.A) }) diff --git a/toml_testgen_support_test.go b/toml_testgen_support_test.go index e2617e6..071ea6d 100644 --- a/toml_testgen_support_test.go +++ b/toml_testgen_support_test.go @@ -94,12 +94,7 @@ func testGenTranslateDesc(input interface{}) interface{} { if ok { dvalue, ok = d["value"] if ok { - var okdt bool - - dtype, okdt = dtypeiface.(string) - if !okdt { - panic(fmt.Sprintf("dtypeiface should be valid string: %v", dtypeiface)) - } + dtype = dtypeiface.(string) switch dtype { case "string": @@ -132,10 +127,7 @@ func testGenTranslateDesc(input interface{}) interface{} { return nil } - a, oka := dvalue.([]interface{}) - if !oka { - panic(fmt.Sprintf("a should be valid []interface{}: %v", a)) - } + a := dvalue.([]interface{}) xs := make([]interface{}, len(a)) diff --git a/unmarshaler.go b/unmarshaler.go index 47a9aec..3864be6 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -56,7 +56,7 @@ func (d *Decoder) SetStrict(strict bool) { func (d *Decoder) Decode(v interface{}) error { b, err := ioutil.ReadAll(d.r) if err != nil { - return fmt.Errorf("Decode: %w", err) + return fmt.Errorf("toml: %w", err) } p := parser{} @@ -130,20 +130,15 @@ func keyLocation(node ast.Node) []byte { return unsafe.BytesRange(start, end) } -var ( - errFromParserExpectingPointer = errors.New("expecting a pointer as target") - errFromParserExpectingNonNilPointer = errors.New("expecting non nil pointer as target") -) - //nolint:funlen,cyclop func (d *decoder) fromParser(p *parser, v interface{}) error { r := reflect.ValueOf(v) if r.Kind() != reflect.Ptr { - return fmt.Errorf("fromParser: %w, not %s", errFromParserExpectingPointer, r.Kind()) + return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind()) } if r.IsNil() { - return errFromParserExpectingNonNilPointer + return fmt.Errorf("toml: decoding pointer target cannot be nil") } var ( @@ -162,7 +157,7 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { err := d.seen.CheckExpression(node) if err != nil { - return fmt.Errorf("fromParser: %w", err) + return err } var found bool @@ -181,16 +176,13 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { // looks like a table. Otherwise the information // of a table is lost, and marshal cannot do the // round trip. - err := ensureMapIfInterface(current) - if err != nil { - panic(fmt.Sprintf("ensureMapIfInterface: %s", err)) - } + ensureMapIfInterface(current) } case ast.ArrayTable: d.strict.EnterArrayTable(node) current, found, err = d.scopeWithArrayTable(root, node.Key()) default: - panic(fmt.Sprintf("fromParser: this should not be a top level node type: %s", node.Kind)) + panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind)) } if err != nil { @@ -267,26 +259,18 @@ func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool, v := x.get() if v.Kind() == reflect.Ptr { - x, err = scopePtr(x) - if err != nil { - return x, false, err - } - + x = scopePtr(x) v = x.get() } if v.Kind() == reflect.Interface { - x, err = scopeInterface(true, x) - if err != nil { - return x, found, err - } - + x = scopeInterface(true, x) v = x.get() } switch v.Kind() { case reflect.Slice: - x, err = scopeSlice(true, x) + x = scopeSlice(true, x) case reflect.Array: x, err = d.scopeArray(true, x) default: @@ -334,7 +318,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { if v.Type().Implements(textUnmarshalerType) { err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) if err != nil { - return false, fmt.Errorf("tryTextUnmarshaler: %w", err) + return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) } return true, nil @@ -343,7 +327,7 @@ func tryTextUnmarshaler(x target, node ast.Node) (bool, error) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) if err != nil { - return false, fmt.Errorf("tryTextUnmarshaler: %w", err) + return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err) } return true, nil @@ -358,11 +342,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error { if v.Kind() == reflect.Ptr { if !v.Elem().IsValid() { - err := x.set(reflect.New(v.Type().Elem())) - if err != nil { - return fmt.Errorf("unmarshalValue: %w", err) - } - + x.set(reflect.New(v.Type().Elem())) v = x.get() } @@ -394,7 +374,7 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error { case ast.LocalDate: return unmarshalLocalDate(x, node) default: - panic(fmt.Sprintf("unmarshalValue: unhandled unmarshalValue kind %s", node.Kind)) + panic(fmt.Sprintf("unhandled node kind %s", node.Kind)) } } @@ -406,7 +386,9 @@ func unmarshalLocalDate(x target, node ast.Node) error { return err } - return setDate(x, v) + setDate(x, v) + + return nil } func unmarshalLocalDateTime(x target, node ast.Node) error { @@ -421,7 +403,9 @@ func unmarshalLocalDateTime(x target, node ast.Node) error { return newDecodeError(rest, "extra characters at the end of a local date time") } - return setLocalDateTime(x, v) + setLocalDateTime(x, v) + + return nil } func unmarshalDateTime(x target, node ast.Node) error { @@ -432,48 +416,37 @@ func unmarshalDateTime(x target, node ast.Node) error { return err } - return setDateTime(x, v) + setDateTime(x, v) + + return nil } -func setLocalDateTime(x target, v LocalDateTime) error { +func setLocalDateTime(x target, v LocalDateTime) { if x.get().Type() == timeType { cast := v.In(time.Local) - return setDateTime(x, cast) + setDateTime(x, cast) + return } - err := x.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setLocalDateTime: %w", err) - } - - return nil + x.set(reflect.ValueOf(v)) } -func setDateTime(x target, v time.Time) error { - err := x.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setDateTime: %w", err) - } - - return nil +func setDateTime(x target, v time.Time) { + x.set(reflect.ValueOf(v)) } var timeType = reflect.TypeOf(time.Time{}) -func setDate(x target, v LocalDate) error { +func setDate(x target, v LocalDate) { if x.get().Type() == timeType { cast := v.In(time.Local) - return setDateTime(x, cast) + setDateTime(x, cast) + return } - err := x.set(reflect.ValueOf(v)) - if err != nil { - return fmt.Errorf("setDate: %w", err) - } - - return nil + x.set(reflect.ValueOf(v)) } func unmarshalString(x target, node ast.Node) error { @@ -514,10 +487,7 @@ func unmarshalFloat(x target, node ast.Node) error { func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error { assertNode(ast.InlineTable, node) - err := ensureMapIfInterface(x) - if err != nil { - return fmt.Errorf("unmarshalInlineTable: %w", err) - } + ensureMapIfInterface(x) it := node.Children() for it.Next() { @@ -546,10 +516,7 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error { for it.Next() { n := it.Node() - v, err := elementAt(x, idx) - if err != nil { - return err - } + v := elementAt(x, idx) if v == nil { // when we go out of bound for an array just stop processing it to diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 6caaadb..9814cd1 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -38,6 +38,11 @@ func TestUnmarshal_Integers(t *testing.T) { input: `+99`, expected: 99, }, + { + desc: "integer decimal underscore", + input: `123_456`, + expected: 123456, + }, { desc: "integer hex uppercase", input: `0xDEADBEEF`, @@ -58,6 +63,21 @@ func TestUnmarshal_Integers(t *testing.T) { input: `0b11010110`, expected: 0b11010110, }, + { + desc: "double underscore", + input: "12__3", + err: true, + }, + { + desc: "starts with underscore", + input: "_1", + err: true, + }, + { + desc: "ends with underscore", + input: "1_", + err: true, + }, } type doc struct { @@ -71,8 +91,12 @@ func TestUnmarshal_Integers(t *testing.T) { doc := doc{} err := toml.Unmarshal([]byte(`A = `+e.input), &doc) - require.NoError(t, err) - assert.Equal(t, e.expected, doc.A) + if e.err { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, e.expected, doc.A) + } }) } } @@ -799,6 +823,33 @@ B = "data"`, } }, }, + { + desc: "mismatch types int to string", + input: `A = 42`, + gen: func() test { + type S struct { + A string + } + return test{ + target: &S{}, + err: true, + } + }, + }, + { + desc: "mismatch types array of int to interface with non-slice", + input: `A = [[42]]`, + skip: true, + gen: func() test { + type S struct { + A *string + } + return test{ + target: &S{}, + expected: &S{}, + } + }, + }, } for _, e := range examples { @@ -815,6 +866,9 @@ B = "data"`, } err := toml.Unmarshal([]byte(e.input), test.target) if test.err { + if err == nil { + t.Log("=>", test.target) + } require.Error(t, err) } else { require.NoError(t, err) @@ -1030,7 +1084,7 @@ world'`, if e.msg != "" { t.Log("\n" + de.String()) - require.Equal(t, e.msg, de.Error()) + require.Equal(t, "toml: "+e.msg, de.Error()) } }) }