From 721fa81f2e14298a63077ac2c403dccd31f67732 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Wed, 10 Feb 2021 10:00:08 -0500 Subject: [PATCH] Support numbers --- internal/reflectbuild/reflectbuild.go | 48 ++++++++ parser.go | 167 ++++++++++++++++++++++---- unmarshal.go | 32 +++++ unmarshal_test.go | 7 ++ 4 files changed, 232 insertions(+), 22 deletions(-) diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 7199f9f..8225d23 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -210,6 +210,54 @@ func (b *Builder) SetBool(v bool) error { return nil } +func (b *Builder) SetFloat(n float64) error { + t := b.top() + + err := checkKindFloat(t.Type()) + if err != nil { + return err + } + + t.SetFloat(n) + return nil +} + +func (b *Builder) SetInt(n int64) error { + t := b.top() + + err := checkKindInt(t.Type()) + if err != nil { + return err + } + + t.SetInt(n) + return nil +} + +func checkKindInt(rt reflect.Type) error { + switch rt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return nil + } + + return IncorrectKindError{ + Actual: rt.Kind(), + Expected: reflect.Int, + } +} + +func checkKindFloat(rt reflect.Type) error { + switch rt.Kind() { + case reflect.Float32, reflect.Float64: + return nil + } + + return IncorrectKindError{ + Actual: rt.Kind(), + Expected: reflect.Float64, + } +} + func checkKind(rt reflect.Type, expected reflect.Kind) error { if rt.Kind() != expected { return IncorrectKindError{ diff --git a/parser.go b/parser.go index e1293ec..f44facb 100644 --- a/parser.go +++ b/parser.go @@ -3,8 +3,11 @@ package toml import ( "bytes" "encoding/hex" + "errors" "fmt" "math" + "strconv" + "strings" ) type builder interface { @@ -23,6 +26,7 @@ type builder interface { StringValue(v []byte) BoolValue(b bool) FloatValue(n float64) + IntValue(n int64) } type parser struct { @@ -605,11 +609,11 @@ func (p parser) parseIntOrFloatOrDateTime(b []byte) ([]byte, error) { p.builder.FloatValue(math.NaN()) return b[3:], nil case '+', '-': - return parseIntOrFloat(b) + return p.parseIntOrFloat(b) } if len(b) < 3 { - return parseIntOrFloat(b) + return p.parseIntOrFloat(b) } for idx, c := range b[:5] { if c >= '0' && c <= '9' { @@ -622,48 +626,58 @@ func (p parser) parseIntOrFloatOrDateTime(b []byte) ([]byte, error) { return parseDateTime(b) } } - return parseIntOrFloat(b) + return p.parseIntOrFloat(b) } func parseDateTime(b []byte) ([]byte, error) { - + panic("implement me") } func (p parser) parseIntOrFloat(b []byte) ([]byte, error) { + i := 0 r := b[0] if r == '0' { if len(b) >= 2 { var isValidRune validRuneFn + var parseFn func([]byte) (int64, error) switch b[1] { case 'x': isValidRune = isValidHexRune + parseFn = parseIntHex case 'o': isValidRune = isValidOctalRune + parseFn = parseIntOct case 'b': isValidRune = isValidBinaryRune + parseFn = parseIntBin default: if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' { return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1])) } + parseFn = parseIntDec } if isValidRune != nil { - b = b[2:] + i = 2 digitSeen := false for { - if !isValidRune(b[0]) { + if !isValidRune(b[i]) { break } digitSeen = true - b = b[1:] + i++ } if !digitSeen { return nil, fmt.Errorf("number needs at least one digit") } - p.builder.IntValue() - return b, nil + v, err := parseFn(b[:i]) + if err != nil { + return nil, err + } + p.builder.IntValue(v) + return b[i:], nil } } } @@ -687,31 +701,31 @@ func (p parser) parseIntOrFloat(b []byte) ([]byte, error) { pointSeen := false expSeen := false digitSeen := false - for len(b) > 0 { - next := b[0] + for i < len(b) { + next := b[i] if next == '.' { if pointSeen { return nil, fmt.Errorf("cannot have two dots in one float") } - b = b[1:] - if len(b) > 0 && !isDigit(b[0]) { + i++ + if i < len(b) && !isDigit(b[i]) { return nil, fmt.Errorf("float cannot end with a dot") } pointSeen = true } else if next == 'e' || next == 'E' { expSeen = true - b = b[1:] - if len(b) == 0 { + i++ + if i >= len(b) { break } - if b[0] == '+' || b[0] == '-' { - b = b[1:] + if b[i] == '+' || b[i] == '-' { + i++ } } else if isDigit(next) { digitSeen = true - b = b[1:] + i++ } else if next == '_' { - b = b[1:] + i++ } else { break } @@ -724,17 +738,117 @@ func (p parser) parseIntOrFloat(b []byte) ([]byte, error) { return nil, fmt.Errorf("no digit in that number") } if pointSeen || expSeen { - p.builder.FloatValue() + f, err := parseFloat(b[:i]) + if err != nil { + return nil, err + } + p.builder.FloatValue(f) } else { - p.builder.IntValue() + v, err := parseIntDec(b[:i]) + if err != nil { + return nil, err + } + p.builder.IntValue(v) } - return b, nil + return b[i:], nil +} + +func parseFloat(b []byte) (float64, error) { + // TODO: inefficient + tok := string(b) + err := numberContainsInvalidUnderscore(tok) + if err != nil { + return 0, err + } + cleanedVal := cleanupNumberToken(tok) + return strconv.ParseFloat(cleanedVal, 64) +} + +func parseIntHex(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := hexNumberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, nil + } + return strconv.ParseInt(cleanedVal[2:], 16, 64) +} + +func parseIntOct(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 8, 64) +} + +func parseIntBin(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 2, 64) +} + +func parseIntDec(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal, 10, 64) +} + +func numberContainsInvalidUnderscore(value string) error { + // For large numbers, you may use underscores between digits to enhance + // readability. Each underscore must be surrounded by at least one digit on + // each side. + + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscore + } + } + hasBefore = isDigitRune(r) + } + return nil +} + +func hexNumberContainsInvalidUnderscore(value string) error { + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscoreHex + } + } + hasBefore = isHexDigit(r) + } + return nil +} + +func cleanupNumberToken(value string) string { + cleanedVal := strings.Replace(value, "_", "", -1) + return cleanedVal } func isDigit(r byte) bool { return r >= '0' && r <= '9' } +func isDigitRune(r rune) bool { + return r >= '0' && r <= '9' +} + var plusInf = math.Inf(1) var minusInf = math.Inf(-1) var nan = math.NaN() @@ -748,6 +862,12 @@ func isValidHexRune(r byte) bool { r == '_' } +func isHexDigit(r rune) bool { + return isDigitRune(r) || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + func isValidOctalRune(r byte) bool { return r >= '0' && r <= '7' || r == '_' } @@ -775,3 +895,6 @@ func (u unexpectedCharacter) Error() string { } return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) } + +var errInvalidUnderscore = errors.New("invalid use of _ in number") +var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/unmarshal.go b/unmarshal.go index 96ecc7d..6462999 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -137,6 +137,38 @@ func (u *unmarshaler) BoolValue(b bool) { } } +func (u *unmarshaler) FloatValue(n float64) { + if u.err != nil { + return + } + if u.builder.IsSlice() { + u.builder.Save() + u.err = u.builder.SliceAppend(reflect.ValueOf(n)) + if u.err != nil { + return + } + u.builder.Load() + } else { + u.err = u.builder.SetFloat(n) + } +} + +func (u *unmarshaler) IntValue(n int64) { + if u.err != nil { + return + } + if u.builder.IsSlice() { + u.builder.Save() + u.err = u.builder.SliceAppend(reflect.ValueOf(n)) + if u.err != nil { + return + } + u.builder.Load() + } else { + u.err = u.builder.SetInt(n) + } +} + func (u *unmarshaler) SimpleKey(v []byte) { if u.err != nil { return diff --git a/unmarshal_test.go b/unmarshal_test.go index 96d4696..b5fd8f2 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -15,6 +15,13 @@ func TestUnmarshalSimple(t *testing.T) { assert.Equal(t, "hello", x.Foo) } +func TestUnmarshalInt(t *testing.T) { + x := struct{ Foo int }{} + err := toml.Unmarshal([]byte(`Foo = 42`), &x) + require.NoError(t, err) + assert.Equal(t, 42, x.Foo) +} + func TestUnmarshalNestedStructs(t *testing.T) { x := struct{ Foo struct{ Bar string } }{} err := toml.Unmarshal([]byte(`Foo.Bar = "hello"`), &x)