From 32da85ab11695b140552ebce61d90f0f8635a8b0 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 30 Mar 2021 21:43:57 -0400 Subject: [PATCH] Decoding error position tracking --- README.md | 2 +- decode.go | 12 ++--- errors.go | 10 +++- internal/unsafe/unsafe.go | 2 +- parser.go | 13 +++-- scanner.go | 18 +++---- unmarshaler.go | 17 +++++- unmarshaler_test.go | 110 ++++++++++++++++++++++++++++++++++---- 8 files changed, 150 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 04d6b4b..d467703 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Development branch. Probably does not work. - [x] Benchmark! - [x] Abstract AST. - [x] Original go-toml testgen tests pass. -- [ ] Track file position (line, column) for errors. +- [x] Track file position (line, column) for errors. - [ ] Attach comments to AST (gated by parser flag). - [ ] Benchmark again! diff --git a/decode.go b/decode.go index 58bca71..d4b8559 100644 --- a/decode.go +++ b/decode.go @@ -34,7 +34,7 @@ func parseLocalDate(b []byte) (LocalDate, error) { date := LocalDate{} if len(b) != 10 || b[4] != '-' || b[7] != '-' { - return date, fmt.Errorf("dates are expected to have the format YYYY-MM-DD") + return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") } var err error @@ -89,7 +89,7 @@ func parseDateTime(b []byte) (time.Time, error) { zone = time.UTC } else { if len(b) != 6 { - return time.Time{}, fmt.Errorf("invalid date-time timezone") + return time.Time{}, newDecodeError(b, "invalid date-time timezone") } direction := 1 switch b[0] { @@ -97,7 +97,7 @@ func parseDateTime(b []byte) (time.Time, error) { case '-': direction = -1 default: - return time.Time{}, fmt.Errorf("invalid timezone offset character") + return time.Time{}, newDecodeError(b[0:1], "invalid timezone offset character") } hours := digitsToInt(b[1:3]) @@ -107,7 +107,7 @@ func parseDateTime(b []byte) (time.Time, error) { } if len(b) > 0 { - return time.Time{}, fmt.Errorf("extra bytes at the end of the timezone") + return time.Time{}, newDecodeError(b, "extra bytes at the end of the timezone") } t := time.Date( @@ -166,14 +166,14 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) { return t, nil, err } if b[2] != ':' { - return t, nil, fmt.Errorf("expecting colon between hours and minutes") + return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") } t.Minute, err = parseDecimalDigits(b[3:5]) if err != nil { return t, nil, err } if b[5] != ':' { - return t, nil, fmt.Errorf("expecting colon between minutes and seconds") + return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") } t.Second, err = parseDecimalDigits(b[6:8]) if err != nil { diff --git a/errors.go b/errors.go index 8d31594..cc0baba 100644 --- a/errors.go +++ b/errors.go @@ -129,8 +129,16 @@ func formatLineNumber(line int, width int) string { func linesOfContext(document []byte, highlight []byte, offset int, linesAround int) ([][]byte, [][]byte) { var beforeLines [][]byte for beforeOffset, lastOffset := offset, offset; beforeOffset >= 0 && len(beforeLines) <= linesAround; beforeOffset-- { + if beforeOffset == len(document) { + beforeLines = append(beforeLines, []byte{}) + continue + } if document[beforeOffset] == '\n' { - beforeLines = append(beforeLines, document[beforeOffset+1:lastOffset]) + if beforeOffset == lastOffset { + beforeLines = append(beforeLines, []byte{}) + } else { + beforeLines = append(beforeLines, document[beforeOffset+1:lastOffset]) + } lastOffset = beforeOffset } else if beforeOffset == 0 && beforeOffset != lastOffset { beforeLines = append(beforeLines, document[beforeOffset:lastOffset]) diff --git a/internal/unsafe/unsafe.go b/internal/unsafe/unsafe.go index 179fd1a..ce6b955 100644 --- a/internal/unsafe/unsafe.go +++ b/internal/unsafe/unsafe.go @@ -23,7 +23,7 @@ func SubsliceOffset(data []byte, subslice []byte) int { intoffset := int(offset) - if intoffset >= datap.Len { + if intoffset > datap.Len { panic(fmt.Errorf("slice offset (%d) is farther than data length (%d)", intoffset, datap.Len)) } diff --git a/parser.go b/parser.go index 31e69cf..bb9a37e 100644 --- a/parser.go +++ b/parser.go @@ -363,7 +363,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { } if len(b) == 0 { - return parent, nil, unexpectedCharacter{b: b} + return parent, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF } if b[0] == ']' { @@ -590,7 +590,7 @@ func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { //quoted-key = basic-string / literal-string if len(b) == 0 { - return nil, nil, unexpectedCharacter{b: b} + return nil, nil, unexpectedCharacter{b: b} // TODO: should be unexpected EOF } if b[0] == '\'' { @@ -600,7 +600,7 @@ func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { } else if isUnquotedKeyChar(b[0]) { key, rest, err = scanUnquotedKey(b) } else { - err = unexpectedCharacter{b: b} + err = unexpectedCharacter{b: b} // TODO: should contain expected characters } return } @@ -1158,8 +1158,11 @@ func isValidBinaryRune(r byte) bool { } func expect(x byte, b []byte) ([]byte, error) { - if len(b) == 0 || b[0] != x { - return nil, unexpectedCharacter{r: x, b: b} + if len(b) == 0 { + return nil, newDecodeError(b[:0], "expecting %#U", x) + } + if b[0] != x { + return nil, newDecodeError(b[0:1], "expected character %U", x) } return b[1:], nil } diff --git a/scanner.go b/scanner.go index ba857c6..bad47bc 100644 --- a/scanner.go +++ b/scanner.go @@ -30,7 +30,7 @@ func scanUnquotedKey(b []byte) ([]byte, []byte, error) { return b[:i], b[i:], nil } } - return b, nil, nil + return b, b[len(b):], nil } func isUnquotedKeyChar(r byte) bool { @@ -46,10 +46,10 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) { case '\'': return b[:i+1], b[i+1:], nil case '\n': - return nil, nil, fmt.Errorf("literal strings cannot have new lines") + return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") } } - return nil, nil, fmt.Errorf("unterminated literal string") + return nil, nil, newDecodeError(b[len(b):], "unterminated literal string") } func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { @@ -70,7 +70,7 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { } } - return nil, nil, fmt.Errorf(`multiline literal string not terminated by '''`) + return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) } func scanWindowsNewline(b []byte) ([]byte, []byte, error) { @@ -92,7 +92,7 @@ func scanWhitespace(b []byte) ([]byte, []byte) { return b[:i], b[i:] } } - return b, nil + return b, b[len(b):] } func scanComment(b []byte) ([]byte, []byte, error) { @@ -125,10 +125,10 @@ func scanBasicString(b []byte) ([]byte, []byte, error) { case '"': return b[:i+1], b[i+1:], nil case '\n': - return nil, nil, fmt.Errorf("basic strings cannot have new lines") + return nil, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines") case '\\': if len(b) < i+2 { - return nil, nil, fmt.Errorf("need a character after \\") + return nil, nil, newDecodeError(b[i:i+1], "need a character after \\") } i++ // skip the next character } @@ -158,11 +158,11 @@ func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { } case '\\': if len(b) < i+2 { - return nil, nil, fmt.Errorf("need a character after \\") + return nil, nil, newDecodeError(b[len(b):], "need a character after \\") } i++ // skip the next character } } - return nil, nil, fmt.Errorf(`multiline basic string not terminated by """`) + return nil, nil, newDecodeError(b[len(b):], `multiline basic string not terminated by """`) } diff --git a/unmarshaler.go b/unmarshaler.go index b95ce5a..0f89703 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -268,11 +268,22 @@ func (d *decoder) unmarshalValue(x target, node ast.Node) error { return unmarshalLocalDateTime(x, node) case ast.DateTime: return unmarshalDateTime(x, node) + case ast.LocalDate: + return unmarshalLocalDate(x, node) default: panic(fmt.Errorf("unhandled unmarshalValue kind %s", node.Kind)) } } +func unmarshalLocalDate(x target, node ast.Node) error { + assertNode(ast.LocalDate, node) + v, err := parseLocalDate(node.Data) + if err != nil { + return err + } + return setDate(x, v) +} + func unmarshalLocalDateTime(x target, node ast.Node) error { assertNode(ast.LocalDateTime, node) v, rest, err := parseLocalDateTime(node.Data) @@ -280,7 +291,7 @@ func unmarshalLocalDateTime(x target, node ast.Node) error { return err } if len(rest) > 0 { - return fmt.Errorf("extra characters at the end of a local date time") + return newDecodeError(rest, "extra characters at the end of a local date time") } return setLocalDateTime(x, v) } @@ -302,6 +313,10 @@ func setDateTime(x target, v time.Time) error { return x.set(reflect.ValueOf(v)) } +func setDate(x target, v LocalDate) error { + return x.set(reflect.ValueOf(v)) +} + func unmarshalString(x target, node ast.Node) error { assertNode(ast.String, node) return setString(x, string(node.Data)) diff --git a/unmarshaler_test.go b/unmarshaler_test.go index d0c1aaa..dab8957 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -200,10 +200,10 @@ func TestUnmarshal(t *testing.T) { gen: func() test { m := map[string]interface{}{} return test{ - target: &m, + target: &m, expected: &map[string]interface{}{ "fruit": map[string]interface{}{ - "color": "yellow", + "color": "yellow", "flavor": "banana", }, }, @@ -217,7 +217,7 @@ func TestUnmarshal(t *testing.T) { gen: func() test { m := map[string]interface{}{} return test{ - target: &m, + target: &m, expected: &map[string]interface{}{ `"a"`: int64(1), `"b"`: int64(2), @@ -226,7 +226,7 @@ func TestUnmarshal(t *testing.T) { }, }, { - desc: "multiline basic string", + desc: "multiline basic string", input: `A = """\ Test"""`, gen: func() test { @@ -705,7 +705,6 @@ B = "data"`, } } - type Integer484 struct { Value int } @@ -726,7 +725,7 @@ type Config484 struct { Integers []Integer484 `toml:"integers"` } -func TestIssue484(t *testing.T) { +func TestIssue484(t *testing.T) { raw := []byte(`integers = ["1","2","3","100"]`) var cfg Config484 err := toml.Unmarshal(raw, &cfg) @@ -753,10 +752,101 @@ version = "0.1.0"`) require.NoError(t, err) a := m.A("package") expected := Slice458{ - map[string]interface {}{ - "dependencies": []interface {}{"regex"}, - "name":"decode", - "version":"0.1.0"}, + map[string]interface{}{ + "dependencies": []interface{}{"regex"}, + "name": "decode", + "version": "0.1.0"}, } assert.Equal(t, expected, a) } + +func TestUnmarshalDecodeErrors(t *testing.T) { + examples := []struct { + desc string + data string + msg string + }{ + { + desc: "int with wrong base", + data: `a = 0f2`, + }, + { + desc: "literal string with new lines", + data: `a = 'hello +world'`, + msg: `literal strings cannot have new lines`, + }, + { + desc: "unterminated literal string", + data: `a = 'hello`, + msg: `unterminated literal string`, + }, + { + desc: "unterminated multiline literal string", + data: `a = '''hello`, + msg: `multiline literal string not terminated by '''`, + }, + { + desc: "basic string with new lines", + data: `a = "hello +"`, + msg: `basic strings cannot have new lines`, + }, + { + desc: "basic string with unfinished escape", + data: `a = "hello \`, + msg: `need a character after \`, + }, + { + desc: "basic unfinished multiline string", + data: `a = """hello`, + msg: `multiline basic string not terminated by """`, + }, + { + desc: "basic unfinished escape in multiline string", + data: `a = """hello \`, + msg: `need a character after \`, + }, + { + desc: "malformed local date", + data: `a = 2021-033-0`, + msg: `dates are expected to have the format YYYY-MM-DD`, + }, + { + desc: "malformed tz", + data: `a = 2021-03-30 21:31:00+1`, + msg: `invalid date-time timezone`, + }, + { + desc: "malformed tz first char", + data: `a = 2021-03-30 21:31:00:1`, + msg: `extra characters at the end of a local date time`, + }, + { + desc: "bad char between hours and minutes", + data: `a = 2021-03-30 213:1:00`, + msg: `expecting colon between hours and minutes`, + }, + { + desc: "bad char between minutes and seconds", + data: `a = 2021-03-30 21:312:0`, + msg: `expecting colon between minutes and seconds`, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + m := map[string]interface{}{} + err := toml.Unmarshal([]byte(e.data), &m) + require.Error(t, err) + de, ok := err.(*toml.DecodeError) + if !ok { + t.Fatalf("err should have been a *toml.DecodeError, but got %s (%T)", err, err) + } + if e.msg != "" { + t.Log("\n" + de.String()) + require.Equal(t, e.msg, de.Error()) + } + }) + } +}