Playing with an AST

Idea would be to build a light AST as a first pass, then have the
unmarshaler and Document parser do what they need with it.
This commit is contained in:
Thomas Pelletier
2021-03-13 11:38:09 -05:00
parent 93a74fca35
commit 21d3e85fcc
11 changed files with 2009 additions and 59 deletions
+162
View File
@@ -0,0 +1,162 @@
package ast
import (
"fmt"
"strings"
)
type Kind int
const (
// meta
Comment Kind = iota
Key
// top level structures
Table
ArrayTable
KeyValue
// containers values
Array
InlineTable
// values
String
Bool
Float
Integer
LocalDate
LocalDateTime
DateTime
Time
)
func (k Kind) String() string {
switch k {
case Comment:
return "Comment"
case Key:
return "Key"
case Table:
return "Table"
case ArrayTable:
return "ArrayTable"
case KeyValue:
return "KeyValue"
case Array:
return "Array"
case InlineTable:
return "InlineTable"
case String:
return "String"
case Bool:
return "Bool"
case Float:
return "Float"
case Integer:
return "Integer"
case LocalDate:
return "LocalDate"
case LocalDateTime:
return "LocalDateTime"
case DateTime:
return "DateTime"
case Time:
return "Time"
}
panic(fmt.Errorf("Kind.String() not implemented for '%d'", k))
}
type Root []Node
// Dot returns a dot representation of the AST for debugging.
func (r Root) Sdot() string {
type edge struct {
from int
to int
}
var nodes []string
var edges []edge // indexes into nodes
nodes = append(nodes, "root")
labelForNode := func(node *Node) string {
return fmt.Sprintf("{%s}", node.Kind)
}
var processNode func(int, *Node)
processNode = func(parentIdx int, node *Node) {
idx := len(nodes)
label := labelForNode(node)
nodes = append(nodes, label)
edges = append(edges, edge{from: parentIdx, to: idx})
for _, c := range node.Children {
processNode(idx, &c)
}
}
for _, n := range r {
processNode(0, &n)
}
var b strings.Builder
b.WriteString("digraph tree {\n")
for i, label := range nodes {
_, _ = fmt.Fprintf(&b, "\tnode%d [label=\"%s\"];\n", i, label)
}
b.WriteString("\n")
for _, e := range edges {
_, _ = fmt.Fprintf(&b, "\tnode%d -> node%d;\n", e.from, e.to)
}
b.WriteString("}")
return b.String()
}
type Node struct {
Kind Kind
Data []byte // Raw bytes from the input
// Arrays have one child per element in the array.
// InlineTables have one child per key-value pair in the table.
// KeyValues have at least two children. The last one is the value. The
// rest make a potentially dotted key.
Children []Node
}
var NoNode = Node{}
// Key returns the nodes making the Key of a KeyValue.
// They are guaranteed to be all be of the Kind Key. A simple key would return
// just one element.
// Panics if not called on a KeyValue node, or if the Children are malformed.
func (n *Node) Key() []Node {
if n.Kind != KeyValue {
panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind))
}
if len(n.Children) < 2 {
panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children)))
}
return n.Children[:len(n.Children)-1]
}
// Value returns a pointer to the value node of a KeyValue.
// Guaranteed to be non-nil.
// Panics if not called on a KeyValue node, or if the Children are malformed.
func (n *Node) Value() *Node {
if n.Kind != KeyValue {
panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind))
}
if len(n.Children) < 2 {
panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children)))
}
return &n.Children[len(n.Children)-1]
}
@@ -1855,16 +1855,19 @@ func TestUnmarshalMixedTypeArray(t *testing.T) {
ArrayField []interface{}
}
doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
//doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
//`)
doc := []byte(`ArrayField = [{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]]
`)
actual := TestStruct{}
expected := TestStruct{
ArrayField: []interface{}{
3.14,
int64(100),
true,
"hello world",
//3.14,
//int64(100),
//true,
//"hello world",
map[string]interface{}{
"Field": "inner1",
},
@@ -1874,14 +1877,9 @@ func TestUnmarshalMixedTypeArray(t *testing.T) {
},
},
}
if err := toml.Unmarshal(doc, &actual); err == nil {
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Bad unmarshal: expected %#v, got %#v", expected, actual)
}
} else {
t.Fatal(err)
}
err := toml.Unmarshal(doc, &actual)
require.NoError(t, err)
assert.Equal(t, expected, actual)
}
func TestUnmarshalArray(t *testing.T) {
+33 -39
View File
@@ -199,9 +199,7 @@ func NewBuilder(tag string, v interface{}) (Builder, error) {
}
func (b *Builder) top() target {
t := b.stack[len(b.stack)-1]
fmt.Println("TOP:", t)
return t
return b.stack[len(b.stack)-1]
}
func (b *Builder) duplicate() {
@@ -213,7 +211,6 @@ func (b *Builder) duplicate() {
func (b *Builder) pop() {
b.stack = b.stack[:len(b.stack)-1]
fmt.Println("POP: top:", b.stack[len(b.stack)-1])
}
func (b *Builder) len() int {
@@ -236,7 +233,6 @@ func (b *Builder) Dump() string {
}
func (b *Builder) replace(v target) {
fmt.Println("REPLACING:", v)
b.stack[len(b.stack)-1] = v
}
@@ -250,10 +246,6 @@ func (b *Builder) DigField(s string) error {
v := t.get()
for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr {
if v.Kind() == reflect.Interface {
fmt.Println("STOP")
}
if v.IsNil() {
if v.Kind() == reflect.Ptr {
thing := reflect.New(v.Type().Elem())
@@ -338,7 +330,20 @@ func (b *Builder) IsSlice() bool {
}
func (b *Builder) IsSliceOrPtr() bool {
return b.top().get().Kind() == reflect.Slice || (b.top().get().Kind() == reflect.Ptr && b.top().get().Type().Elem().Kind() == reflect.Slice)
t := b.top().get()
if t.Kind() == reflect.Slice {
return true
}
if t.Kind() == reflect.Ptr && t.Type().Elem().Kind() == reflect.Slice {
return true
}
if t.Kind() == reflect.Interface && !t.IsNil() && t.Elem().Type().Kind() == reflect.Slice {
return true
}
return false
}
// Last moves the cursor to the last value of the current value.
@@ -502,14 +507,14 @@ func convert(t reflect.Type, value reflect.Value) (reflect.Value, error) {
return result.Elem(), nil
}
type IntegerOverflowErr struct {
type IntegerOverflowError struct {
value int64
min int64
max int64
kind reflect.Kind
}
func (e IntegerOverflowErr) Error() string {
func (e IntegerOverflowError) Error() string {
return fmt.Sprintf("integer overflow: cannot store %d in %s [%d, %d]", e.value, e.kind, e.min, e.max)
}
@@ -524,7 +529,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
switch t.Kind() {
case reflect.Int:
if x > maxInt || x < minInt {
return value, IntegerOverflowErr{
return value, IntegerOverflowError{
value: x,
min: minInt,
max: maxInt,
@@ -533,7 +538,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
}
case reflect.Int8:
if x > math.MaxInt8 || x < math.MinInt8 {
return value, IntegerOverflowErr{
return value, IntegerOverflowError{
value: x,
min: math.MinInt8,
max: math.MaxInt8,
@@ -542,7 +547,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
}
case reflect.Int16:
if x > math.MaxInt16 || x < math.MinInt16 {
return value, IntegerOverflowErr{
return value, IntegerOverflowError{
value: x,
min: math.MinInt16,
max: math.MaxInt16,
@@ -551,7 +556,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
}
case reflect.Int32:
if x > math.MaxInt32 || x < math.MinInt32 {
return value, IntegerOverflowErr{
return value, IntegerOverflowError{
value: x,
min: math.MinInt32,
max: math.MaxInt32,
@@ -560,7 +565,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
}
case reflect.Int64:
if x > math.MaxInt64 || x < math.MinInt64 {
return value, IntegerOverflowErr{
return value, IntegerOverflowError{
value: x,
min: math.MinInt64,
max: math.MaxInt64,
@@ -575,13 +580,13 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) {
}
}
type UnsignedIntegerOverflowErr struct {
type UnsignedIntegerOverflowError struct {
value uint64
max uint64
kind reflect.Kind
}
func (e UnsignedIntegerOverflowErr) Error() string {
func (e UnsignedIntegerOverflowError) Error() string {
return fmt.Sprintf("unsigned integer overflow: cannot store %d in %s [max %d]", e.value, e.kind, e.max)
}
@@ -617,7 +622,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
switch t {
case reflect.Uint:
if x > maxUint {
return UnsignedIntegerOverflowErr{
return UnsignedIntegerOverflowError{
value: x,
max: maxUint,
kind: t,
@@ -625,7 +630,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
}
case reflect.Uint8:
if x > math.MaxUint8 {
return UnsignedIntegerOverflowErr{
return UnsignedIntegerOverflowError{
value: x,
max: math.MaxUint8,
kind: t,
@@ -633,7 +638,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
}
case reflect.Uint16:
if x > math.MaxUint16 {
return UnsignedIntegerOverflowErr{
return UnsignedIntegerOverflowError{
value: x,
max: math.MaxUint16,
kind: t,
@@ -641,7 +646,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
}
case reflect.Uint32:
if x > math.MaxUint32 {
return UnsignedIntegerOverflowErr{
return UnsignedIntegerOverflowError{
value: x,
max: math.MaxUint32,
kind: t,
@@ -649,7 +654,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error {
}
case reflect.Uint64:
if x > math.MaxUint64 {
return UnsignedIntegerOverflowErr{
return UnsignedIntegerOverflowError{
value: x,
max: math.MaxUint64,
kind: t,
@@ -665,7 +670,7 @@ func convertFloat(t reflect.Type, value reflect.Value) (reflect.Value, error) {
if t.Kind() == reflect.Float32 {
f := value.Float()
if f > math.MaxFloat32 {
return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]")
return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]", f, t, math.MaxFloat32)
}
}
return value.Convert(t), nil
@@ -684,7 +689,7 @@ func (b *Builder) SetString(s string) error {
v.Set(reflect.ValueOf(&s))
return nil
}
return t.set(reflect.ValueOf(s))
return t.set(reflect.ValueOf(&s))
}
// Set the value at the cursor to the given boolean.
@@ -762,6 +767,8 @@ func (b *Builder) EnsureStructOrMap() error {
x.Elem().Set(reflect.MakeMap(v.Type()))
return t.set(x)
}
case reflect.Interface:
// TODO: ?
default:
return IncorrectKindError{
Reason: "EnsureStructOrMap",
@@ -772,19 +779,6 @@ func (b *Builder) EnsureStructOrMap() error {
return nil
}
func checkKindInt(rt reflect.Type) error {
switch rt.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return nil
}
return IncorrectKindError{
Reason: "CheckKindInt",
Actual: rt.Kind(),
Expected: []reflect.Kind{reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64},
}
}
func checkKindFloat(rt reflect.Type) error {
switch rt.Kind() {
case reflect.Float32, reflect.Float64:
File diff suppressed because it is too large Load Diff
+50
View File
@@ -0,0 +1,50 @@
package unmarshaler
import (
"testing"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/stretchr/testify/require"
)
func TestParser_Simple(t *testing.T) {
examples := []struct {
desc string
input string
ast ast.Root
err bool
}{
{
desc: "simple string assignment",
input: `A = "hello"`,
ast: ast.Root{
ast.Node{
Kind: ast.KeyValue,
Children: []ast.Node{
{
Kind: ast.Key,
Data: []byte(`A`),
},
{
Kind: ast.String,
Data: []byte(`hello`),
},
},
},
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
p := parser{}
err := p.parse([]byte(e.input))
if e.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, e.ast, p.tree)
}
})
}
}
+168
View File
@@ -0,0 +1,168 @@
package unmarshaler
import "fmt"
func scanFollows(pattern []byte) func(b []byte) bool {
return func(b []byte) bool {
if len(b) < len(pattern) {
return false
}
for i, c := range pattern {
if b[i] != c {
return false
}
}
return true
}
}
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'})
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''})
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'})
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'})
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'})
func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
for i := 0; i < len(b); i++ {
if !isUnquotedKeyChar(b[i]) {
return b[:i], b[i:], nil
}
}
return b, nil, nil
}
func isUnquotedKeyChar(r byte) bool {
return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_'
}
func scanLiteralString(b []byte) ([]byte, []byte, error) {
//literal-string = apostrophe *literal-char apostrophe
//apostrophe = %x27 ; ' apostrophe
//literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
for i := 1; i < len(b); i++ {
switch b[i] {
case '\'':
return b[:i+1], b[i+1:], nil
case '\n':
return nil, nil, fmt.Errorf("literal strings cannot have new lines")
}
}
return nil, nil, fmt.Errorf("unterminated literal string")
}
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
//ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body
//ml-literal-string-delim
//ml-literal-string-delim = 3apostrophe
//ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ]
//
//mll-content = mll-char / newline
//mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
//mll-quotes = 1*2apostrophe
for i := 3; i < len(b); i++ {
switch b[i] {
case '\'':
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
return b[:i+3], b[:i+3], nil
}
}
}
return nil, nil, fmt.Errorf(`multiline literal string not terminated by '''`)
}
func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
if len(b) < 2 {
return nil, nil, fmt.Errorf(`windows new line missing \n`)
}
if b[1] != '\n' {
return nil, nil, fmt.Errorf(`windows new line should be \r\n`)
}
return b[:2], b[2:], nil
}
func scanWhitespace(b []byte) ([]byte, []byte) {
for i := 0; i < len(b); i++ {
switch b[i] {
case ' ', '\t':
continue
default:
return b[:i], b[i:]
}
}
return b, nil
}
func scanComment(b []byte) ([]byte, []byte, error) {
//;; Comment
//
//comment-start-symbol = %x23 ; #
//non-ascii = %x80-D7FF / %xE000-10FFFF
//non-eol = %x09 / %x20-7F / non-ascii
//
//comment = comment-start-symbol *non-eol
for i := 1; i < len(b); i++ {
switch b[i] {
case '\n':
return b[:i], b[i:], nil
}
}
return b, nil, nil
}
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) {
//basic-string = quotation-mark *basic-char quotation-mark
//quotation-mark = %x22 ; "
//basic-char = basic-unescaped / escaped
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//escaped = escape escape-seq-char
for i := 1; i < len(b); i++ {
switch b[i] {
case '"':
return b[:i+1], b[i+1:], nil
case '\n':
return nil, nil, fmt.Errorf("basic strings cannot have new lines")
case '\\':
if len(b) < i+2 {
return nil, nil, fmt.Errorf("need a character after \\")
}
i++ // skip the next character
}
}
return nil, nil, fmt.Errorf(`basic string not terminated by "`)
}
// TODO perform validation on the string?
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
//ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
//ml-basic-string-delim
//ml-basic-string-delim = 3quotation-mark
//ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ]
//
//mlb-content = mlb-char / newline / mlb-escaped-nl
//mlb-char = mlb-unescaped / escaped
//mlb-quotes = 1*2quotation-mark
//mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//mlb-escaped-nl = escape ws newline *( wschar / newline )
for i := 3; i < len(b); i++ {
switch b[i] {
case '"':
if scanFollowsMultilineBasicStringDelimiter(b[i:]) {
return b[:i+3], b[i+3:], nil
}
case '\\':
if len(b) < i+2 {
return nil, nil, fmt.Errorf("need a character after \\")
}
i++ // skip the next character
}
}
return nil, nil, fmt.Errorf(`multiline basic string not terminated by """`)
}
+94
View File
@@ -0,0 +1,94 @@
package unmarshaler
import (
"fmt"
"reflect"
)
type target interface {
// Ensure the target's reflect value is not nil.
ensure()
// Store a string at the target.
setString(v string) error
// Appends an arbitrary value to the container.
pushValue(v reflect.Value) error
// Dereferences the target.
get() reflect.Value
}
// struct target just contain the reflect.Value of the target field.
type structTarget reflect.Value
func (t structTarget) get() reflect.Value {
return reflect.Value(t)
}
func (t structTarget) ensure() {
f := t.get()
if !f.IsNil() {
return
}
switch f.Kind() {
case reflect.Slice:
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
default:
panic(fmt.Errorf("don't know how to ensure %s", f.Kind()))
}
}
func (t structTarget) setString(v string) error {
f := t.get()
if f.Kind() != reflect.String {
return fmt.Errorf("cannot assign string to a %s", f.String())
}
f.SetString(v)
return nil
}
func (t structTarget) pushValue(v reflect.Value) error {
f := t.get()
switch f.Kind() {
case reflect.Slice:
t.ensure()
f.Set(reflect.Append(f, v))
default:
return fmt.Errorf("cannot push %s on a %s", v.Kind(), f.Kind())
}
return nil
}
func scope(v reflect.Value, name string) (target, error) {
switch v.Kind() {
case reflect.Struct:
return scopeStruct(v, name)
default:
panic(fmt.Errorf("can't scope on a %s", v.Kind()))
}
}
func scopeStruct(v reflect.Value, name string) (target, error) {
// TODO: cache this
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
// only consider exported fields
continue
}
if f.Anonymous {
// TODO: handle embedded structs
} else {
// TODO: handle names variations
if f.Name == name {
return structTarget(v.Field(i)), nil
}
}
}
return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type())
}
+166
View File
@@ -0,0 +1,166 @@
package unmarshaler
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStructTarget_Ensure(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
name string
test func(v reflect.Value)
}{
{
desc: "handle a nil slice of string",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
name: "A",
test: func(v reflect.Value) {
assert.False(t, v.IsNil())
},
},
{
desc: "handle an existing slice of string",
input: reflect.ValueOf(&struct{ A []string }{A: []string{"foo"}}).Elem(),
name: "A",
test: func(v reflect.Value) {
require.False(t, v.IsNil())
s := v.Interface().([]string)
assert.Equal(t, []string{"foo"}, s)
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, e.name)
require.NoError(t, err)
target.ensure()
v := target.get()
e.test(v)
})
}
}
func TestStructTarget_SetString(t *testing.T) {
str := "value"
examples := []struct {
desc string
input reflect.Value
name string
test func(v reflect.Value, err error)
}{
{
desc: "sets a string",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.NoError(t, err)
assert.Equal(t, str, v.String())
},
},
{
desc: "fails on a float",
input: reflect.ValueOf(&struct{ A float64 }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.Error(t, err)
},
},
{
desc: "fails on a slice",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.Error(t, err)
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, e.name)
require.NoError(t, err)
err = target.setString(str)
v := target.get()
e.test(v, err)
})
}
}
func TestPushValue_Struct(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
expected []string
error bool
}{
{
desc: "push to nil slice",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
expected: []string{"hello"},
},
{
desc: "push to string",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
error: true,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, "A")
require.NoError(t, err)
v := reflect.ValueOf("hello")
err = target.pushValue(v)
if e.error {
require.Error(t, err)
} else {
require.NoError(t, err)
x := target.get().Interface().([]string)
assert.Equal(t, e.expected, x)
}
})
}
}
func TestScope_Struct(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
name string
err bool
idx []int
}{
{
desc: "simple field",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
name: "A",
idx: []int{0},
},
{
desc: "fails not-exported field",
input: reflect.ValueOf(&struct{ a string }{}).Elem(),
name: "a",
err: true,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
x, err := scope(e.input, e.name)
if e.err {
require.Error(t, err)
} else {
x2, ok := x.(structTarget)
require.True(t, ok)
x2.get()
}
})
}
}
+69
View File
@@ -0,0 +1,69 @@
package unmarshaler
import (
"fmt"
"reflect"
"github.com/pelletier/go-toml/v2/internal/ast"
)
func FromAst(tree ast.Root, target interface{}) error {
x := reflect.ValueOf(target)
if x.Kind() != reflect.Ptr {
return fmt.Errorf("need to target a pointer, not %s", x.Kind())
}
if x.IsNil() {
return fmt.Errorf("target pointer must be non-nil")
}
for _, node := range tree {
err := topLevelNode(x, &node)
if err != nil {
return err
}
}
return nil
}
func topLevelNode(x reflect.Value, node *ast.Node) error {
if x.Kind() != reflect.Ptr {
panic("topLevelNode should receive target, which should be a pointer")
}
if x.IsNil() {
panic("topLevelNode should receive target, which should not be a nil pointer")
}
switch node.Kind {
case ast.Table:
panic("TODO")
case ast.ArrayTable:
panic("TODO")
case ast.KeyValue:
return keyValue(x, node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func keyValue(x reflect.Value, node *ast.Node) error {
assertNode(ast.KeyValue, node)
assertPtr(x)
key := node.Key()
key = key
// TODO
return nil
}
func assertNode(expected ast.Kind, node *ast.Node) {
if node.Kind != expected {
panic(fmt.Errorf("expected node of kind %s, not %s", expected, node.Kind))
}
}
func assertPtr(x reflect.Value) {
if x.Kind() != reflect.Ptr {
panic(fmt.Errorf("should be a pointer, not a %s", x.Kind()))
}
}
+39
View File
@@ -0,0 +1,39 @@
package unmarshaler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/unmarshaler"
)
func TestFromAst_KV(t *testing.T) {
t.Skipf("later")
root := ast.Root{
ast.Node{
Kind: ast.KeyValue,
Children: []ast.Node{
{
Kind: ast.Key,
Data: []byte(`Foo`),
},
{
Kind: ast.String,
Data: []byte(`hello`),
},
},
},
}
type Doc struct {
Foo string
}
x := Doc{}
err := unmarshaler.FromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: "hello"}, x)
}
+25 -7
View File
@@ -1,6 +1,7 @@
package toml
import (
"fmt"
"reflect"
"time"
@@ -62,6 +63,7 @@ func (u *unmarshaler) Assignation() {
return
}
u.assign = true
fmt.Println("ASSIGN: TRUE!")
}
func (u *unmarshaler) ArrayBegin() {
@@ -73,11 +75,12 @@ func (u *unmarshaler) ArrayBegin() {
if u.err != nil {
return
}
if u.assign {
u.assign = false
} else {
u.err = u.builder.SliceNewElem()
fmt.Println("ARRAY BEGIN ASSIGN =", u.assign)
if !u.assign {
//u.err = u.builder.SliceNewSlice()
// TODO
}
u.assign = false
}
func (u *unmarshaler) ArrayEnd() {
@@ -126,8 +129,15 @@ func (u *unmarshaler) InlineTableBegin() {
return
}
// TODO
u.builder.Save()
if u.builder.IsSliceOrPtr() {
u.err = u.builder.SliceNewElem()
} else {
u.err = u.builder.EnsureStructOrMap()
}
u.assign = false
}
func (u *unmarshaler) InlineTableEnd() {
@@ -135,7 +145,7 @@ func (u *unmarshaler) InlineTableEnd() {
return
}
// TODO
u.builder.Load()
}
func (u *unmarshaler) KeyValBegin() {
@@ -176,6 +186,7 @@ func (u *unmarshaler) StringValue(v []byte) {
s := string(v)
u.err = u.builder.Set(reflect.ValueOf(&s))
}
u.assign = false
}
func (u *unmarshaler) BoolValue(b bool) {
@@ -192,6 +203,7 @@ func (u *unmarshaler) BoolValue(b bool) {
} else {
u.err = u.builder.SetBool(b)
}
u.assign = false
}
func (u *unmarshaler) FloatValue(n float64) {
@@ -209,6 +221,7 @@ func (u *unmarshaler) FloatValue(n float64) {
u.err = u.builder.Set(reflect.ValueOf(&n))
//u.err = u.builder.SetFloat(n)
}
u.assign = false
}
func (u *unmarshaler) IntValue(n int64) {
@@ -225,6 +238,7 @@ func (u *unmarshaler) IntValue(n int64) {
} else {
u.err = u.builder.Set(reflect.ValueOf(&n))
}
u.assign = false
}
func (u *unmarshaler) LocalDateValue(date LocalDate) {
@@ -241,6 +255,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) {
} else {
u.err = u.builder.Set(reflect.ValueOf(&date))
}
u.assign = false
}
func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) {
@@ -257,6 +272,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) {
} else {
u.err = u.builder.Set(reflect.ValueOf(&dt))
}
u.assign = false
}
func (u *unmarshaler) DateTimeValue(dt time.Time) {
@@ -273,6 +289,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) {
} else {
u.err = u.builder.Set(reflect.ValueOf(&dt))
}
u.assign = false
}
func (u *unmarshaler) LocalTimeValue(localTime LocalTime) {
@@ -289,6 +306,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) {
} else {
u.err = u.builder.Set(reflect.ValueOf(&localTime))
}
u.assign = false
}
func (u *unmarshaler) SimpleKey(v []byte) {
@@ -337,5 +355,5 @@ func (u *unmarshaler) StandardTableEnd() {
return
}
u.builder.EnsureStructOrMap()
u.builder.EnsureStructOrMap() // TODO: handle error
}