From fcc91f261844d652243b8a5cf30524ed9542ebbf Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Mon, 22 Mar 2021 09:59:15 -0400 Subject: [PATCH] Progress on date/times --- README.md | 6 +- decode.go | 288 +++++++++++++++++++++++++++++++++++++++++ internal/ast/ast.go | 28 ---- internal/ast/decode.go | 113 ---------------- parser.go | 60 +++++++-- unmarshaler.go | 38 +++++- 6 files changed, 380 insertions(+), 153 deletions(-) create mode 100644 decode.go delete mode 100644 internal/ast/decode.go diff --git a/README.md b/README.md index 8f760f5..298ef05 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,10 @@ Development branch. Probably does not work. - [x] Unmarshal into maps. - [x] Support Array Tables. -- [x] Unmarshal into pointers. -- [ ] Support Date / times. +- [ ] Unmarshal into pointers. + > Was supposed to be done, but seems like there are still some assignation + > issues. +- [x] Support Date / times. - [ ] Support Unmarshaler interface. - [x] Support struct tags annotations. - [ ] Original go-toml unmarshal tests pass. diff --git a/decode.go b/decode.go new file mode 100644 index 0000000..5bd321e --- /dev/null +++ b/decode.go @@ -0,0 +1,288 @@ +package toml + +import ( + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" +) + +func parseInteger(b []byte) (int64, error) { + if len(b) > 2 && b[0] == '0' { + switch b[1] { + case 'x': + return parseIntHex(b) + case 'b': + return parseIntBin(b) + case 'o': + return parseIntOct(b) + default: + return 0, fmt.Errorf("invalid base: '%c'", b[1]) + } + } + return parseIntDec(b) +} + +func parseLocalDate(b []byte) (LocalDate, error) { + // full-date = date-fullyear "-" date-month "-" date-mday + // date-fullyear = 4DIGIT + // date-month = 2DIGIT ; 01-12 + // date-mday = 2DIGIT ; 01-28, 01-29, 01-30, 01-31 based on month/year + + date := LocalDate{} + + if len(b) != 10 || b[4] != '-' || b[7] != '-' { + return date, fmt.Errorf("dates are expected to have the format YYYY-MM-DD") + } + + var err error + + date.Year, err = parseDecimalDigits(b[0:4]) + if err != nil { + return date, err + } + + v, err := parseDecimalDigits(b[5:7]) + if err != nil { + return date, err + } + date.Month = time.Month(v) + + date.Day, err = parseDecimalDigits(b[8:10]) + + return date, nil +} + +func parseDecimalDigits(b []byte) (int, error) { + v := 0 + for _, c := range b { + if !isDigit(c) { + return 0, fmt.Errorf("expected digit") + } + v *= 10 + v += int(c - '0') + } + return v, nil +} + +func parseDateTime(b []byte) (time.Time, error) { + // offset-date-time = full-date time-delim full-time + // full-time = partial-time time-offset + // time-offset = "Z" / time-numoffset + // time-numoffset = ( "+" / "-" ) time-hour ":" time-minute + + dt, b, err := parseLocalDateTime(b) + if err != nil { + return time.Time{}, nil + } + + var zone *time.Location + + if len(b) == 0 { + return time.Time{}, fmt.Errorf("date-time missing timezone information") + } + + if b[0] == 'Z' { + b = b[1:] + zone = time.UTC + } else { + if len(b) != 6 { + return time.Time{}, fmt.Errorf("invalid date-time timezone") + } + direction := 1 + switch b[0] { + case '+': + case '-': + direction = -1 + default: + return time.Time{}, fmt.Errorf("invalid timezone offset character") + } + + hours := digitsToInt(b[1:3]) + minutes := digitsToInt(b[4:6]) + seconds := direction * (hours*3600 + minutes*60) + zone = time.FixedZone("", seconds) + } + + if len(b) > 0 { + return time.Time{}, fmt.Errorf("extra bytes at the end of the timezone") + } + + t := time.Date( + dt.Date.Year, + dt.Date.Month, + dt.Date.Day, + dt.Time.Hour, + dt.Time.Minute, + dt.Time.Second, + dt.Time.Nanosecond, + zone) + + return t, nil +} + +func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) { + dt := LocalDateTime{} + + if len(b) < 11 { + return dt, nil, fmt.Errorf("local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNN]") + } + + date, err := parseLocalDate(b[:10]) + if err != nil { + return dt, nil, err + } + dt.Date = date + + sep := b[10] + if sep != 'T' && sep != ' ' { + return dt, nil, fmt.Errorf("datetime separator is expected to be T or a space") + } + + t, rest, err := parseLocalTime(b[11:]) + if err != nil { + return dt, nil, err + } + dt.Time = t + + return dt, rest, nil +} + +// parseLocalTime is a bit different because it also returns the remaining +// []byte that is didn't need. This is to allow parseDateTime to parse those +// remaining bytes as a timezone. +func parseLocalTime(b []byte) (LocalTime, []byte, error) { + t := LocalTime{} + + if len(b) < 8 { + return t, nil, fmt.Errorf("times are expected to have the format HH:MM:SS[.NNNNNN]") + } + + var err error + t.Hour, err = parseDecimalDigits(b[0:2]) + if err != nil { + return t, nil, err + } + t.Minute, err = parseDecimalDigits(b[3:5]) + if err != nil { + return t, nil, err + } + t.Second, err = parseDecimalDigits(b[6:8]) + if err != nil { + return t, nil, err + } + + if len(b) >= 15 && b[8] == '.' { + t.Nanosecond, err = parseDecimalDigits(b[9:15]) + return t, b[15:], nil + } + + return t, b[8:], nil +} + +func parseFloat(b []byte) (float64, error) { + // TODO: inefficient + if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' { + return math.NaN(), nil + } + + tok := string(b) + err := numberContainsInvalidUnderscore(tok) + if err != nil { + return 0, err + } + cleanedVal := cleanupNumberToken(tok) + return strconv.ParseFloat(cleanedVal, 64) +} + +func parseIntHex(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := hexNumberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, nil + } + return strconv.ParseInt(cleanedVal[2:], 16, 64) +} + +func parseIntOct(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 8, 64) +} + +func parseIntBin(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 2, 64) +} + +func parseIntDec(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal, 10, 64) +} + +func numberContainsInvalidUnderscore(value string) error { + // For large numbers, you may use underscores between digits to enhance + // readability. Each underscore must be surrounded by at least one digit on + // each side. + + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscore + } + } + hasBefore = isDigitRune(r) + } + return nil +} + +func hexNumberContainsInvalidUnderscore(value string) error { + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscoreHex + } + } + hasBefore = isHexDigit(r) + } + return nil +} + +func cleanupNumberToken(value string) string { + cleanedVal := strings.Replace(value, "_", "", -1) + return cleanedVal +} + +func isHexDigit(r rune) bool { + return isDigitRune(r) || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + +func isDigitRune(r rune) bool { + return r >= '0' && r <= '9' +} + +var errInvalidUnderscore = errors.New("invalid use of _ in number") +var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/internal/ast/ast.go b/internal/ast/ast.go index a8727e3..e927dd0 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -211,34 +211,6 @@ func (n *Node) Value() *Node { return &n.Children[len(n.Children)-1] } -// DecodeInteger parse the data of an Integer node and returns the represented -// int64, or an error. -// Panics if not called on an Integer node. -func (n *Node) DecodeInteger() (int64, error) { - assertKind(Integer, n) - if len(n.Data) > 2 && n.Data[0] == '0' { - switch n.Data[1] { - case 'x': - return parseIntHex(n.Data) - case 'b': - return parseIntBin(n.Data) - case 'o': - return parseIntOct(n.Data) - default: - return 0, fmt.Errorf("invalid base: '%c'", n.Data[1]) - } - } - return parseIntDec(n.Data) -} - -// DecodeFloat parse the data of a Float node and returns the represented -// float64, or an error. -// Panics if not called on an Float node. -func (n *Node) DecodeFloat() (float64, error) { - assertKind(Float, n) - return parseFloat(n.Data) -} - func assertKind(k Kind, n *Node) { if n.Kind != k { panic(fmt.Errorf("method was expecting a %s, not a %s", k, n.Kind)) diff --git a/internal/ast/decode.go b/internal/ast/decode.go deleted file mode 100644 index a27f04c..0000000 --- a/internal/ast/decode.go +++ /dev/null @@ -1,113 +0,0 @@ -package ast - -import ( - "errors" - "math" - "strconv" - "strings" -) - -func parseFloat(b []byte) (float64, error) { - // TODO: inefficient - if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' { - return math.NaN(), nil - } - - tok := string(b) - err := numberContainsInvalidUnderscore(tok) - if err != nil { - return 0, err - } - cleanedVal := cleanupNumberToken(tok) - return strconv.ParseFloat(cleanedVal, 64) -} - -func parseIntHex(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := hexNumberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, nil - } - return strconv.ParseInt(cleanedVal[2:], 16, 64) -} - -func parseIntOct(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal[2:], 8, 64) -} - -func parseIntBin(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal[2:], 2, 64) -} - -func parseIntDec(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal, 10, 64) -} - -func numberContainsInvalidUnderscore(value string) error { - // For large numbers, you may use underscores between digits to enhance - // readability. Each underscore must be surrounded by at least one digit on - // each side. - - hasBefore := false - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscore - } - } - hasBefore = isDigitRune(r) - } - return nil -} - -func hexNumberContainsInvalidUnderscore(value string) error { - hasBefore := false - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscoreHex - } - } - hasBefore = isHexDigit(r) - } - return nil -} - -func cleanupNumberToken(value string) string { - cleanedVal := strings.Replace(value, "_", "", -1) - return cleanedVal -} - -func isHexDigit(r rune) bool { - return isDigitRune(r) || - (r >= 'a' && r <= 'f') || - (r >= 'A' && r <= 'F') -} - -func isDigitRune(r rune) bool { - return r >= '0' && r <= '9' -} - -var errInvalidUnderscore = errors.New("invalid use of _ in number") -var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/parser.go b/parser.go index 803e48a..289bb15 100644 --- a/parser.go +++ b/parser.go @@ -637,14 +637,11 @@ func (p *parser) parseIntOrFloatOrDateTime(node *ast.Node, b []byte) ([]byte, er s = len(b) } for idx, c := range b[:s] { - if c >= '0' && c <= '9' { + if isDigit(c) { continue } - if idx == 2 && c == ':' { - return p.parseDateTime(b) - } - if idx == 4 && c == '-' { - return p.parseDateTime(b) + if idx == 2 && c == ':' || (idx == 4 && c == '-') { + return p.scanDateTime(node, b) } } return p.scanIntOrFloat(node, b) @@ -659,14 +656,61 @@ func digitsToInt(b []byte) int { return x } +func (p *parser) scanDateTime(node *ast.Node, b []byte) ([]byte, error) { + // scans for contiguous characters in [0-9T:Z.+-], and up to one space if + // followed by a digit. + + hasTime := false + hasTz := false + seenSpace := false + + i := 0 + for ; i < len(b); i++ { + c := b[i] + if isDigit(c) || c == '-' { + } else if c == 'T' || c == ':' || c == '.' { + hasTime = true + continue + } else if c == '+' || c == '-' || c == 'Z' { + hasTz = true + } else if c == ' ' { + if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) { + i += 2 + seenSpace = true + hasTime = true + } else { + break + } + } else { + break + } + } + + if hasTime { + if hasTz { + node.Kind = ast.DateTime + } else { + node.Kind = ast.LocalDateTime + } + } else { + if hasTz { + return nil, fmt.Errorf("possible DateTime cannot have a timezone but no time component") + } + node.Kind = ast.LocalDate + } + + node.Data = b[:i] + + return b[i:], nil +} + func (p *parser) parseDateTime(b []byte) ([]byte, error) { - // we know the first 2 ar digits. + // we know the first 2 are digits. if b[2] == ':' { return p.parseTime(b) } // This state accepts an offset date-time, a local date-time, or a local date. // - // v--- cursor // 1979-05-27T07:32:00Z // 1979-05-27T00:32:00-07:00 // 1979-05-27T00:32:00.999999-07:00 diff --git a/unmarshaler.go b/unmarshaler.go index 45f8fbc..536b103 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "reflect" + "time" "github.com/pelletier/go-toml/v2/internal/ast" ) @@ -166,11 +167,44 @@ func unmarshalValue(x target, node *ast.Node) error { return unmarshalArray(x, node) case ast.InlineTable: return unmarshalInlineTable(x, node) + case ast.LocalDateTime: + return unmarshalLocalDateTime(x, node) + case ast.DateTime: + return unmarshalDateTime(x, node) default: panic(fmt.Errorf("unhandled unmarshalValue kind %s", node.Kind)) } } +func unmarshalLocalDateTime(x target, node *ast.Node) error { + assertNode(ast.LocalDateTime, node) + v, rest, err := parseLocalDateTime(node.Data) + if err != nil { + return err + } + if len(rest) > 0 { + return fmt.Errorf("extra characters at the end of a local date time") + } + return setLocalDateTime(x, v) +} + +func unmarshalDateTime(x target, node *ast.Node) error { + assertNode(ast.DateTime, node) + v, err := parseDateTime(node.Data) + if err != nil { + return err + } + return setDateTime(x, v) +} + +func setLocalDateTime(x target, v LocalDateTime) error { + return x.set(reflect.ValueOf(v)) +} + +func setDateTime(x target, v time.Time) error { + return x.set(reflect.ValueOf(v)) +} + func unmarshalString(x target, node *ast.Node) error { assertNode(ast.String, node) return setString(x, string(node.Data)) @@ -184,7 +218,7 @@ func unmarshalBool(x target, node *ast.Node) error { func unmarshalInteger(x target, node *ast.Node) error { assertNode(ast.Integer, node) - v, err := node.DecodeInteger() + v, err := parseInteger(node.Data) if err != nil { return err } @@ -193,7 +227,7 @@ func unmarshalInteger(x target, node *ast.Node) error { func unmarshalFloat(x target, node *ast.Node) error { assertNode(ast.Float, node) - v, err := node.DecodeFloat() + v, err := parseFloat(node.Data) if err != nil { return err }