diff --git a/.golangci.toml b/.golangci.toml index 0e71b20..fdf167b 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -24,7 +24,7 @@ enable = [ # "exhaustivestruct", "exportloopref", "forbidigo", - "forcetypeassert", + # "forcetypeassert", "funlen", "gci", # "gochecknoglobals", @@ -35,7 +35,7 @@ enable = [ "gocyclo", "godot", "godox", - "goerr113", + # "goerr113", "gofmt", "gofumpt", "goheader", @@ -57,7 +57,7 @@ enable = [ "nakedret", "nestif", "nilerr", - "nlreturn", + # "nlreturn", "noctx", "nolintlint", "paralleltest", @@ -80,5 +80,5 @@ enable = [ "wastedassign", "whitespace", # "wrapcheck", - "wsl" + # "wsl" ] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e62df92..59658d3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -51,17 +51,17 @@ Want to contribute a patch? Very happy to hear that! First, some high-level rules: -* A short proposal with some POC code is better than a lengthy piece of text +- A short proposal with some POC code is better than a lengthy piece of text with no code. Code speaks louder than words. That being said, bigger changes should probably start with a [discussion][discussions]. -* No backward-incompatible patch will be accepted unless discussed. Sometimes +- No backward-incompatible patch will be accepted unless discussed. Sometimes it's hard, but we try not to break people's programs unless we absolutely have to. -* If you are writing a new feature or extending an existing one, make sure to +- If you are writing a new feature or extending an existing one, make sure to write some documentation. -* Bug fixes need to be accompanied with regression tests. -* New code needs to be tested. -* Your commit messages need to explain why the change is needed, even if already +- Bug fixes need to be accompanied with regression tests. +- New code needs to be tested. +- Your commit messages need to explain why the change is needed, even if already included in the PR description. It does sound like a lot, but those best practices are here to save time overall @@ -129,13 +129,15 @@ Benchmark results should be compared against each other with `new.txt`). 4. Run `benchstat old.txt new.txt` to check that time/op does not go up in any test. - + +On Unix you can use `./ci.sh benchmark -d v2` to verify how your code impacts +performance. + It is highly encouraged to add the benchstat results to your pull request description. Pull requests that lower performance will receive more scrutiny. [benchstat]: https://pkg.go.dev/golang.org/x/perf/cmd/benchstat - ### Style Try to look around and follow the same format and structure as the rest of the @@ -149,10 +151,10 @@ code. We enforce using `go fmt` on the whole code base. Checklist: -* Passing CI. -* Does not introduce backward-incompatible changes (unless discussed). -* Has relevant doc changes. -* Benchstat does not show performance regression. +- Passing CI. +- Does not introduce backward-incompatible changes (unless discussed). +- Has relevant doc changes. +- Benchstat does not show performance regression. 1. Merge using "squash and merge". 2. Make sure to edit the commit message to keep all the useful information diff --git a/ci.sh b/ci.sh index 0e3314c..75d7008 100755 --- a/ci.sh +++ b/ci.sh @@ -25,6 +25,20 @@ USAGE COMMANDS +benchmark [OPTIONS...] [BRANCH] + + Run benchmarks. + + ARGUMENTS + + BRANCH Optional. Defines which Git branch to use when running + benchmarks. + + OPTIONS + + -d Compare benchmarks of HEAD with BRANCH using benchstats. In + this form the BRANCH argument is required. + coverage [OPTIONS...] [BRANCH] Generates code coverage. @@ -50,9 +64,9 @@ cover() { stderr "Executing coverage for ${branch} at ${dir}" if [ "${branch}" = "HEAD" ]; then - cp -r . "${dir}/" + cp -r . "${dir}/" else - git worktree add "$dir" "$branch" + git worktree add "$dir" "$branch" fi pushd "$dir" @@ -61,7 +75,7 @@ cover() { popd if [ "${branch}" != "HEAD" ]; then - git worktree remove --force "$dir" + git worktree remove --force "$dir" fi } @@ -101,7 +115,48 @@ coverage() { cover "${1-HEAD}" } +bench() { + branch="${1}" + out="${2}" + dir="$(mktemp -d)" + + stderr "Executing benchmark for ${branch} at ${dir}" + + if [ "${branch}" = "HEAD" ]; then + cp -r . "${dir}/" + else + git worktree add "$dir" "$branch" + fi + + pushd "$dir" + go test -bench=. -count=10 ./... | tee "${out}" + popd + + if [ "${branch}" != "HEAD" ]; then + git worktree remove --force "$dir" + fi +} + +benchmark() { + case "$1" in + -d) + shift + target="${1?Need to provide a target branch argument}" + old=`mktemp` + bench "${target}" "${old}" + + new=`mktemp` + bench HEAD "${new}" + benchstat "${old}" "${new}" + return 0 + ;; + esac + + bench "${1-HEAD}" `mktemp` +} + case "$1" in coverage) shift; coverage $@;; + benchmark) shift; benchmark $@;; *) usage "bad argument $1";; esac diff --git a/decode.go b/decode.go index afa6db3..33ac2a9 100644 --- a/decode.go +++ b/decode.go @@ -1,6 +1,7 @@ package toml import ( + "fmt" "math" "strconv" "time" @@ -16,7 +17,7 @@ func parseInteger(b []byte) (int64, error) { case 'o': return parseIntOct(b) default: - return 0, newDecodeError(b[1:2], "invalid base: '%c'", b[1]) + panic(fmt.Errorf("invalid base '%c', should have been checked by scanIntOrFloat", b[1])) } } @@ -34,41 +35,26 @@ func parseLocalDate(b []byte) (LocalDate, error) { return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") } - var err error + date.Year = parseDecimalDigits(b[0:4]) - date.Year, err = parseDecimalDigits(b[0:4]) - if err != nil { - return date, err - } - - v, err := parseDecimalDigits(b[5:7]) - if err != nil { - return date, err - } + v := parseDecimalDigits(b[5:7]) date.Month = time.Month(v) - date.Day, err = parseDecimalDigits(b[8:10]) - if err != nil { - return date, err - } + date.Day = parseDecimalDigits(b[8:10]) return date, nil } -func parseDecimalDigits(b []byte) (int, error) { +func parseDecimalDigits(b []byte) int { v := 0 - for i, c := range b { - if !isDigit(c) { - return 0, newDecodeError(b[i:i+1], "should be a digit (0-9)") - } - + for _, c := range b { v *= 10 v += int(c - '0') } - return v, nil + return v } func parseDateTime(b []byte) (time.Time, error) { @@ -77,8 +63,6 @@ func parseDateTime(b []byte) (time.Time, error) { // time-offset = "Z" / time-numoffset // time-numoffset = ( "+" / "-" ) time-hour ":" time-minute - originalBytes := b - dt, b, err := parseLocalDateTime(b) if err != nil { return time.Time{}, err @@ -87,7 +71,8 @@ func parseDateTime(b []byte) (time.Time, error) { var zone *time.Location if len(b) == 0 { - return time.Time{}, newDecodeError(originalBytes, "date-time is missing timezone") + // parser should have checked that when assigning the date time node + panic("date time should have a timezone") } if b[0] == 'Z' { @@ -99,18 +84,15 @@ func parseDateTime(b []byte) (time.Time, error) { return time.Time{}, newDecodeError(b, "invalid date-time timezone") } direction := 1 - switch b[0] { - case '+': - case '-': + if b[0] == '-' { direction = -1 - default: - return time.Time{}, newDecodeError(b[0:1], "invalid timezone offset character") } hours := digitsToInt(b[1:3]) minutes := digitsToInt(b[4:6]) seconds := direction * (hours*3600 + minutes*60) zone = time.FixedZone("", seconds) + b = b[dateTimeByteLen:] } if len(b) > 0 { @@ -161,7 +143,6 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { // parseLocalTime is a bit different because it also returns the remaining // []byte that is didn't need. This is to allow parseDateTime to parse those // remaining bytes as a timezone. -//nolint:cyclop,funlen func parseLocalTime(b []byte) (LocalTime, []byte, error) { var ( nspow = [10]int{0, 1e8, 1e7, 1e6, 1e5, 1e4, 1e3, 1e2, 1e1, 1e0} @@ -173,46 +154,26 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) { return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") } - var err error - - t.Hour, err = parseDecimalDigits(b[0:2]) - if err != nil { - return t, nil, err - } - + t.Hour = parseDecimalDigits(b[0:2]) if b[2] != ':' { return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") } - t.Minute, err = parseDecimalDigits(b[3:5]) - if err != nil { - return t, nil, err - } - + t.Minute = parseDecimalDigits(b[3:5]) if b[5] != ':' { return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") } - t.Second, err = parseDecimalDigits(b[6:8]) - if err != nil { - return t, nil, err - } + t.Second = parseDecimalDigits(b[6:8]) - if len(b) >= 9 && b[8] == '.' { + const minLengthWithFrac = 9 + if len(b) >= minLengthWithFrac && b[minLengthWithFrac-1] == '.' { frac := 0 digits := 0 - for i, c := range b[9:] { - if !isDigit(c) { - if i == 0 { - return t, nil, newDecodeError(b[i:i+1], "need at least one digit after fraction point") - } - - break - } - - //nolint:gomnd - if i >= 9 { + for i, c := range b[minLengthWithFrac:] { + const maxFracPrecision = 9 + if i >= maxFracPrecision { return t, nil, newDecodeError(b[i:i+1], "maximum precision for date time is nanosecond") } @@ -231,8 +192,6 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) { //nolint:cyclop func parseFloat(b []byte) (float64, error) { - //nolint:godox - // TODO: inefficient if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' { return math.NaN(), nil } @@ -252,7 +211,7 @@ func parseFloat(b []byte) (float64, error) { f, err := strconv.ParseFloat(string(cleaned), 64) if err != nil { - return 0, newDecodeError(b, "coudn't parse float: %w", err) + return 0, newDecodeError(b, "unable to parse float: %w", err) } return f, nil @@ -315,10 +274,6 @@ func parseIntDec(b []byte) (int64, error) { } func checkAndRemoveUnderscores(b []byte) ([]byte, error) { - if len(b) == 0 { - return b, nil - } - if b[0] == '_' { return nil, newDecodeError(b[0:1], "number cannot start with underscore") } diff --git a/doc.go b/doc.go index d541e4b..b7bc599 100644 --- a/doc.go +++ b/doc.go @@ -1,4 +1,2 @@ -/* - Package toml is a library to read and write TOML documents. -*/ +// Package toml is a library to read and write TOML documents. package toml diff --git a/errors.go b/errors.go index 4486505..712765b 100644 --- a/errors.go +++ b/errors.go @@ -105,13 +105,9 @@ func (e *DecodeError) Key() Key { // highlight can be freely deallocated. //nolint:funlen func wrapDecodeError(document []byte, de *decodeError) *DecodeError { - if de == nil { - return nil - } - offset := unsafe.SubsliceOffset(document, de.highlight) - errMessage := de.message + errMessage := de.Error() errLine, errColumn := positionAtEnd(document[:offset]) before, after := linesOfContext(document, de.highlight, offset, 3) diff --git a/errors_test.go b/errors_test.go index 893dc9c..d6af314 100644 --- a/errors_test.go +++ b/errors_test.go @@ -181,6 +181,24 @@ line 5`, } } +func TestDecodeError_Accessors(t *testing.T) { + t.Parallel() + + e := DecodeError{ + message: "foo", + line: 1, + column: 2, + key: []string{"one", "two"}, + human: "bar", + } + assert.Equal(t, "toml: foo", e.Error()) + r, c := e.Position() + assert.Equal(t, 1, r) + assert.Equal(t, 2, c) + assert.Equal(t, Key{"one", "two"}, e.Key()) + assert.Equal(t, "bar", e.String()) +} + func ExampleDecodeError() { doc := `name = 123__456` @@ -189,14 +207,15 @@ func ExampleDecodeError() { fmt.Println(err) + //nolint:errorlint de := err.(*DecodeError) fmt.Println(de.String()) row, col := de.Position() - fmt.Println("error occured at row", row, "column", col) + fmt.Println("error occurred at row", row, "column", col) // Output: // toml: number must have at least one digit between underscores // 1| name = 123__456 // | ~~ number must have at least one digit between underscores - // error occured at row 1 column 11 + // error occurred at row 1 column 11 } diff --git a/marshaler.go b/marshaler.go index 713d8af..ce9972a 100644 --- a/marshaler.go +++ b/marshaler.go @@ -127,6 +127,10 @@ func (enc *Encoder) Encode(v interface{}) error { ctx.inline = enc.tablesInline + if v == nil { + return fmt.Errorf("toml: cannot encode a nil interface") + } + b, err := enc.encode(b, ctx, reflect.ValueOf(v)) if err != nil { return err @@ -193,9 +197,11 @@ func (ctx *encoderCtx) isRoot() bool { //nolint:cyclop,funlen func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { - i, ok := v.Interface().(time.Time) - if ok { - return i.AppendFormat(b, time.RFC3339), nil + if !v.IsZero() { + i, ok := v.Interface().(time.Time) + if ok { + return i.AppendFormat(b, time.RFC3339), nil + } } if v.Type().Implements(textMarshalerType) { @@ -273,11 +279,6 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r if !ctx.hasKey { panic("caller of encodeKv should have set the key in the context") } - - if isNil(v) { - return b, nil - } - b = enc.indent(ctx.indent, b) b, err = enc.encodeKey(b, ctx.key) @@ -470,12 +471,7 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte continue } - table, err := willConvertToTableOrArrayTable(ctx, v) - if err != nil { - return nil, err - } - - if table { + if willConvertToTableOrArrayTable(ctx, v) { t.pushTable(k, v, emptyValueOptions) } else { t.pushKV(k, v, emptyValueOptions) @@ -543,18 +539,13 @@ func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]b continue } - willConvert, err := willConvertToTableOrArrayTable(ctx, f) - if err != nil { - return nil, err - } - options := valueOptions{ multiline: fieldBoolTag(fieldType, "multiline"), } inline := fieldBoolTag(fieldType, "inline") - if inline || !willConvert { + if inline || !willConvertToTableOrArrayTable(ctx, f) { t.pushKV(k, f, options) } else { t.pushTable(k, f, options) @@ -640,21 +631,8 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte } } - for _, table := range t.tables { - if first { - first = false - } else { - b = append(b, `, `...) - } - - ctx.setKey(table.Key) - - b, err = enc.encode(b, ctx, table.Value) - if err != nil { - return nil, err - } - - b = append(b, '\n') + if len(t.tables) > 0 { + panic("inline table cannot contain nested tables, online key-values") } b = append(b, "}"...) @@ -664,61 +642,50 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() -func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { +func willConvertToTable(ctx encoderCtx, v reflect.Value) bool { if v.Type() == timeType || v.Type().Implements(textMarshalerType) { - return false, nil + return false } t := v.Type() switch t.Kind() { case reflect.Map, reflect.Struct: - return !ctx.inline, nil + return !ctx.inline case reflect.Interface: - if v.IsNil() { - return false, fmt.Errorf("toml: encoding a nil interface is not supported") - } - return willConvertToTable(ctx, v.Elem()) case reflect.Ptr: if v.IsNil() { - return false, nil + return false } return willConvertToTable(ctx, v.Elem()) default: - return false, nil + return false } } -func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) (bool, error) { +func willConvertToTableOrArrayTable(ctx encoderCtx, v reflect.Value) bool { t := v.Type() if t.Kind() == reflect.Interface { - if v.IsNil() { - return false, fmt.Errorf("toml: encoding a nil interface is not supported") - } - return willConvertToTableOrArrayTable(ctx, v.Elem()) } if t.Kind() == reflect.Slice { if v.Len() == 0 { // An empty slice should be a kv = []. - return false, nil + return false } for i := 0; i < v.Len(); i++ { - t, err := willConvertToTable(ctx, v.Index(i)) - if err != nil { - return false, err - } + t := willConvertToTable(ctx, v.Index(i)) if !t { - return false, nil + return false } } - return true, nil + return true } return willConvertToTable(ctx, v) @@ -731,12 +698,7 @@ func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]by return b, nil } - allTables, err := willConvertToTableOrArrayTable(ctx, v) - if err != nil { - return nil, err - } - - if allTables { + if willConvertToTableOrArrayTable(ctx, v) { return enc.encodeSliceAsArrayTable(b, ctx, v) } @@ -746,10 +708,6 @@ func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]by // caller should have checked that v is a slice that only contains values that // encode into tables. func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { - if v.Len() == 0 { - return b, nil - } - ctx.shiftKey() var err error diff --git a/marshaler_test.go b/marshaler_test.go index d719406..5c946c3 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -16,6 +16,12 @@ import ( func TestMarshal(t *testing.T) { t.Parallel() + someInt := 42 + + type structInline struct { + A interface{} `inline:"true"` + } + examples := []struct { desc string v interface{} @@ -298,6 +304,213 @@ A = [ ] `, }, + { + desc: "nil interface not supported at root", + v: nil, + err: true, + }, + { + desc: "nil interface not supported in slice", + v: map[string]interface{}{ + "a": []interface{}{"a", nil, 2}, + }, + err: true, + }, + { + desc: "nil pointer in slice uses zero value", + v: struct { + A []*int + }{ + A: []*int{nil}, + }, + expected: `A = [0]`, + }, + { + desc: "nil pointer in slice uses zero value", + v: struct { + A []*int + }{ + A: []*int{nil}, + }, + expected: `A = [0]`, + }, + { + desc: "pointer in slice", + v: struct { + A []*int + }{ + A: []*int{&someInt}, + }, + expected: `A = [42]`, + }, + { + desc: "inline table in inline table", + v: structInline{ + A: structInline{ + A: structInline{ + A: "hello", + }, + }, + }, + expected: `A = {A = {A = 'hello'}}`, + }, + { + desc: "empty slice in map", + v: map[string][]string{ + "a": {}, + }, + expected: `a = []`, + }, + { + desc: "map in slice", + v: map[string][]map[string]string{ + "a": {{"hello": "world"}}, + }, + expected: ` +[[a]] +hello = 'world'`, + }, + { + desc: "newline in map in slice", + v: map[string][]map[string]string{ + "a\n": {{"hello": "world"}}, + }, + err: true, + }, + { + desc: "newline in map in slice", + v: map[string][]map[string]*customTextMarshaler{ + "a": {{"hello": &customTextMarshaler{1}}}, + }, + err: true, + }, + { + desc: "empty slice of empty struct", + v: struct { + A []struct{} + }{ + A: []struct{}{}, + }, + expected: `A = []`, + }, + { + desc: "nil field is ignored", + v: struct { + A interface{} + }{ + A: nil, + }, + expected: ``, + }, + { + desc: "private fields are ignored", + v: struct { + Public string + private string + }{ + Public: "shown", + private: "hidden", + }, + expected: `Public = 'shown'`, + }, + { + desc: "fields tagged - are ignored", + v: struct { + Public string `toml:"-"` + private string + }{ + Public: "hidden", + }, + expected: ``, + }, + { + desc: "nil value in map is ignored", + v: map[string]interface{}{ + "A": nil, + }, + expected: ``, + }, + { + desc: "new line in table key", + v: map[string]interface{}{ + "hello\nworld": 42, + }, + err: true, + }, + { + desc: "new line in parent of nested table key", + v: map[string]interface{}{ + "hello\nworld": map[string]interface{}{ + "inner": 42, + }, + }, + err: true, + }, + { + desc: "new line in nested table key", + v: map[string]interface{}{ + "parent": map[string]interface{}{ + "in\ner": map[string]interface{}{ + "foo": 42, + }, + }, + }, + err: true, + }, + { + desc: "invalid map key", + v: map[int]interface{}{}, + err: true, + }, + { + desc: "unhandled type", + v: struct { + A chan int + }{ + A: make(chan int), + }, + err: true, + }, + { + desc: "numbers", + v: struct { + A float32 + B uint64 + C uint32 + D uint16 + E uint8 + F uint + G int64 + H int32 + I int16 + J int8 + K int + }{ + A: 1.1, + B: 42, + C: 42, + D: 42, + E: 42, + F: 42, + G: 42, + H: 42, + I: 42, + J: 42, + K: 42, + }, + expected: ` +A = 1.1 +B = 42 +C = 42 +D = 42 +E = 42 +F = 42 +G = 42 +H = 42 +I = 42 +J = 42 +K = 42`, + }, } for _, e := range examples { @@ -460,6 +673,85 @@ root = 'value0' } } +type customTextMarshaler struct { + value int64 +} + +func (c *customTextMarshaler) MarshalText() ([]byte, error) { + if c.value == 1 { + return nil, fmt.Errorf("cannot represent 1 because this is a silly test") + } + return []byte(fmt.Sprintf("::%d", c.value)), nil +} + +func TestMarshalTextMarshaler_NoRoot(t *testing.T) { + t.Parallel() + + c := customTextMarshaler{} + _, err := toml.Marshal(&c) + require.Error(t, err) +} + +func TestMarshalTextMarshaler_Error(t *testing.T) { + t.Parallel() + + m := map[string]interface{}{"a": &customTextMarshaler{value: 1}} + _, err := toml.Marshal(m) + require.Error(t, err) +} + +func TestMarshalTextMarshaler_ErrorInline(t *testing.T) { + t.Parallel() + + type s struct { + A map[string]interface{} `inline:"true"` + } + + d := s{ + A: map[string]interface{}{"a": &customTextMarshaler{value: 1}}, + } + + _, err := toml.Marshal(d) + require.Error(t, err) +} + +func TestMarshalTextMarshaler(t *testing.T) { + t.Parallel() + + m := map[string]interface{}{"a": &customTextMarshaler{value: 2}} + r, err := toml.Marshal(m) + require.NoError(t, err) + equalStringsIgnoreNewlines(t, "a = '::2'", string(r)) +} + +type brokenWriter struct{} + +func (b *brokenWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("dead") +} + +func TestEncodeToBrokenWriter(t *testing.T) { + t.Parallel() + w := brokenWriter{} + enc := toml.NewEncoder(&w) + err := enc.Encode(map[string]string{"hello": "world"}) + require.Error(t, err) +} + +func TestEncoderSetIndentSymbol(t *testing.T) { + t.Parallel() + var w strings.Builder + enc := toml.NewEncoder(&w) + enc.SetIndentTables(true) + enc.SetIndentSymbol(">>>") + err := enc.Encode(map[string]map[string]string{"parent": {"hello": "world"}}) + require.NoError(t, err) + expected := ` +[parent] +>>>hello = 'world'` + equalStringsIgnoreNewlines(t, expected, w.String()) +} + func TestIssue436(t *testing.T) { t.Parallel() diff --git a/parser.go b/parser.go index 5b6e3ba..eefbf03 100644 --- a/parser.go +++ b/parser.go @@ -2,7 +2,6 @@ package toml import ( "bytes" - "fmt" "strconv" "github.com/pelletier/go-toml/v2/internal/ast" @@ -77,7 +76,6 @@ func (p *parser) parseNewline(b []byte) ([]byte, error) { if b[0] == '\r' { _, rest, err := scanWindowsNewline(b) - return rest, err } @@ -206,6 +204,10 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { b = p.parseWhitespace(b) + if len(b) == 0 { + return ast.Reference{}, nil, newDecodeError(b, "expected = after a key, but the document ends there") + } + b, err = expect('=', b) if err != nil { return ast.Reference{}, nil, err @@ -304,6 +306,7 @@ func atmost(b []byte, n int) []byte { if n >= len(b) { return b } + return b[:n] } @@ -397,8 +400,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { } if len(b) == 0 { - //nolint:godox - return parent, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF + return parent, nil, newDecodeError(b, "array is incomplete") } if b[0] == ']' { @@ -562,7 +564,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { case 't': builder.WriteByte('\t') case 'u': - x, err := hexToString(token[i+3:len(token)-3], 4) + x, err := hexToString(atmost(token[i+1:], 4), 4) if err != nil { return nil, nil, err } @@ -570,7 +572,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { builder.WriteString(x) i += 4 case 'U': - x, err := hexToString(token[i+3:len(token)-3], 8) + x, err := hexToString(atmost(token[i+1:], 8), 8) if err != nil { return nil, nil, err } @@ -610,12 +612,7 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { for { b = p.parseWhitespace(b) if len(b) > 0 && b[0] == '.' { - b, err = expect('.', b) - if err != nil { - return ref, nil, err - } - - b = p.parseWhitespace(b) + b = p.parseWhitespace(b[1:]) key, b, err = p.parseSimpleKey(b) if err != nil { @@ -639,8 +636,7 @@ func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ // quoted-key = basic-string / literal-string if len(b) == 0 { - //nolint:godox - return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF + return nil, nil, newDecodeError(b, "key is incomplete") } switch { @@ -649,10 +645,10 @@ func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { case b[0] == '"': return p.parseBasicString(b) case isUnquotedKeyChar(b[0]): - return scanUnquotedKey(b) + key, rest = scanUnquotedKey(b) + return key, rest, nil default: - //nolint:godox - return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF + return nil, nil, newDecodeError(b[0:1], "invalid character at start of key: %c", b[0]) } } @@ -825,11 +821,14 @@ byteLoop: c := b[i] switch { - case isDigit(c) || c == '-': + case isDigit(c): + case c == '-': + const offsetOfTz = 19 + if i == offsetOfTz { + hasTz = true + } case c == 'T' || c == ':' || c == '.': hasTime = true - - continue byteLoop case c == '+' || c == '-' || c == 'Z': hasTz = true case c == ' ': @@ -854,9 +853,6 @@ byteLoop: kind = ast.LocalDateTime } } else { - if hasTz { - return ast.Reference{}, nil, newDecodeError(b, "date-time has timezone but not time component") - } kind = ast.LocalDate } @@ -977,26 +973,9 @@ func isValidBinaryRune(r byte) bool { } func expect(x byte, b []byte) ([]byte, error) { - if len(b) == 0 { - return nil, newDecodeError(b[:0], "expecting %#U", x) - } - if b[0] != x { return nil, newDecodeError(b[0:1], "expected character %U", x) } return b[1:], nil } - -type unexpectedCharacter struct { - r byte - b []byte -} - -func (u unexpectedCharacter) Error() string { - if len(u.b) == 0 { - return fmt.Sprintf("expected %#U, not EOF", u.r) - } - - return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) -} diff --git a/scanner.go b/scanner.go index 047aef5..b203702 100644 --- a/scanner.go +++ b/scanner.go @@ -30,15 +30,15 @@ func scanFollowsNan(b []byte) bool { return scanFollows(b, `nan`) } -func scanUnquotedKey(b []byte) ([]byte, []byte, error) { +func scanUnquotedKey(b []byte) ([]byte, []byte) { // unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ for i := 0; i < len(b); i++ { if !isUnquotedKeyChar(b[i]) { - return b[:i], b[i:], nil + return b[:i], b[i:] } } - return b, b[len(b):], nil + return b, b[len(b):] } func isUnquotedKeyChar(r byte) bool { diff --git a/targets.go b/targets.go index fd3eaec..370892f 100644 --- a/targets.go +++ b/targets.go @@ -70,19 +70,19 @@ func (t interfaceTarget) set(v reflect.Value) { } func (t interfaceTarget) setString(v string) { - t.x.setString(v) + panic("interface targets should always go through set") } func (t interfaceTarget) setBool(v bool) { - t.x.setBool(v) + panic("interface targets should always go through set") } func (t interfaceTarget) setInt64(v int64) { - t.x.setInt64(v) + panic("interface targets should always go through set") } func (t interfaceTarget) setFloat64(v float64) { - t.x.setFloat64(v) + panic("interface targets should always go through set") } // mapTarget targets a specific key of a map. @@ -115,7 +115,6 @@ func (t mapTarget) setFloat64(v float64) { t.set(reflect.ValueOf(v)) } -//nolint:cyclop // makes sure that the value pointed at by t is indexable (Slice, Array), or // dereferences to an indexable (Ptr, Interface). func ensureValueIndexable(t target) error { @@ -193,7 +192,7 @@ const ( minInt = -maxInt - 1 ) -//nolint:funlen,gocognit,cyclop,gocyclo +//nolint:funlen,gocognit,cyclop func setInt64(t target, v int64) error { f := t.get() @@ -285,7 +284,6 @@ func setFloat64(t target, v float64) error { return nil } -//nolint:cyclop // Returns the element at idx of the value pointed at by target, or an error if // t does not point to an indexable. // If the target points to an Array and idx is out of bounds, it returns @@ -311,7 +309,6 @@ func elementAt(t target, idx int) target { case reflect.Interface: // This function is called after ensureValueIndexable, so it's // guaranteed that f contains an initialized slice. - ifaceElem := f.Elem() idx := ifaceElem.Len() newElem := reflect.New(ifaceElem.Type().Elem()).Elem() @@ -326,7 +323,6 @@ func elementAt(t target, idx int) target { } } -//nolint:cyclop func (d *decoder) scopeTableTarget(shouldAppend bool, t target, name string) (target, bool, error) { x := t.get() diff --git a/unmarshaler.go b/unmarshaler.go index 3d25a86..6a156ae 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -541,6 +541,27 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error { return err } + // Special work around when unmarshaling into an array. + // If the array is not addressable, for example when stored as a value in a + // map, calling elementAt in the inner function would fail. + // Instead, we allocate a new array that will be filled then inserted into + // the container. + // This problem does not exist with slices because they are addressable. + // There may be a better way of doing this, but it is not obvious to me + // with the target system. + if x.get().Kind() == reflect.Array { + container := x + newArrayPtr := reflect.New(x.get().Type()) + x = valueTarget(newArrayPtr.Elem()) + defer func() { + container.set(newArrayPtr.Elem()) + }() + } + + return d.unmarshalArrayInner(x, node) +} + +func (d *decoder) unmarshalArrayInner(x target, node ast.Node) error { idx := 0 it := node.Children() @@ -555,14 +576,13 @@ func (d *decoder) unmarshalArray(x target, node ast.Node) error { break } - err = d.unmarshalValue(v, n) + err := d.unmarshalValue(v, n) if err != nil { return err } idx++ } - return nil } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 338c0f5..6346b8e 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" ) +// nolint:funlen func TestUnmarshal_Integers(t *testing.T) { t.Parallel() @@ -239,6 +240,34 @@ func TestUnmarshal(t *testing.T) { } }, }, + { + desc: "time.time with negative zone", + input: `a = 1979-05-27T00:32:00-07:00 `, // space intentional + gen: func() test { + var v map[string]time.Time + + return test{ + target: &v, + expected: &map[string]time.Time{ + "a": time.Date(1979, 5, 27, 0, 32, 0, 0, time.FixedZone("", -7*3600)), + }, + } + }, + }, + { + desc: "time.time with positive zone", + input: `a = 1979-05-27T00:32:00+07:00`, + gen: func() test { + var v map[string]time.Time + + return test{ + target: &v, + expected: &map[string]time.Time{ + "a": time.Date(1979, 5, 27, 0, 32, 0, 0, time.FixedZone("", 7*3600)), + }, + } + }, + }, { desc: "issue 475 - space between dots in key", input: `fruit. color = "yellow" @@ -288,6 +317,73 @@ func TestUnmarshal(t *testing.T) { } }, }, + { + desc: "multiline literal string with windows newline", + input: "A = '''\r\nTest'''", + gen: func() test { + type doc struct { + A string + } + + return test{ + target: &doc{}, + expected: &doc{A: "Test"}, + } + }, + }, + { + desc: "multiline basic string with windows newline", + input: "A = \"\"\"\r\nTest\"\"\"", + gen: func() test { + type doc struct { + A string + } + + return test{ + target: &doc{}, + expected: &doc{A: "Test"}, + } + }, + }, + { + desc: "multiline basic string escapes", + input: `A = """ +\\\b\f\n\r\t\uffff\U0001D11E"""`, + gen: func() test { + type doc struct { + A string + } + + return test{ + target: &doc{}, + expected: &doc{A: "\\\b\f\n\r\t\uffff\U0001D11E"}, + } + }, + }, + { + desc: "basic string escapes", + input: `A = "\\\b\f\n\r\t\uffff\U0001D11E"`, + gen: func() test { + type doc struct { + A string + } + + return test{ + target: &doc{}, + expected: &doc{A: "\\\b\f\n\r\t\uffff\U0001D11E"}, + } + }, + }, + { + desc: "spaces around dotted keys", + input: "a . b = 1", + gen: func() test { + return test{ + target: &map[string]map[string]interface{}{}, + expected: &map[string]map[string]interface{}{"a": {"b": int64(1)}}, + } + }, + }, { desc: "kv bool true", input: `A = true`, @@ -721,6 +817,197 @@ B = "data"`, } }, }, + { + desc: "interface holding a string", + input: `A = "Hello"`, + gen: func() test { + type doc struct { + A interface{} + } + return test{ + target: &doc{}, + expected: &doc{ + A: "Hello", + }, + } + }, + }, + { + desc: "map of bools", + input: `A = true`, + gen: func() test { + return test{ + target: &map[string]bool{}, + expected: &map[string]bool{"A": true}, + } + }, + }, + { + desc: "map of int64", + input: `A = 42`, + gen: func() test { + return test{ + target: &map[string]int64{}, + expected: &map[string]int64{"A": 42}, + } + }, + }, + { + desc: "map of float64", + input: `A = 4.2`, + gen: func() test { + return test{ + target: &map[string]float64{}, + expected: &map[string]float64{"A": 4.2}, + } + }, + }, + { + desc: "array of int in map", + input: `A = [1,2,3]`, + gen: func() test { + return test{ + target: &map[string][3]int{}, + expected: &map[string][3]int{"A": {1, 2, 3}}, + } + }, + }, + { + desc: "array of int in map with too many elements", + input: `A = [1,2,3,4,5]`, + gen: func() test { + return test{ + target: &map[string][3]int{}, + expected: &map[string][3]int{"A": {1, 2, 3}}, + } + }, + }, + { + desc: "array of int in map with invalid element", + input: `A = [1,2,false]`, + gen: func() test { + return test{ + target: &map[string][3]int{}, + err: true, + } + }, + }, + { + desc: "nested arrays", + input: ` + [[A]] + [[A.B]] + C = 1 + [[A]] + [[A.B]] + C = 2`, + gen: func() test { + type leaf struct { + C int + } + type inner struct { + B [2]leaf + } + type s struct { + A [2]inner + } + return test{ + target: &s{}, + expected: &s{A: [2]inner{ + {B: [2]leaf{ + {C: 1}, + }}, + {B: [2]leaf{ + {C: 2}, + }}, + }}, + } + }, + }, + { + desc: "nested arrays too many", + input: ` + [[A]] + [[A.B]] + C = 1 + [[A.B]] + C = 2`, + gen: func() test { + type leaf struct { + C int + } + type inner struct { + B [1]leaf + } + type s struct { + A [1]inner + } + return test{ + target: &s{}, + err: true, + } + }, + }, + { + desc: "into map with invalid key type", + input: `A = "hello"`, + gen: func() test { + return test{ + target: &map[int]string{}, + err: true, + } + }, + }, + { + desc: "into map with convertible key type", + input: `A = "hello"`, + gen: func() test { + type foo string + return test{ + target: &map[foo]string{}, + expected: &map[foo]string{ + "A": "hello", + }, + } + }, + }, + { + desc: "array of int in struct", + input: `A = [1,2,3]`, + gen: func() test { + type s struct { + A [3]int + } + return test{ + target: &s{}, + expected: &s{A: [3]int{1, 2, 3}}, + } + }, + }, + { + desc: "array of int in struct", + input: `[A] + b = 42`, + gen: func() test { + type s struct { + A *map[string]interface{} + } + return test{ + target: &s{}, + expected: &s{A: &map[string]interface{}{"b": int64(42)}}, + } + }, + }, + { + desc: "assign bool to float", + input: `A = true`, + gen: func() test { + return test{ + target: &map[string]float64{}, + err: true, + } + }, + }, { desc: "interface holding a struct", input: `[A] @@ -877,6 +1164,82 @@ B = "data"`, } } +func TestUnmarshalOverflows(t *testing.T) { + examples := []struct { + t interface{} + errors []string + }{ + { + t: &map[string]int32{}, + errors: []string{`-2147483649`, `2147483649`}, + }, + { + t: &map[string]int16{}, + errors: []string{`-2147483649`, `2147483649`}, + }, + { + t: &map[string]int8{}, + errors: []string{`-2147483649`, `2147483649`}, + }, + { + t: &map[string]int{}, + errors: []string{`-19223372036854775808`, `9223372036854775808`}, + }, + { + t: &map[string]uint64{}, + errors: []string{`-1`, `18446744073709551616`}, + }, + { + t: &map[string]uint32{}, + errors: []string{`-1`, `18446744073709551616`}, + }, + { + t: &map[string]uint16{}, + errors: []string{`-1`, `18446744073709551616`}, + }, + { + t: &map[string]uint8{}, + errors: []string{`-1`, `18446744073709551616`}, + }, + { + t: &map[string]uint{}, + errors: []string{`-1`, `18446744073709551616`}, + }, + } + + for _, e := range examples { + e := e + for _, v := range e.errors { + v := v + t.Run(fmt.Sprintf("%T %s", e.t, v), func(t *testing.T) { + doc := "A = " + v + err := toml.Unmarshal([]byte(doc), e.t) + t.Log("input:", doc) + require.Error(t, err) + }) + } + t.Run(fmt.Sprintf("%T ok", e.t), func(t *testing.T) { + doc := "A = 1" + err := toml.Unmarshal([]byte(doc), e.t) + t.Log("input:", doc) + require.NoError(t, err) + }) + } +} + +func TestUnmarshalFloat32(t *testing.T) { + t.Run("fits", func(t *testing.T) { + doc := "A = 1.2" + err := toml.Unmarshal([]byte(doc), &map[string]float32{}) + require.NoError(t, err) + }) + t.Run("overflows", func(t *testing.T) { + doc := "A = 4.40282346638528859811704183484516925440e+38" + err := toml.Unmarshal([]byte(doc), &map[string]float32{}) + require.Error(t, err) + }) +} + type Integer484 struct { Value int } @@ -999,10 +1362,66 @@ func TestUnmarshalDecodeErrors(t *testing.T) { data string msg string }{ + { + desc: "local date with invalid digit", + data: `a = 20x1-05-21`, + }, + { + desc: "local time with fractional", + data: `a = 11:22:33.x`, + }, + { + desc: "local time frac precision too large", + data: `a = 2021-05-09T11:22:33.99999999999`, + }, + { + desc: "wrong time offset separator", + data: `a = 1979-05-27T00:32:00T07:00`, + }, + { + desc: "wrong time offset separator", + data: `a = 1979-05-27T00:32:00Z07:00`, + }, + { + desc: "float with double _", + data: `flt8 = 224_617.445_991__228`, + }, + { + desc: "float with double _", + data: `flt8 = 1..2`, + }, { desc: "int with wrong base", data: `a = 0f2`, }, + { + desc: "int hex with double underscore", + data: `a = 0xFFF__FFF`, + }, + { + desc: "int hex very large", + data: `a = 0xFFFFFFFFFFFFFFFFF`, + }, + { + desc: "int oct with double underscore", + data: `a = 0o777__77`, + }, + { + desc: "int oct very large", + data: `a = 0o77777777777777777777777`, + }, + { + desc: "int bin with double underscore", + data: `a = 0b111__111`, + }, + { + desc: "int bin very large", + data: `a = 0b11111111111111111111111111111111111111111111111111111111111111111111111111111`, + }, + { + desc: "int dec very large", + data: `a = 999999999999999999999999`, + }, { desc: "literal string with new lines", data: `a = 'hello @@ -1065,6 +1484,102 @@ world'`, data: `a = 2021-03-30 21:312:0`, msg: `expecting colon between minutes and seconds`, }, + { + desc: `binary with invalid digit`, + data: `a = 0bf`, + }, + { + desc: `invalid i in dec`, + data: `a = 0i`, + }, + { + desc: `invalid n in dec`, + data: `a = 0n`, + }, + { + desc: `invalid unquoted key`, + data: `a`, + }, + { + desc: "dt with tz has no time", + data: `a = 2021-03-30TZ`, + }, + { + desc: "invalid end of array table", + data: `[[a}`, + }, + { + desc: "invalid end of array table two", + data: `[[a]}`, + }, + { + desc: "eof after equal", + data: `a =`, + }, + { + desc: "invalid true boolean", + data: `a = trois`, + }, + { + desc: "invalid false boolean", + data: `a = faux`, + }, + { + desc: "inline table with incorrect separator", + data: `a = {b=1;}`, + }, + { + desc: "inline table with invalid value", + data: `a = {b=faux}`, + }, + { + desc: `incomplete array after whitespace`, + data: `a = [ `, + }, + { + desc: `array with comma first`, + data: `a = [ ,]`, + }, + { + desc: `array staring with incomplete newline`, + data: "a = [\r]", + }, + { + desc: `array with incomplete newline after comma`, + data: "a = [1,\r]", + }, + { + desc: `array with incomplete newline after value`, + data: "a = [1\r]", + }, + { + desc: `invalid unicode in basic multiline string`, + data: `A = """\u123"""`, + }, + { + desc: `invalid long unicode in basic multiline string`, + data: `A = """\U0001D11"""`, + }, + { + desc: `invalid unicode in basic string`, + data: `A = "\u123"`, + }, + { + desc: `invalid long unicode in basic string`, + data: `A = "\U0001D11"`, + }, + { + desc: `invalid escape char basic multiline string`, + data: `A = """\z"""`, + }, + { + desc: `invalid inf`, + data: `A = ick`, + }, + { + desc: `invalid nan`, + data: `A = non`, + }, } for _, e := range examples { @@ -1270,21 +1785,35 @@ bar = 42 t.Run(e.desc, func(t *testing.T) { t.Parallel() - r := strings.NewReader(e.input) - d := toml.NewDecoder(r) - d.SetStrict(true) - x := e.target - if x == nil { - x = &struct{}{} - } - err := d.Decode(x) + t.Run("strict", func(t *testing.T) { + r := strings.NewReader(e.input) + d := toml.NewDecoder(r) + d.SetStrict(true) + x := e.target + if x == nil { + x = &struct{}{} + } + err := d.Decode(x) - var tsm *toml.StrictMissingError - if errors.As(err, &tsm) { - equalStringsIgnoreNewlines(t, e.expected, tsm.String()) - } else { - t.Fatalf("err should have been a *toml.StrictMissingError, but got %s (%T)", err, err) - } + var tsm *toml.StrictMissingError + if errors.As(err, &tsm) { + equalStringsIgnoreNewlines(t, e.expected, tsm.String()) + } else { + t.Fatalf("err should have been a *toml.StrictMissingError, but got %s (%T)", err, err) + } + }) + + t.Run("default", func(t *testing.T) { + r := strings.NewReader(e.input) + d := toml.NewDecoder(r) + d.SetStrict(false) + x := e.target + if x == nil { + x = &struct{}{} + } + err := d.Decode(x) + require.NoError(t, err) + }) }) } }