Support numbers

This commit is contained in:
Thomas Pelletier
2021-02-10 10:00:08 -05:00
parent f6a13d6e05
commit 721fa81f2e
4 changed files with 232 additions and 22 deletions
+48
View File
@@ -210,6 +210,54 @@ func (b *Builder) SetBool(v bool) error {
return nil return nil
} }
func (b *Builder) SetFloat(n float64) error {
t := b.top()
err := checkKindFloat(t.Type())
if err != nil {
return err
}
t.SetFloat(n)
return nil
}
func (b *Builder) SetInt(n int64) error {
t := b.top()
err := checkKindInt(t.Type())
if err != nil {
return err
}
t.SetInt(n)
return nil
}
func checkKindInt(rt reflect.Type) error {
switch rt.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return nil
}
return IncorrectKindError{
Actual: rt.Kind(),
Expected: reflect.Int,
}
}
func checkKindFloat(rt reflect.Type) error {
switch rt.Kind() {
case reflect.Float32, reflect.Float64:
return nil
}
return IncorrectKindError{
Actual: rt.Kind(),
Expected: reflect.Float64,
}
}
func checkKind(rt reflect.Type, expected reflect.Kind) error { func checkKind(rt reflect.Type, expected reflect.Kind) error {
if rt.Kind() != expected { if rt.Kind() != expected {
return IncorrectKindError{ return IncorrectKindError{
+145 -22
View File
@@ -3,8 +3,11 @@ package toml
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"math" "math"
"strconv"
"strings"
) )
type builder interface { type builder interface {
@@ -23,6 +26,7 @@ type builder interface {
StringValue(v []byte) StringValue(v []byte)
BoolValue(b bool) BoolValue(b bool)
FloatValue(n float64) FloatValue(n float64)
IntValue(n int64)
} }
type parser struct { type parser struct {
@@ -605,11 +609,11 @@ func (p parser) parseIntOrFloatOrDateTime(b []byte) ([]byte, error) {
p.builder.FloatValue(math.NaN()) p.builder.FloatValue(math.NaN())
return b[3:], nil return b[3:], nil
case '+', '-': case '+', '-':
return parseIntOrFloat(b) return p.parseIntOrFloat(b)
} }
if len(b) < 3 { if len(b) < 3 {
return parseIntOrFloat(b) return p.parseIntOrFloat(b)
} }
for idx, c := range b[:5] { for idx, c := range b[:5] {
if c >= '0' && c <= '9' { if c >= '0' && c <= '9' {
@@ -622,48 +626,58 @@ func (p parser) parseIntOrFloatOrDateTime(b []byte) ([]byte, error) {
return parseDateTime(b) return parseDateTime(b)
} }
} }
return parseIntOrFloat(b) return p.parseIntOrFloat(b)
} }
func parseDateTime(b []byte) ([]byte, error) { func parseDateTime(b []byte) ([]byte, error) {
panic("implement me")
} }
func (p parser) parseIntOrFloat(b []byte) ([]byte, error) { func (p parser) parseIntOrFloat(b []byte) ([]byte, error) {
i := 0
r := b[0] r := b[0]
if r == '0' { if r == '0' {
if len(b) >= 2 { if len(b) >= 2 {
var isValidRune validRuneFn var isValidRune validRuneFn
var parseFn func([]byte) (int64, error)
switch b[1] { switch b[1] {
case 'x': case 'x':
isValidRune = isValidHexRune isValidRune = isValidHexRune
parseFn = parseIntHex
case 'o': case 'o':
isValidRune = isValidOctalRune isValidRune = isValidOctalRune
parseFn = parseIntOct
case 'b': case 'b':
isValidRune = isValidBinaryRune isValidRune = isValidBinaryRune
parseFn = parseIntBin
default: default:
if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' { if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' {
return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1])) return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1]))
} }
parseFn = parseIntDec
} }
if isValidRune != nil { if isValidRune != nil {
b = b[2:] i = 2
digitSeen := false digitSeen := false
for { for {
if !isValidRune(b[0]) { if !isValidRune(b[i]) {
break break
} }
digitSeen = true digitSeen = true
b = b[1:] i++
} }
if !digitSeen { if !digitSeen {
return nil, fmt.Errorf("number needs at least one digit") return nil, fmt.Errorf("number needs at least one digit")
} }
p.builder.IntValue() v, err := parseFn(b[:i])
return b, nil if err != nil {
return nil, err
}
p.builder.IntValue(v)
return b[i:], nil
} }
} }
} }
@@ -687,31 +701,31 @@ func (p parser) parseIntOrFloat(b []byte) ([]byte, error) {
pointSeen := false pointSeen := false
expSeen := false expSeen := false
digitSeen := false digitSeen := false
for len(b) > 0 { for i < len(b) {
next := b[0] next := b[i]
if next == '.' { if next == '.' {
if pointSeen { if pointSeen {
return nil, fmt.Errorf("cannot have two dots in one float") return nil, fmt.Errorf("cannot have two dots in one float")
} }
b = b[1:] i++
if len(b) > 0 && !isDigit(b[0]) { if i < len(b) && !isDigit(b[i]) {
return nil, fmt.Errorf("float cannot end with a dot") return nil, fmt.Errorf("float cannot end with a dot")
} }
pointSeen = true pointSeen = true
} else if next == 'e' || next == 'E' { } else if next == 'e' || next == 'E' {
expSeen = true expSeen = true
b = b[1:] i++
if len(b) == 0 { if i >= len(b) {
break break
} }
if b[0] == '+' || b[0] == '-' { if b[i] == '+' || b[i] == '-' {
b = b[1:] i++
} }
} else if isDigit(next) { } else if isDigit(next) {
digitSeen = true digitSeen = true
b = b[1:] i++
} else if next == '_' { } else if next == '_' {
b = b[1:] i++
} else { } else {
break break
} }
@@ -724,17 +738,117 @@ func (p parser) parseIntOrFloat(b []byte) ([]byte, error) {
return nil, fmt.Errorf("no digit in that number") return nil, fmt.Errorf("no digit in that number")
} }
if pointSeen || expSeen { if pointSeen || expSeen {
p.builder.FloatValue() f, err := parseFloat(b[:i])
if err != nil {
return nil, err
}
p.builder.FloatValue(f)
} else { } else {
p.builder.IntValue() v, err := parseIntDec(b[:i])
if err != nil {
return nil, err
}
p.builder.IntValue(v)
} }
return b, nil return b[i:], nil
}
func parseFloat(b []byte) (float64, error) {
// TODO: inefficient
tok := string(b)
err := numberContainsInvalidUnderscore(tok)
if err != nil {
return 0, err
}
cleanedVal := cleanupNumberToken(tok)
return strconv.ParseFloat(cleanedVal, 64)
}
func parseIntHex(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, nil
}
return strconv.ParseInt(cleanedVal[2:], 16, 64)
}
func parseIntOct(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 8, 64)
}
func parseIntBin(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 2, 64)
}
func parseIntDec(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal, 10, 64)
}
func numberContainsInvalidUnderscore(value string) error {
// For large numbers, you may use underscores between digits to enhance
// readability. Each underscore must be surrounded by at least one digit on
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
}
return nil
}
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
}
return nil
}
func cleanupNumberToken(value string) string {
cleanedVal := strings.Replace(value, "_", "", -1)
return cleanedVal
} }
func isDigit(r byte) bool { func isDigit(r byte) bool {
return r >= '0' && r <= '9' return r >= '0' && r <= '9'
} }
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var plusInf = math.Inf(1) var plusInf = math.Inf(1)
var minusInf = math.Inf(-1) var minusInf = math.Inf(-1)
var nan = math.NaN() var nan = math.NaN()
@@ -748,6 +862,12 @@ func isValidHexRune(r byte) bool {
r == '_' r == '_'
} }
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isValidOctalRune(r byte) bool { func isValidOctalRune(r byte) bool {
return r >= '0' && r <= '7' || r == '_' return r >= '0' && r <= '7' || r == '_'
} }
@@ -775,3 +895,6 @@ func (u unexpectedCharacter) Error() string {
} }
return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0])
} }
var errInvalidUnderscore = errors.New("invalid use of _ in number")
var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
+32
View File
@@ -137,6 +137,38 @@ func (u *unmarshaler) BoolValue(b bool) {
} }
} }
func (u *unmarshaler) FloatValue(n float64) {
if u.err != nil {
return
}
if u.builder.IsSlice() {
u.builder.Save()
u.err = u.builder.SliceAppend(reflect.ValueOf(n))
if u.err != nil {
return
}
u.builder.Load()
} else {
u.err = u.builder.SetFloat(n)
}
}
func (u *unmarshaler) IntValue(n int64) {
if u.err != nil {
return
}
if u.builder.IsSlice() {
u.builder.Save()
u.err = u.builder.SliceAppend(reflect.ValueOf(n))
if u.err != nil {
return
}
u.builder.Load()
} else {
u.err = u.builder.SetInt(n)
}
}
func (u *unmarshaler) SimpleKey(v []byte) { func (u *unmarshaler) SimpleKey(v []byte) {
if u.err != nil { if u.err != nil {
return return
+7
View File
@@ -15,6 +15,13 @@ func TestUnmarshalSimple(t *testing.T) {
assert.Equal(t, "hello", x.Foo) assert.Equal(t, "hello", x.Foo)
} }
func TestUnmarshalInt(t *testing.T) {
x := struct{ Foo int }{}
err := toml.Unmarshal([]byte(`Foo = 42`), &x)
require.NoError(t, err)
assert.Equal(t, 42, x.Foo)
}
func TestUnmarshalNestedStructs(t *testing.T) { func TestUnmarshalNestedStructs(t *testing.T) {
x := struct{ Foo struct{ Bar string } }{} x := struct{ Foo struct{ Bar string } }{}
err := toml.Unmarshal([]byte(`Foo.Bar = "hello"`), &x) err := toml.Unmarshal([]byte(`Foo.Bar = "hello"`), &x)