diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 25a89be..eed8675 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -158,11 +158,43 @@ func (n *Node) Key() []Node { // Guaranteed to be non-nil. // Panics if not called on a KeyValue node, or if the Children are malformed. func (n *Node) Value() *Node { - if n.Kind != KeyValue { - panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind)) - } + assertKind(KeyValue, n) if len(n.Children) < 2 { panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children))) } return &n.Children[len(n.Children)-1] } + +// DecodeInteger parse the data of an Integer node and returns the represented +// int64, or an error. +// Panics if not called on an Integer node. +func (n *Node) DecodeInteger() (int64, error) { + assertKind(Integer, n) + if len(n.Data) > 2 && n.Data[0] == '0' { + switch n.Data[1] { + case 'x': + return parseIntHex(n.Data) + case 'b': + return parseIntBin(n.Data) + case 'o': + return parseIntOct(n.Data) + default: + return 0, fmt.Errorf("invalid base: '%c'", n.Data[1]) + } + } + return parseIntDec(n.Data) +} + +// DecodeFloat parse the data of a Float node and returns the represented +// float64, or an error. +// Panics if not called on an Float node. +func (n *Node) DecodeFloat() (float64, error) { + assertKind(Float, n) + return parseFloat(n.Data) +} + +func assertKind(k Kind, n *Node) { + if n.Kind != k { + panic(fmt.Errorf("method was expecting a %s, not a %s", k, n.Kind)) + } +} diff --git a/internal/ast/decode.go b/internal/ast/decode.go new file mode 100644 index 0000000..a27f04c --- /dev/null +++ b/internal/ast/decode.go @@ -0,0 +1,113 @@ +package ast + +import ( + "errors" + "math" + "strconv" + "strings" +) + +func parseFloat(b []byte) (float64, error) { + // TODO: inefficient + if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' { + return math.NaN(), nil + } + + tok := string(b) + err := numberContainsInvalidUnderscore(tok) + if err != nil { + return 0, err + } + cleanedVal := cleanupNumberToken(tok) + return strconv.ParseFloat(cleanedVal, 64) +} + +func parseIntHex(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := hexNumberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, nil + } + return strconv.ParseInt(cleanedVal[2:], 16, 64) +} + +func parseIntOct(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 8, 64) +} + +func parseIntBin(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 2, 64) +} + +func parseIntDec(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal, 10, 64) +} + +func numberContainsInvalidUnderscore(value string) error { + // For large numbers, you may use underscores between digits to enhance + // readability. Each underscore must be surrounded by at least one digit on + // each side. + + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscore + } + } + hasBefore = isDigitRune(r) + } + return nil +} + +func hexNumberContainsInvalidUnderscore(value string) error { + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscoreHex + } + } + hasBefore = isHexDigit(r) + } + return nil +} + +func cleanupNumberToken(value string) string { + cleanedVal := strings.Replace(value, "_", "", -1) + return cleanedVal +} + +func isHexDigit(r rune) bool { + return isDigitRune(r) || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + +func isDigitRune(r rune) bool { + return r >= '0' && r <= '9' +} + +var errInvalidUnderscore = errors.New("invalid use of _ in number") +var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/internal/unmarshaler/parser.go b/internal/unmarshaler/parser.go index a9708bb..f59ecad 100644 --- a/internal/unmarshaler/parser.go +++ b/internal/unmarshaler/parser.go @@ -3,15 +3,10 @@ package unmarshaler import ( "bytes" "encoding/hex" - "errors" "fmt" - "math" - "strconv" - "strings" "time" "github.com/pelletier/go-toml/v2" - "github.com/pelletier/go-toml/v2/internal/ast" ) @@ -234,8 +229,6 @@ func (p *parser) parseVal(b []byte) (ast.Node, []byte, error) { b, err = p.parseIntOrFloatOrDateTime(&node, b) return node, b, err } - panic("parseVal not finished yet") - return ast.Node{}, nil, nil } func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { @@ -994,235 +987,10 @@ func (p *parser) scanIntOrFloat(node *ast.Node, b []byte) ([]byte, error) { return b[i:], nil } -//func (p *parser) parseIntOrFloat(node *ast.Node, b []byte) ([]byte, error) { -// i := 0 -// r := b[0] -// if r == '0' { -// if len(b) >= 2 { -// var isValidRune validRuneFn -// var parseFn func([]byte) (int64, error) -// switch b[1] { -// case 'x': -// isValidRune = isValidHexRune -// parseFn = parseIntHex -// case 'o': -// isValidRune = isValidOctalRune -// parseFn = parseIntOct -// case 'b': -// isValidRune = isValidBinaryRune -// parseFn = parseIntBin -// default: -// if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' { -// return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1])) -// } -// parseFn = parseIntDec -// } -// -// if isValidRune != nil { -// i = 2 -// digitSeen := false -// for { -// if !isValidRune(b[i]) { -// break -// } -// digitSeen = true -// i++ -// } -// -// if !digitSeen { -// return nil, fmt.Errorf("number needs at least one digit") -// } -// -// v, err := parseFn(b[:i]) -// if err != nil { -// return nil, err -// } -// //p.builder.IntValue(v) -// // TODO -// v = v -// return b[i:], nil -// } -// } -// } -// -// if r == '+' || r == '-' { -// b = b[1:] -// if scanFollowsInf(b) { -// if r == '+' { -// //p.builder.FloatValue(plusInf) -// // TODO -// } else { -// //p.builder.FloatValue(minusInf) -// // TODO -// } -// return b, nil -// } -// if scanFollowsNan(b) { -// //p.builder.FloatValue(nan) -// // TODO -// return b, nil -// } -// } -// -// pointSeen := false -// expSeen := false -// digitSeen := false -// for i < len(b) { -// next := b[i] -// if next == '.' { -// if pointSeen { -// return nil, fmt.Errorf("cannot have two dots in one float") -// } -// i++ -// if i < len(b) && !isDigit(b[i]) { -// return nil, fmt.Errorf("float cannot end with a dot") -// } -// pointSeen = true -// } else if next == 'e' || next == 'E' { -// expSeen = true -// i++ -// if i >= len(b) { -// break -// } -// if b[i] == '+' || b[i] == '-' { -// i++ -// } -// } else if isDigit(next) { -// digitSeen = true -// i++ -// } else if next == '_' { -// i++ -// } else { -// break -// } -// if pointSeen && !digitSeen { -// return nil, fmt.Errorf("cannot start float with a dot") -// } -// } -// -// if !digitSeen { -// return nil, fmt.Errorf("no digit in that number") -// } -// if pointSeen || expSeen { -// f, err := parseFloat(b[:i]) -// if err != nil { -// return nil, err -// } -// //p.builder.FloatValue(f) -// // TODO -// f = f -// } else { -// v, err := parseIntDec(b[:i]) -// if err != nil { -// return nil, err -// } -// //p.builder.IntValue(v) -// // TODO -// v = v -// } -// return b[i:], nil -//} - -func parseFloat(b []byte) (float64, error) { - // TODO: inefficient - tok := string(b) - err := numberContainsInvalidUnderscore(tok) - if err != nil { - return 0, err - } - cleanedVal := cleanupNumberToken(tok) - return strconv.ParseFloat(cleanedVal, 64) -} - -func parseIntHex(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := hexNumberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, nil - } - return strconv.ParseInt(cleanedVal[2:], 16, 64) -} - -func parseIntOct(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal[2:], 8, 64) -} - -func parseIntBin(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal[2:], 2, 64) -} - -func parseIntDec(b []byte) (int64, error) { - tok := string(b) - cleanedVal := cleanupNumberToken(tok) - err := numberContainsInvalidUnderscore(cleanedVal) - if err != nil { - return 0, err - } - return strconv.ParseInt(cleanedVal, 10, 64) -} - -func numberContainsInvalidUnderscore(value string) error { - // For large numbers, you may use underscores between digits to enhance - // readability. Each underscore must be surrounded by at least one digit on - // each side. - - hasBefore := false - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscore - } - } - hasBefore = isDigitRune(r) - } - return nil -} - -func hexNumberContainsInvalidUnderscore(value string) error { - hasBefore := false - for idx, r := range value { - if r == '_' { - if !hasBefore || idx+1 >= len(value) { - // can't end with an underscore - return errInvalidUnderscoreHex - } - } - hasBefore = isHexDigit(r) - } - return nil -} - -func cleanupNumberToken(value string) string { - cleanedVal := strings.Replace(value, "_", "", -1) - return cleanedVal -} - func isDigit(r byte) bool { return r >= '0' && r <= '9' } -func isDigitRune(r rune) bool { - return r >= '0' && r <= '9' -} - -var plusInf = math.Inf(1) -var minusInf = math.Inf(-1) -var nan = math.NaN() - type validRuneFn func(r byte) bool func isValidHexRune(r byte) bool { @@ -1232,12 +1000,6 @@ func isValidHexRune(r byte) bool { r == '_' } -func isHexDigit(r rune) bool { - return isDigitRune(r) || - (r >= 'a' && r <= 'f') || - (r >= 'A' && r <= 'F') -} - func isValidOctalRune(r byte) bool { return r >= '0' && r <= '7' || r == '_' } @@ -1265,6 +1027,3 @@ func (u unexpectedCharacter) Error() string { } return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) } - -var errInvalidUnderscore = errors.New("invalid use of _ in number") -var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/internal/unmarshaler/parser_test.go b/internal/unmarshaler/parser_test.go index 250da3e..b7aabfa 100644 --- a/internal/unmarshaler/parser_test.go +++ b/internal/unmarshaler/parser_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestParser_Numbers(t *testing.T) { +func TestParser_AST_Numbers(t *testing.T) { examples := []struct { desc string input string diff --git a/internal/unmarshaler/targets.go b/internal/unmarshaler/targets.go index 6adaed6..af3c5c7 100644 --- a/internal/unmarshaler/targets.go +++ b/internal/unmarshaler/targets.go @@ -15,6 +15,12 @@ type target interface { // Store a boolean at the target setBool(v bool) error + // Store an int64 at the target + setInt64(v int64) error + + // Store a float64 at the target + setFloat64(v float64) error + // Creates a new value of the container's element type, and returns a // target to it. pushNew() (target, error) @@ -83,6 +89,38 @@ func (t valueTarget) setBool(v bool) error { return nil } +func (t valueTarget) setInt64(v int64) error { + f := t.get() + + switch f.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // TODO: overflow checks + f.SetInt(v) + case reflect.Interface: + f.Set(reflect.ValueOf(v)) + default: + return fmt.Errorf("cannot assign int64 to a %s", f.String()) + } + + return nil +} + +func (t valueTarget) setFloat64(v float64) error { + f := t.get() + + switch f.Kind() { + case reflect.Float32, reflect.Float64: + // TODO: overflow checks + f.SetFloat(v) + case reflect.Interface: + f.Set(reflect.ValueOf(v)) + default: + return fmt.Errorf("cannot assign float64 to a %s", f.String()) + } + + return nil +} + func (t valueTarget) pushNew() (target, error) { f := t.get() diff --git a/internal/unmarshaler/unmarshaler.go b/internal/unmarshaler/unmarshaler.go index 6b2a665..ab63b89 100644 --- a/internal/unmarshaler/unmarshaler.go +++ b/internal/unmarshaler/unmarshaler.go @@ -80,6 +80,10 @@ func unmarshalValue(x target, node *ast.Node) error { return unmarshalString(x, node) case ast.Bool: return unmarshalBool(x, node) + case ast.Integer: + return unmarshalInteger(x, node) + case ast.Float: + return unmarshalFloat(x, node) case ast.Array: return unmarshalArray(x, node) case ast.InlineTable: @@ -100,6 +104,24 @@ func unmarshalBool(x target, node *ast.Node) error { return x.setBool(v) } +func unmarshalInteger(x target, node *ast.Node) error { + assertNode(ast.Integer, node) + v, err := node.DecodeInteger() + if err != nil { + return err + } + return x.setInt64(v) +} + +func unmarshalFloat(x target, node *ast.Node) error { + assertNode(ast.Float, node) + v, err := node.DecodeFloat() + if err != nil { + return err + } + return x.setFloat64(v) +} + func unmarshalInlineTable(x target, node *ast.Node) error { assertNode(ast.InlineTable, node) diff --git a/internal/unmarshaler/unmarshaler_test.go b/internal/unmarshaler/unmarshaler_test.go index d8212e7..4c45e4e 100644 --- a/internal/unmarshaler/unmarshaler_test.go +++ b/internal/unmarshaler/unmarshaler_test.go @@ -1,6 +1,7 @@ package unmarshaler import ( + "math" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,164 @@ import ( "github.com/pelletier/go-toml/v2/internal/ast" ) +func TestUnmarshal_Integers(t *testing.T) { + examples := []struct { + desc string + input string + expected int64 + err bool + }{ + { + desc: "integer just digits", + input: `1234`, + expected: 1234, + }, + { + desc: "integer zero", + input: `0`, + expected: 0, + }, + { + desc: "integer sign", + input: `+99`, + expected: 99, + }, + { + desc: "integer hex uppercase", + input: `0xDEADBEEF`, + expected: 0xDEADBEEF, + }, + { + desc: "integer hex lowercase", + input: `0xdead_beef`, + expected: 0xDEADBEEF, + }, + { + desc: "integer octal", + input: `0o01234567`, + expected: 0o01234567, + }, + { + desc: "integer binary", + input: `0b11010110`, + expected: 0b11010110, + }, + } + + type doc struct { + A int64 + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + doc := doc{} + err := Unmarshal([]byte(`A = `+e.input), &doc) + require.NoError(t, err) + assert.Equal(t, e.expected, doc.A) + }) + } +} + +func TestUnmarshal_Floats(t *testing.T) { + examples := []struct { + desc string + input string + expected float64 + testFn func(t *testing.T, v float64) + err bool + }{ + + { + desc: "float pi", + input: `3.1415`, + expected: 3.1415, + }, + { + desc: "float negative", + input: `-0.01`, + expected: -0.01, + }, + { + desc: "float signed exponent", + input: `5e+22`, + expected: 5e+22, + }, + { + desc: "float exponent lowercase", + input: `1e06`, + expected: 1e06, + }, + { + desc: "float exponent uppercase", + input: `-2E-2`, + expected: -2e-2, + }, + { + desc: "float fractional with exponent", + input: `6.626e-34`, + expected: 6.626e-34, + }, + { + desc: "float underscores", + input: `224_617.445_991_228`, + expected: 224_617.445_991_228, + }, + { + desc: "inf", + input: `inf`, + expected: math.Inf(+1), + }, + { + desc: "inf negative", + input: `-inf`, + expected: math.Inf(-1), + }, + { + desc: "inf positive", + input: `+inf`, + expected: math.Inf(+1), + }, + { + desc: "nan", + input: `nan`, + testFn: func(t *testing.T, v float64) { + assert.True(t, math.IsNaN(v)) + }, + }, + { + desc: "nan negative", + input: `-nan`, + testFn: func(t *testing.T, v float64) { + assert.True(t, math.IsNaN(v)) + }, + }, + { + desc: "nan positive", + input: `+nan`, + testFn: func(t *testing.T, v float64) { + assert.True(t, math.IsNaN(v)) + }, + }, + } + + type doc struct { + A float64 + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + doc := doc{} + err := Unmarshal([]byte(`A = `+e.input), &doc) + require.NoError(t, err) + if e.testFn != nil { + e.testFn(t, doc.A) + } else { + assert.Equal(t, e.expected, doc.A) + } + }) + } +} + func TestUnmarshal(t *testing.T) { type test struct { target interface{}