Unmarshal ints and floats

This commit is contained in:
Thomas Pelletier
2021-03-14 18:06:34 -04:00
parent 9a1cfcdd8e
commit 590d674153
7 changed files with 368 additions and 245 deletions
+35 -3
View File
@@ -158,11 +158,43 @@ func (n *Node) Key() []Node {
// Guaranteed to be non-nil.
// Panics if not called on a KeyValue node, or if the Children are malformed.
func (n *Node) Value() *Node {
if n.Kind != KeyValue {
panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind))
}
assertKind(KeyValue, n)
if len(n.Children) < 2 {
panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children)))
}
return &n.Children[len(n.Children)-1]
}
// DecodeInteger parse the data of an Integer node and returns the represented
// int64, or an error.
// Panics if not called on an Integer node.
func (n *Node) DecodeInteger() (int64, error) {
assertKind(Integer, n)
if len(n.Data) > 2 && n.Data[0] == '0' {
switch n.Data[1] {
case 'x':
return parseIntHex(n.Data)
case 'b':
return parseIntBin(n.Data)
case 'o':
return parseIntOct(n.Data)
default:
return 0, fmt.Errorf("invalid base: '%c'", n.Data[1])
}
}
return parseIntDec(n.Data)
}
// DecodeFloat parse the data of a Float node and returns the represented
// float64, or an error.
// Panics if not called on an Float node.
func (n *Node) DecodeFloat() (float64, error) {
assertKind(Float, n)
return parseFloat(n.Data)
}
func assertKind(k Kind, n *Node) {
if n.Kind != k {
panic(fmt.Errorf("method was expecting a %s, not a %s", k, n.Kind))
}
}
+113
View File
@@ -0,0 +1,113 @@
package ast
import (
"errors"
"math"
"strconv"
"strings"
)
func parseFloat(b []byte) (float64, error) {
// TODO: inefficient
if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' {
return math.NaN(), nil
}
tok := string(b)
err := numberContainsInvalidUnderscore(tok)
if err != nil {
return 0, err
}
cleanedVal := cleanupNumberToken(tok)
return strconv.ParseFloat(cleanedVal, 64)
}
func parseIntHex(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, nil
}
return strconv.ParseInt(cleanedVal[2:], 16, 64)
}
func parseIntOct(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 8, 64)
}
func parseIntBin(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 2, 64)
}
func parseIntDec(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal, 10, 64)
}
func numberContainsInvalidUnderscore(value string) error {
// For large numbers, you may use underscores between digits to enhance
// readability. Each underscore must be surrounded by at least one digit on
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
}
return nil
}
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
}
return nil
}
func cleanupNumberToken(value string) string {
cleanedVal := strings.Replace(value, "_", "", -1)
return cleanedVal
}
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var errInvalidUnderscore = errors.New("invalid use of _ in number")
var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
-241
View File
@@ -3,15 +3,10 @@ package unmarshaler
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
"github.com/pelletier/go-toml/v2/internal/ast"
)
@@ -234,8 +229,6 @@ func (p *parser) parseVal(b []byte) (ast.Node, []byte, error) {
b, err = p.parseIntOrFloatOrDateTime(&node, b)
return node, b, err
}
panic("parseVal not finished yet")
return ast.Node{}, nil, nil
}
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) {
@@ -994,235 +987,10 @@ func (p *parser) scanIntOrFloat(node *ast.Node, b []byte) ([]byte, error) {
return b[i:], nil
}
//func (p *parser) parseIntOrFloat(node *ast.Node, b []byte) ([]byte, error) {
// i := 0
// r := b[0]
// if r == '0' {
// if len(b) >= 2 {
// var isValidRune validRuneFn
// var parseFn func([]byte) (int64, error)
// switch b[1] {
// case 'x':
// isValidRune = isValidHexRune
// parseFn = parseIntHex
// case 'o':
// isValidRune = isValidOctalRune
// parseFn = parseIntOct
// case 'b':
// isValidRune = isValidBinaryRune
// parseFn = parseIntBin
// default:
// if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' {
// return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1]))
// }
// parseFn = parseIntDec
// }
//
// if isValidRune != nil {
// i = 2
// digitSeen := false
// for {
// if !isValidRune(b[i]) {
// break
// }
// digitSeen = true
// i++
// }
//
// if !digitSeen {
// return nil, fmt.Errorf("number needs at least one digit")
// }
//
// v, err := parseFn(b[:i])
// if err != nil {
// return nil, err
// }
// //p.builder.IntValue(v)
// // TODO
// v = v
// return b[i:], nil
// }
// }
// }
//
// if r == '+' || r == '-' {
// b = b[1:]
// if scanFollowsInf(b) {
// if r == '+' {
// //p.builder.FloatValue(plusInf)
// // TODO
// } else {
// //p.builder.FloatValue(minusInf)
// // TODO
// }
// return b, nil
// }
// if scanFollowsNan(b) {
// //p.builder.FloatValue(nan)
// // TODO
// return b, nil
// }
// }
//
// pointSeen := false
// expSeen := false
// digitSeen := false
// for i < len(b) {
// next := b[i]
// if next == '.' {
// if pointSeen {
// return nil, fmt.Errorf("cannot have two dots in one float")
// }
// i++
// if i < len(b) && !isDigit(b[i]) {
// return nil, fmt.Errorf("float cannot end with a dot")
// }
// pointSeen = true
// } else if next == 'e' || next == 'E' {
// expSeen = true
// i++
// if i >= len(b) {
// break
// }
// if b[i] == '+' || b[i] == '-' {
// i++
// }
// } else if isDigit(next) {
// digitSeen = true
// i++
// } else if next == '_' {
// i++
// } else {
// break
// }
// if pointSeen && !digitSeen {
// return nil, fmt.Errorf("cannot start float with a dot")
// }
// }
//
// if !digitSeen {
// return nil, fmt.Errorf("no digit in that number")
// }
// if pointSeen || expSeen {
// f, err := parseFloat(b[:i])
// if err != nil {
// return nil, err
// }
// //p.builder.FloatValue(f)
// // TODO
// f = f
// } else {
// v, err := parseIntDec(b[:i])
// if err != nil {
// return nil, err
// }
// //p.builder.IntValue(v)
// // TODO
// v = v
// }
// return b[i:], nil
//}
func parseFloat(b []byte) (float64, error) {
// TODO: inefficient
tok := string(b)
err := numberContainsInvalidUnderscore(tok)
if err != nil {
return 0, err
}
cleanedVal := cleanupNumberToken(tok)
return strconv.ParseFloat(cleanedVal, 64)
}
func parseIntHex(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := hexNumberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, nil
}
return strconv.ParseInt(cleanedVal[2:], 16, 64)
}
func parseIntOct(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 8, 64)
}
func parseIntBin(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal[2:], 2, 64)
}
func parseIntDec(b []byte) (int64, error) {
tok := string(b)
cleanedVal := cleanupNumberToken(tok)
err := numberContainsInvalidUnderscore(cleanedVal)
if err != nil {
return 0, err
}
return strconv.ParseInt(cleanedVal, 10, 64)
}
func numberContainsInvalidUnderscore(value string) error {
// For large numbers, you may use underscores between digits to enhance
// readability. Each underscore must be surrounded by at least one digit on
// each side.
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscore
}
}
hasBefore = isDigitRune(r)
}
return nil
}
func hexNumberContainsInvalidUnderscore(value string) error {
hasBefore := false
for idx, r := range value {
if r == '_' {
if !hasBefore || idx+1 >= len(value) {
// can't end with an underscore
return errInvalidUnderscoreHex
}
}
hasBefore = isHexDigit(r)
}
return nil
}
func cleanupNumberToken(value string) string {
cleanedVal := strings.Replace(value, "_", "", -1)
return cleanedVal
}
func isDigit(r byte) bool {
return r >= '0' && r <= '9'
}
func isDigitRune(r rune) bool {
return r >= '0' && r <= '9'
}
var plusInf = math.Inf(1)
var minusInf = math.Inf(-1)
var nan = math.NaN()
type validRuneFn func(r byte) bool
func isValidHexRune(r byte) bool {
@@ -1232,12 +1000,6 @@ func isValidHexRune(r byte) bool {
r == '_'
}
func isHexDigit(r rune) bool {
return isDigitRune(r) ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isValidOctalRune(r byte) bool {
return r >= '0' && r <= '7' || r == '_'
}
@@ -1265,6 +1027,3 @@ func (u unexpectedCharacter) Error() string {
}
return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0])
}
var errInvalidUnderscore = errors.New("invalid use of _ in number")
var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number")
+1 -1
View File
@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestParser_Numbers(t *testing.T) {
func TestParser_AST_Numbers(t *testing.T) {
examples := []struct {
desc string
input string
+38
View File
@@ -15,6 +15,12 @@ type target interface {
// Store a boolean at the target
setBool(v bool) error
// Store an int64 at the target
setInt64(v int64) error
// Store a float64 at the target
setFloat64(v float64) error
// Creates a new value of the container's element type, and returns a
// target to it.
pushNew() (target, error)
@@ -83,6 +89,38 @@ func (t valueTarget) setBool(v bool) error {
return nil
}
func (t valueTarget) setInt64(v int64) error {
f := t.get()
switch f.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// TODO: overflow checks
f.SetInt(v)
case reflect.Interface:
f.Set(reflect.ValueOf(v))
default:
return fmt.Errorf("cannot assign int64 to a %s", f.String())
}
return nil
}
func (t valueTarget) setFloat64(v float64) error {
f := t.get()
switch f.Kind() {
case reflect.Float32, reflect.Float64:
// TODO: overflow checks
f.SetFloat(v)
case reflect.Interface:
f.Set(reflect.ValueOf(v))
default:
return fmt.Errorf("cannot assign float64 to a %s", f.String())
}
return nil
}
func (t valueTarget) pushNew() (target, error) {
f := t.get()
+22
View File
@@ -80,6 +80,10 @@ func unmarshalValue(x target, node *ast.Node) error {
return unmarshalString(x, node)
case ast.Bool:
return unmarshalBool(x, node)
case ast.Integer:
return unmarshalInteger(x, node)
case ast.Float:
return unmarshalFloat(x, node)
case ast.Array:
return unmarshalArray(x, node)
case ast.InlineTable:
@@ -100,6 +104,24 @@ func unmarshalBool(x target, node *ast.Node) error {
return x.setBool(v)
}
func unmarshalInteger(x target, node *ast.Node) error {
assertNode(ast.Integer, node)
v, err := node.DecodeInteger()
if err != nil {
return err
}
return x.setInt64(v)
}
func unmarshalFloat(x target, node *ast.Node) error {
assertNode(ast.Float, node)
v, err := node.DecodeFloat()
if err != nil {
return err
}
return x.setFloat64(v)
}
func unmarshalInlineTable(x target, node *ast.Node) error {
assertNode(ast.InlineTable, node)
+159
View File
@@ -1,6 +1,7 @@
package unmarshaler
import (
"math"
"testing"
"github.com/stretchr/testify/assert"
@@ -9,6 +10,164 @@ import (
"github.com/pelletier/go-toml/v2/internal/ast"
)
func TestUnmarshal_Integers(t *testing.T) {
examples := []struct {
desc string
input string
expected int64
err bool
}{
{
desc: "integer just digits",
input: `1234`,
expected: 1234,
},
{
desc: "integer zero",
input: `0`,
expected: 0,
},
{
desc: "integer sign",
input: `+99`,
expected: 99,
},
{
desc: "integer hex uppercase",
input: `0xDEADBEEF`,
expected: 0xDEADBEEF,
},
{
desc: "integer hex lowercase",
input: `0xdead_beef`,
expected: 0xDEADBEEF,
},
{
desc: "integer octal",
input: `0o01234567`,
expected: 0o01234567,
},
{
desc: "integer binary",
input: `0b11010110`,
expected: 0b11010110,
},
}
type doc struct {
A int64
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
doc := doc{}
err := Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err)
assert.Equal(t, e.expected, doc.A)
})
}
}
func TestUnmarshal_Floats(t *testing.T) {
examples := []struct {
desc string
input string
expected float64
testFn func(t *testing.T, v float64)
err bool
}{
{
desc: "float pi",
input: `3.1415`,
expected: 3.1415,
},
{
desc: "float negative",
input: `-0.01`,
expected: -0.01,
},
{
desc: "float signed exponent",
input: `5e+22`,
expected: 5e+22,
},
{
desc: "float exponent lowercase",
input: `1e06`,
expected: 1e06,
},
{
desc: "float exponent uppercase",
input: `-2E-2`,
expected: -2e-2,
},
{
desc: "float fractional with exponent",
input: `6.626e-34`,
expected: 6.626e-34,
},
{
desc: "float underscores",
input: `224_617.445_991_228`,
expected: 224_617.445_991_228,
},
{
desc: "inf",
input: `inf`,
expected: math.Inf(+1),
},
{
desc: "inf negative",
input: `-inf`,
expected: math.Inf(-1),
},
{
desc: "inf positive",
input: `+inf`,
expected: math.Inf(+1),
},
{
desc: "nan",
input: `nan`,
testFn: func(t *testing.T, v float64) {
assert.True(t, math.IsNaN(v))
},
},
{
desc: "nan negative",
input: `-nan`,
testFn: func(t *testing.T, v float64) {
assert.True(t, math.IsNaN(v))
},
},
{
desc: "nan positive",
input: `+nan`,
testFn: func(t *testing.T, v float64) {
assert.True(t, math.IsNaN(v))
},
},
}
type doc struct {
A float64
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
doc := doc{}
err := Unmarshal([]byte(`A = `+e.input), &doc)
require.NoError(t, err)
if e.testFn != nil {
e.testFn(t, doc.A)
} else {
assert.Equal(t, e.expected, doc.A)
}
})
}
}
func TestUnmarshal(t *testing.T) {
type test struct {
target interface{}