Playing with an AST
Idea would be to build a light AST as a first pass, then have the unmarshaler and Document parser do what they need with it.
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Kind int
|
||||
|
||||
const (
|
||||
// meta
|
||||
Comment Kind = iota
|
||||
Key
|
||||
|
||||
// top level structures
|
||||
Table
|
||||
ArrayTable
|
||||
KeyValue
|
||||
|
||||
// containers values
|
||||
Array
|
||||
InlineTable
|
||||
|
||||
// values
|
||||
String
|
||||
Bool
|
||||
Float
|
||||
Integer
|
||||
LocalDate
|
||||
LocalDateTime
|
||||
DateTime
|
||||
Time
|
||||
)
|
||||
|
||||
func (k Kind) String() string {
|
||||
switch k {
|
||||
case Comment:
|
||||
return "Comment"
|
||||
case Key:
|
||||
return "Key"
|
||||
case Table:
|
||||
return "Table"
|
||||
case ArrayTable:
|
||||
return "ArrayTable"
|
||||
case KeyValue:
|
||||
return "KeyValue"
|
||||
case Array:
|
||||
return "Array"
|
||||
case InlineTable:
|
||||
return "InlineTable"
|
||||
case String:
|
||||
return "String"
|
||||
case Bool:
|
||||
return "Bool"
|
||||
case Float:
|
||||
return "Float"
|
||||
case Integer:
|
||||
return "Integer"
|
||||
case LocalDate:
|
||||
return "LocalDate"
|
||||
case LocalDateTime:
|
||||
return "LocalDateTime"
|
||||
case DateTime:
|
||||
return "DateTime"
|
||||
case Time:
|
||||
return "Time"
|
||||
}
|
||||
panic(fmt.Errorf("Kind.String() not implemented for '%d'", k))
|
||||
}
|
||||
|
||||
type Root []Node
|
||||
|
||||
// Dot returns a dot representation of the AST for debugging.
|
||||
func (r Root) Sdot() string {
|
||||
type edge struct {
|
||||
from int
|
||||
to int
|
||||
}
|
||||
|
||||
var nodes []string
|
||||
var edges []edge // indexes into nodes
|
||||
|
||||
nodes = append(nodes, "root")
|
||||
|
||||
labelForNode := func(node *Node) string {
|
||||
return fmt.Sprintf("{%s}", node.Kind)
|
||||
}
|
||||
|
||||
var processNode func(int, *Node)
|
||||
processNode = func(parentIdx int, node *Node) {
|
||||
idx := len(nodes)
|
||||
label := labelForNode(node)
|
||||
nodes = append(nodes, label)
|
||||
edges = append(edges, edge{from: parentIdx, to: idx})
|
||||
|
||||
for _, c := range node.Children {
|
||||
processNode(idx, &c)
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range r {
|
||||
processNode(0, &n)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("digraph tree {\n")
|
||||
|
||||
for i, label := range nodes {
|
||||
_, _ = fmt.Fprintf(&b, "\tnode%d [label=\"%s\"];\n", i, label)
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
|
||||
for _, e := range edges {
|
||||
_, _ = fmt.Fprintf(&b, "\tnode%d -> node%d;\n", e.from, e.to)
|
||||
}
|
||||
|
||||
b.WriteString("}")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
Kind Kind
|
||||
Data []byte // Raw bytes from the input
|
||||
|
||||
// Arrays have one child per element in the array.
|
||||
// InlineTables have one child per key-value pair in the table.
|
||||
// KeyValues have at least two children. The last one is the value. The
|
||||
// rest make a potentially dotted key.
|
||||
Children []Node
|
||||
}
|
||||
|
||||
var NoNode = Node{}
|
||||
|
||||
// Key returns the nodes making the Key of a KeyValue.
|
||||
// They are guaranteed to be all be of the Kind Key. A simple key would return
|
||||
// just one element.
|
||||
// Panics if not called on a KeyValue node, or if the Children are malformed.
|
||||
func (n *Node) Key() []Node {
|
||||
if n.Kind != KeyValue {
|
||||
panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind))
|
||||
}
|
||||
if len(n.Children) < 2 {
|
||||
panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children)))
|
||||
}
|
||||
return n.Children[:len(n.Children)-1]
|
||||
}
|
||||
|
||||
// Value returns a pointer to the value node of a KeyValue.
|
||||
// Guaranteed to be non-nil.
|
||||
// Panics if not called on a KeyValue node, or if the Children are malformed.
|
||||
func (n *Node) Value() *Node {
|
||||
if n.Kind != KeyValue {
|
||||
panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind))
|
||||
}
|
||||
if len(n.Children) < 2 {
|
||||
panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children)))
|
||||
}
|
||||
return &n.Children[len(n.Children)-1]
|
||||
}
|
||||
@@ -1855,16 +1855,19 @@ func TestUnmarshalMixedTypeArray(t *testing.T) {
|
||||
ArrayField []interface{}
|
||||
}
|
||||
|
||||
doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
|
||||
//doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
|
||||
//`)
|
||||
|
||||
doc := []byte(`ArrayField = [{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
|
||||
`)
|
||||
|
||||
actual := TestStruct{}
|
||||
expected := TestStruct{
|
||||
ArrayField: []interface{}{
|
||||
3.14,
|
||||
int64(100),
|
||||
true,
|
||||
"hello world",
|
||||
//3.14,
|
||||
//int64(100),
|
||||
//true,
|
||||
//"hello world",
|
||||
map[string]interface{}{
|
||||
"Field": "inner1",
|
||||
},
|
||||
@@ -1874,14 +1877,9 @@ func TestUnmarshalMixedTypeArray(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := toml.Unmarshal(doc, &actual); err == nil {
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Errorf("Bad unmarshal: expected %#v, got %#v", expected, actual)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := toml.Unmarshal(doc, &actual)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
func TestUnmarshalArray(t *testing.T) {
|
||||
|
||||
@@ -199,9 +199,7 @@ func NewBuilder(tag string, v interface{}) (Builder, error) {
|
||||
}
|
||||
|
||||
func (b *Builder) top() target {
|
||||
t := b.stack[len(b.stack)-1]
|
||||
fmt.Println("TOP:", t)
|
||||
return t
|
||||
return b.stack[len(b.stack)-1]
|
||||
}
|
||||
|
||||
func (b *Builder) duplicate() {
|
||||
@@ -213,7 +211,6 @@ func (b *Builder) duplicate() {
|
||||
|
||||
func (b *Builder) pop() {
|
||||
b.stack = b.stack[:len(b.stack)-1]
|
||||
fmt.Println("POP: top:", b.stack[len(b.stack)-1])
|
||||
}
|
||||
|
||||
func (b *Builder) len() int {
|
||||
@@ -236,7 +233,6 @@ func (b *Builder) Dump() string {
|
||||
}
|
||||
|
||||
func (b *Builder) replace(v target) {
|
||||
fmt.Println("REPLACING:", v)
|
||||
b.stack[len(b.stack)-1] = v
|
||||
}
|
||||
|
||||
@@ -250,10 +246,6 @@ func (b *Builder) DigField(s string) error {
|
||||
v := t.get()
|
||||
|
||||
for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr {
|
||||
if v.Kind() == reflect.Interface {
|
||||
fmt.Println("STOP")
|
||||
}
|
||||
|
||||
if v.IsNil() {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
thing := reflect.New(v.Type().Elem())
|
||||
@@ -338,7 +330,20 @@ func (b *Builder) IsSlice() bool {
|
||||
}
|
||||
|
||||
func (b *Builder) IsSliceOrPtr() bool {
|
||||
return b.top().get().Kind() == reflect.Slice || (b.top().get().Kind() == reflect.Ptr && b.top().get().Type().Elem().Kind() == reflect.Slice)
|
||||
t := b.top().get()
|
||||
if t.Kind() == reflect.Slice {
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Ptr && t.Type().Elem().Kind() == reflect.Slice {
|
||||
return true
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Interface && !t.IsNil() && t.Elem().Type().Kind() == reflect.Slice {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Last moves the cursor to the last value of the current value.
|
||||
@@ -502,14 +507,14 @@ func convert(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
return result.Elem(), nil
|
||||
}
|
||||
|
||||
type IntegerOverflowErr struct {
|
||||
type IntegerOverflowError struct {
|
||||
value int64
|
||||
min int64
|
||||
max int64
|
||||
kind reflect.Kind
|
||||
}
|
||||
|
||||
func (e IntegerOverflowErr) Error() string {
|
||||
func (e IntegerOverflowError) Error() string {
|
||||
return fmt.Sprintf("integer overflow: cannot store %d in %s [%d, %d]", e.value, e.kind, e.min, e.max)
|
||||
}
|
||||
|
||||
@@ -524,7 +529,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
switch t.Kind() {
|
||||
case reflect.Int:
|
||||
if x > maxInt || x < minInt {
|
||||
return value, IntegerOverflowErr{
|
||||
return value, IntegerOverflowError{
|
||||
value: x,
|
||||
min: minInt,
|
||||
max: maxInt,
|
||||
@@ -533,7 +538,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
case reflect.Int8:
|
||||
if x > math.MaxInt8 || x < math.MinInt8 {
|
||||
return value, IntegerOverflowErr{
|
||||
return value, IntegerOverflowError{
|
||||
value: x,
|
||||
min: math.MinInt8,
|
||||
max: math.MaxInt8,
|
||||
@@ -542,7 +547,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
case reflect.Int16:
|
||||
if x > math.MaxInt16 || x < math.MinInt16 {
|
||||
return value, IntegerOverflowErr{
|
||||
return value, IntegerOverflowError{
|
||||
value: x,
|
||||
min: math.MinInt16,
|
||||
max: math.MaxInt16,
|
||||
@@ -551,7 +556,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
case reflect.Int32:
|
||||
if x > math.MaxInt32 || x < math.MinInt32 {
|
||||
return value, IntegerOverflowErr{
|
||||
return value, IntegerOverflowError{
|
||||
value: x,
|
||||
min: math.MinInt32,
|
||||
max: math.MaxInt32,
|
||||
@@ -560,7 +565,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
case reflect.Int64:
|
||||
if x > math.MaxInt64 || x < math.MinInt64 {
|
||||
return value, IntegerOverflowErr{
|
||||
return value, IntegerOverflowError{
|
||||
value: x,
|
||||
min: math.MinInt64,
|
||||
max: math.MaxInt64,
|
||||
@@ -575,13 +580,13 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
}
|
||||
}
|
||||
|
||||
type UnsignedIntegerOverflowErr struct {
|
||||
type UnsignedIntegerOverflowError struct {
|
||||
value uint64
|
||||
max uint64
|
||||
kind reflect.Kind
|
||||
}
|
||||
|
||||
func (e UnsignedIntegerOverflowErr) Error() string {
|
||||
func (e UnsignedIntegerOverflowError) Error() string {
|
||||
return fmt.Sprintf("unsigned integer overflow: cannot store %d in %s [max %d]", e.value, e.kind, e.max)
|
||||
}
|
||||
|
||||
@@ -617,7 +622,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
|
||||
switch t {
|
||||
case reflect.Uint:
|
||||
if x > maxUint {
|
||||
return UnsignedIntegerOverflowErr{
|
||||
return UnsignedIntegerOverflowError{
|
||||
value: x,
|
||||
max: maxUint,
|
||||
kind: t,
|
||||
@@ -625,7 +630,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if x > math.MaxUint8 {
|
||||
return UnsignedIntegerOverflowErr{
|
||||
return UnsignedIntegerOverflowError{
|
||||
value: x,
|
||||
max: math.MaxUint8,
|
||||
kind: t,
|
||||
@@ -633,7 +638,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
|
||||
}
|
||||
case reflect.Uint16:
|
||||
if x > math.MaxUint16 {
|
||||
return UnsignedIntegerOverflowErr{
|
||||
return UnsignedIntegerOverflowError{
|
||||
value: x,
|
||||
max: math.MaxUint16,
|
||||
kind: t,
|
||||
@@ -641,7 +646,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
|
||||
}
|
||||
case reflect.Uint32:
|
||||
if x > math.MaxUint32 {
|
||||
return UnsignedIntegerOverflowErr{
|
||||
return UnsignedIntegerOverflowError{
|
||||
value: x,
|
||||
max: math.MaxUint32,
|
||||
kind: t,
|
||||
@@ -649,7 +654,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
|
||||
}
|
||||
case reflect.Uint64:
|
||||
if x > math.MaxUint64 {
|
||||
return UnsignedIntegerOverflowErr{
|
||||
return UnsignedIntegerOverflowError{
|
||||
value: x,
|
||||
max: math.MaxUint64,
|
||||
kind: t,
|
||||
@@ -665,7 +670,7 @@ func convertFloat(t reflect.Type, value reflect.Value) (reflect.Value, error) {
|
||||
if t.Kind() == reflect.Float32 {
|
||||
f := value.Float()
|
||||
if f > math.MaxFloat32 {
|
||||
return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]")
|
||||
return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]", f, t, math.MaxFloat32)
|
||||
}
|
||||
}
|
||||
return value.Convert(t), nil
|
||||
@@ -684,7 +689,7 @@ func (b *Builder) SetString(s string) error {
|
||||
v.Set(reflect.ValueOf(&s))
|
||||
return nil
|
||||
}
|
||||
return t.set(reflect.ValueOf(s))
|
||||
return t.set(reflect.ValueOf(&s))
|
||||
}
|
||||
|
||||
// Set the value at the cursor to the given boolean.
|
||||
@@ -762,6 +767,8 @@ func (b *Builder) EnsureStructOrMap() error {
|
||||
x.Elem().Set(reflect.MakeMap(v.Type()))
|
||||
return t.set(x)
|
||||
}
|
||||
case reflect.Interface:
|
||||
// TODO: ?
|
||||
default:
|
||||
return IncorrectKindError{
|
||||
Reason: "EnsureStructOrMap",
|
||||
@@ -772,19 +779,6 @@ func (b *Builder) EnsureStructOrMap() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkKindInt(rt reflect.Type) error {
|
||||
switch rt.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return nil
|
||||
}
|
||||
|
||||
return IncorrectKindError{
|
||||
Reason: "CheckKindInt",
|
||||
Actual: rt.Kind(),
|
||||
Expected: []reflect.Kind{reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64},
|
||||
}
|
||||
}
|
||||
|
||||
func checkKindFloat(rt reflect.Type) error {
|
||||
switch rt.Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,50 @@
|
||||
package unmarshaler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParser_Simple(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
input string
|
||||
ast ast.Root
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
desc: "simple string assignment",
|
||||
input: `A = "hello"`,
|
||||
ast: ast.Root{
|
||||
ast.Node{
|
||||
Kind: ast.KeyValue,
|
||||
Children: []ast.Node{
|
||||
{
|
||||
Kind: ast.Key,
|
||||
Data: []byte(`A`),
|
||||
},
|
||||
{
|
||||
Kind: ast.String,
|
||||
Data: []byte(`hello`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
p := parser{}
|
||||
err := p.parse([]byte(e.input))
|
||||
if e.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, e.ast, p.tree)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package unmarshaler
|
||||
|
||||
import "fmt"
|
||||
|
||||
func scanFollows(pattern []byte) func(b []byte) bool {
|
||||
return func(b []byte) bool {
|
||||
if len(b) < len(pattern) {
|
||||
return false
|
||||
}
|
||||
for i, c := range pattern {
|
||||
if b[i] != c {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'})
|
||||
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''})
|
||||
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'})
|
||||
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
|
||||
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'})
|
||||
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'})
|
||||
|
||||
func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
|
||||
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
|
||||
for i := 0; i < len(b); i++ {
|
||||
if !isUnquotedKeyChar(b[i]) {
|
||||
return b[:i], b[i:], nil
|
||||
}
|
||||
}
|
||||
return b, nil, nil
|
||||
}
|
||||
|
||||
func isUnquotedKeyChar(r byte) bool {
|
||||
return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_'
|
||||
}
|
||||
|
||||
func scanLiteralString(b []byte) ([]byte, []byte, error) {
|
||||
//literal-string = apostrophe *literal-char apostrophe
|
||||
//apostrophe = %x27 ; ' apostrophe
|
||||
//literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
|
||||
for i := 1; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case '\'':
|
||||
return b[:i+1], b[i+1:], nil
|
||||
case '\n':
|
||||
return nil, nil, fmt.Errorf("literal strings cannot have new lines")
|
||||
}
|
||||
}
|
||||
return nil, nil, fmt.Errorf("unterminated literal string")
|
||||
}
|
||||
|
||||
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
|
||||
//ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body
|
||||
//ml-literal-string-delim
|
||||
//ml-literal-string-delim = 3apostrophe
|
||||
//ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ]
|
||||
//
|
||||
//mll-content = mll-char / newline
|
||||
//mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
|
||||
//mll-quotes = 1*2apostrophe
|
||||
for i := 3; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case '\'':
|
||||
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
|
||||
return b[:i+3], b[:i+3], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf(`multiline literal string not terminated by '''`)
|
||||
}
|
||||
|
||||
func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
|
||||
if len(b) < 2 {
|
||||
return nil, nil, fmt.Errorf(`windows new line missing \n`)
|
||||
}
|
||||
if b[1] != '\n' {
|
||||
return nil, nil, fmt.Errorf(`windows new line should be \r\n`)
|
||||
}
|
||||
return b[:2], b[2:], nil
|
||||
}
|
||||
|
||||
func scanWhitespace(b []byte) ([]byte, []byte) {
|
||||
for i := 0; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case ' ', '\t':
|
||||
continue
|
||||
default:
|
||||
return b[:i], b[i:]
|
||||
}
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func scanComment(b []byte) ([]byte, []byte, error) {
|
||||
//;; Comment
|
||||
//
|
||||
//comment-start-symbol = %x23 ; #
|
||||
//non-ascii = %x80-D7FF / %xE000-10FFFF
|
||||
//non-eol = %x09 / %x20-7F / non-ascii
|
||||
//
|
||||
//comment = comment-start-symbol *non-eol
|
||||
|
||||
for i := 1; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case '\n':
|
||||
return b[:i], b[i:], nil
|
||||
}
|
||||
}
|
||||
return b, nil, nil
|
||||
}
|
||||
|
||||
// TODO perform validation on the string?
|
||||
func scanBasicString(b []byte) ([]byte, []byte, error) {
|
||||
//basic-string = quotation-mark *basic-char quotation-mark
|
||||
//quotation-mark = %x22 ; "
|
||||
//basic-char = basic-unescaped / escaped
|
||||
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
|
||||
//escaped = escape escape-seq-char
|
||||
for i := 1; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case '"':
|
||||
return b[:i+1], b[i+1:], nil
|
||||
case '\n':
|
||||
return nil, nil, fmt.Errorf("basic strings cannot have new lines")
|
||||
case '\\':
|
||||
if len(b) < i+2 {
|
||||
return nil, nil, fmt.Errorf("need a character after \\")
|
||||
}
|
||||
i++ // skip the next character
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf(`basic string not terminated by "`)
|
||||
}
|
||||
|
||||
// TODO perform validation on the string?
|
||||
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
|
||||
//ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
|
||||
//ml-basic-string-delim
|
||||
//ml-basic-string-delim = 3quotation-mark
|
||||
//ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ]
|
||||
//
|
||||
//mlb-content = mlb-char / newline / mlb-escaped-nl
|
||||
//mlb-char = mlb-unescaped / escaped
|
||||
//mlb-quotes = 1*2quotation-mark
|
||||
//mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
|
||||
//mlb-escaped-nl = escape ws newline *( wschar / newline )
|
||||
|
||||
for i := 3; i < len(b); i++ {
|
||||
switch b[i] {
|
||||
case '"':
|
||||
if scanFollowsMultilineBasicStringDelimiter(b[i:]) {
|
||||
return b[:i+3], b[i+3:], nil
|
||||
}
|
||||
case '\\':
|
||||
if len(b) < i+2 {
|
||||
return nil, nil, fmt.Errorf("need a character after \\")
|
||||
}
|
||||
i++ // skip the next character
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf(`multiline basic string not terminated by """`)
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package unmarshaler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type target interface {
|
||||
// Ensure the target's reflect value is not nil.
|
||||
ensure()
|
||||
|
||||
// Store a string at the target.
|
||||
setString(v string) error
|
||||
|
||||
// Appends an arbitrary value to the container.
|
||||
pushValue(v reflect.Value) error
|
||||
|
||||
// Dereferences the target.
|
||||
get() reflect.Value
|
||||
}
|
||||
|
||||
// struct target just contain the reflect.Value of the target field.
|
||||
type structTarget reflect.Value
|
||||
|
||||
func (t structTarget) get() reflect.Value {
|
||||
return reflect.Value(t)
|
||||
}
|
||||
|
||||
func (t structTarget) ensure() {
|
||||
f := t.get()
|
||||
if !f.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
switch f.Kind() {
|
||||
case reflect.Slice:
|
||||
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
||||
default:
|
||||
panic(fmt.Errorf("don't know how to ensure %s", f.Kind()))
|
||||
}
|
||||
}
|
||||
|
||||
func (t structTarget) setString(v string) error {
|
||||
f := t.get()
|
||||
if f.Kind() != reflect.String {
|
||||
return fmt.Errorf("cannot assign string to a %s", f.String())
|
||||
}
|
||||
f.SetString(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t structTarget) pushValue(v reflect.Value) error {
|
||||
f := t.get()
|
||||
|
||||
switch f.Kind() {
|
||||
case reflect.Slice:
|
||||
t.ensure()
|
||||
f.Set(reflect.Append(f, v))
|
||||
default:
|
||||
return fmt.Errorf("cannot push %s on a %s", v.Kind(), f.Kind())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scope(v reflect.Value, name string) (target, error) {
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
return scopeStruct(v, name)
|
||||
default:
|
||||
panic(fmt.Errorf("can't scope on a %s", v.Kind()))
|
||||
}
|
||||
}
|
||||
|
||||
func scopeStruct(v reflect.Value, name string) (target, error) {
|
||||
// TODO: cache this
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.PkgPath != "" {
|
||||
// only consider exported fields
|
||||
continue
|
||||
}
|
||||
if f.Anonymous {
|
||||
// TODO: handle embedded structs
|
||||
} else {
|
||||
// TODO: handle names variations
|
||||
if f.Name == name {
|
||||
return structTarget(v.Field(i)), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type())
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package unmarshaler
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStructTarget_Ensure(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
input reflect.Value
|
||||
name string
|
||||
test func(v reflect.Value)
|
||||
}{
|
||||
{
|
||||
desc: "handle a nil slice of string",
|
||||
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
|
||||
name: "A",
|
||||
test: func(v reflect.Value) {
|
||||
assert.False(t, v.IsNil())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "handle an existing slice of string",
|
||||
input: reflect.ValueOf(&struct{ A []string }{A: []string{"foo"}}).Elem(),
|
||||
name: "A",
|
||||
test: func(v reflect.Value) {
|
||||
require.False(t, v.IsNil())
|
||||
s := v.Interface().([]string)
|
||||
assert.Equal(t, []string{"foo"}, s)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
target, err := scope(e.input, e.name)
|
||||
require.NoError(t, err)
|
||||
target.ensure()
|
||||
v := target.get()
|
||||
e.test(v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructTarget_SetString(t *testing.T) {
|
||||
str := "value"
|
||||
|
||||
examples := []struct {
|
||||
desc string
|
||||
input reflect.Value
|
||||
name string
|
||||
test func(v reflect.Value, err error)
|
||||
}{
|
||||
{
|
||||
desc: "sets a string",
|
||||
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
|
||||
name: "A",
|
||||
test: func(v reflect.Value, err error) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, str, v.String())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "fails on a float",
|
||||
input: reflect.ValueOf(&struct{ A float64 }{}).Elem(),
|
||||
name: "A",
|
||||
test: func(v reflect.Value, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "fails on a slice",
|
||||
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
|
||||
name: "A",
|
||||
test: func(v reflect.Value, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
target, err := scope(e.input, e.name)
|
||||
require.NoError(t, err)
|
||||
err = target.setString(str)
|
||||
v := target.get()
|
||||
e.test(v, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushValue_Struct(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
input reflect.Value
|
||||
expected []string
|
||||
error bool
|
||||
}{
|
||||
{
|
||||
desc: "push to nil slice",
|
||||
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
|
||||
expected: []string{"hello"},
|
||||
},
|
||||
{
|
||||
desc: "push to string",
|
||||
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
|
||||
error: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
target, err := scope(e.input, "A")
|
||||
require.NoError(t, err)
|
||||
v := reflect.ValueOf("hello")
|
||||
err = target.pushValue(v)
|
||||
if e.error {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
x := target.get().Interface().([]string)
|
||||
assert.Equal(t, e.expected, x)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScope_Struct(t *testing.T) {
|
||||
examples := []struct {
|
||||
desc string
|
||||
input reflect.Value
|
||||
name string
|
||||
err bool
|
||||
idx []int
|
||||
}{
|
||||
{
|
||||
desc: "simple field",
|
||||
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
|
||||
name: "A",
|
||||
idx: []int{0},
|
||||
},
|
||||
{
|
||||
desc: "fails not-exported field",
|
||||
input: reflect.ValueOf(&struct{ a string }{}).Elem(),
|
||||
name: "a",
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range examples {
|
||||
t.Run(e.desc, func(t *testing.T) {
|
||||
x, err := scope(e.input, e.name)
|
||||
if e.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
x2, ok := x.(structTarget)
|
||||
require.True(t, ok)
|
||||
x2.get()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package unmarshaler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
)
|
||||
|
||||
func FromAst(tree ast.Root, target interface{}) error {
|
||||
x := reflect.ValueOf(target)
|
||||
if x.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("need to target a pointer, not %s", x.Kind())
|
||||
}
|
||||
if x.IsNil() {
|
||||
return fmt.Errorf("target pointer must be non-nil")
|
||||
}
|
||||
|
||||
for _, node := range tree {
|
||||
err := topLevelNode(x, &node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func topLevelNode(x reflect.Value, node *ast.Node) error {
|
||||
if x.Kind() != reflect.Ptr {
|
||||
panic("topLevelNode should receive target, which should be a pointer")
|
||||
}
|
||||
if x.IsNil() {
|
||||
panic("topLevelNode should receive target, which should not be a nil pointer")
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case ast.Table:
|
||||
panic("TODO")
|
||||
case ast.ArrayTable:
|
||||
panic("TODO")
|
||||
case ast.KeyValue:
|
||||
return keyValue(x, node)
|
||||
default:
|
||||
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
|
||||
}
|
||||
}
|
||||
|
||||
func keyValue(x reflect.Value, node *ast.Node) error {
|
||||
assertNode(ast.KeyValue, node)
|
||||
assertPtr(x)
|
||||
|
||||
key := node.Key()
|
||||
key = key
|
||||
// TODO
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertNode(expected ast.Kind, node *ast.Node) {
|
||||
if node.Kind != expected {
|
||||
panic(fmt.Errorf("expected node of kind %s, not %s", expected, node.Kind))
|
||||
}
|
||||
}
|
||||
|
||||
func assertPtr(x reflect.Value) {
|
||||
if x.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("should be a pointer, not a %s", x.Kind()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package unmarshaler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pelletier/go-toml/v2/internal/ast"
|
||||
"github.com/pelletier/go-toml/v2/internal/unmarshaler"
|
||||
)
|
||||
|
||||
func TestFromAst_KV(t *testing.T) {
|
||||
t.Skipf("later")
|
||||
root := ast.Root{
|
||||
ast.Node{
|
||||
Kind: ast.KeyValue,
|
||||
Children: []ast.Node{
|
||||
{
|
||||
Kind: ast.Key,
|
||||
Data: []byte(`Foo`),
|
||||
},
|
||||
{
|
||||
Kind: ast.String,
|
||||
Data: []byte(`hello`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
type Doc struct {
|
||||
Foo string
|
||||
}
|
||||
|
||||
x := Doc{}
|
||||
err := unmarshaler.FromAst(root, &x)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Doc{Foo: "hello"}, x)
|
||||
}
|
||||
+25
-7
@@ -1,6 +1,7 @@
|
||||
package toml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
@@ -62,6 +63,7 @@ func (u *unmarshaler) Assignation() {
|
||||
return
|
||||
}
|
||||
u.assign = true
|
||||
fmt.Println("ASSIGN: TRUE!")
|
||||
}
|
||||
|
||||
func (u *unmarshaler) ArrayBegin() {
|
||||
@@ -73,11 +75,12 @@ func (u *unmarshaler) ArrayBegin() {
|
||||
if u.err != nil {
|
||||
return
|
||||
}
|
||||
if u.assign {
|
||||
u.assign = false
|
||||
} else {
|
||||
u.err = u.builder.SliceNewElem()
|
||||
fmt.Println("ARRAY BEGIN ASSIGN =", u.assign)
|
||||
if !u.assign {
|
||||
//u.err = u.builder.SliceNewSlice()
|
||||
// TODO
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) ArrayEnd() {
|
||||
@@ -126,8 +129,15 @@ func (u *unmarshaler) InlineTableBegin() {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO
|
||||
u.builder.Save()
|
||||
|
||||
if u.builder.IsSliceOrPtr() {
|
||||
u.err = u.builder.SliceNewElem()
|
||||
} else {
|
||||
u.err = u.builder.EnsureStructOrMap()
|
||||
}
|
||||
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) InlineTableEnd() {
|
||||
@@ -135,7 +145,7 @@ func (u *unmarshaler) InlineTableEnd() {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO
|
||||
u.builder.Load()
|
||||
}
|
||||
|
||||
func (u *unmarshaler) KeyValBegin() {
|
||||
@@ -176,6 +186,7 @@ func (u *unmarshaler) StringValue(v []byte) {
|
||||
s := string(v)
|
||||
u.err = u.builder.Set(reflect.ValueOf(&s))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) BoolValue(b bool) {
|
||||
@@ -192,6 +203,7 @@ func (u *unmarshaler) BoolValue(b bool) {
|
||||
} else {
|
||||
u.err = u.builder.SetBool(b)
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) FloatValue(n float64) {
|
||||
@@ -209,6 +221,7 @@ func (u *unmarshaler) FloatValue(n float64) {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&n))
|
||||
//u.err = u.builder.SetFloat(n)
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) IntValue(n int64) {
|
||||
@@ -225,6 +238,7 @@ func (u *unmarshaler) IntValue(n int64) {
|
||||
} else {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&n))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) LocalDateValue(date LocalDate) {
|
||||
@@ -241,6 +255,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) {
|
||||
} else {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&date))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) {
|
||||
@@ -257,6 +272,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) {
|
||||
} else {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&dt))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) DateTimeValue(dt time.Time) {
|
||||
@@ -273,6 +289,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) {
|
||||
} else {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&dt))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) LocalTimeValue(localTime LocalTime) {
|
||||
@@ -289,6 +306,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) {
|
||||
} else {
|
||||
u.err = u.builder.Set(reflect.ValueOf(&localTime))
|
||||
}
|
||||
u.assign = false
|
||||
}
|
||||
|
||||
func (u *unmarshaler) SimpleKey(v []byte) {
|
||||
@@ -337,5 +355,5 @@ func (u *unmarshaler) StandardTableEnd() {
|
||||
return
|
||||
}
|
||||
|
||||
u.builder.EnsureStructOrMap()
|
||||
u.builder.EnsureStructOrMap() // TODO: handle error
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user