Playing with an AST

Idea would be to build a light AST as a first pass, then have the
unmarshaler and Document parser do what they need with it.
This commit is contained in:
Thomas Pelletier
2021-03-13 11:38:09 -05:00
parent 93a74fca35
commit 21d3e85fcc
11 changed files with 2009 additions and 59 deletions
File diff suppressed because it is too large Load Diff
+50
View File
@@ -0,0 +1,50 @@
package unmarshaler
import (
"testing"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/stretchr/testify/require"
)
func TestParser_Simple(t *testing.T) {
examples := []struct {
desc string
input string
ast ast.Root
err bool
}{
{
desc: "simple string assignment",
input: `A = "hello"`,
ast: ast.Root{
ast.Node{
Kind: ast.KeyValue,
Children: []ast.Node{
{
Kind: ast.Key,
Data: []byte(`A`),
},
{
Kind: ast.String,
Data: []byte(`hello`),
},
},
},
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
p := parser{}
err := p.parse([]byte(e.input))
if e.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, e.ast, p.tree)
}
})
}
}
+168
View File
@@ -0,0 +1,168 @@
package unmarshaler
import "fmt"
func scanFollows(pattern []byte) func(b []byte) bool {
return func(b []byte) bool {
if len(b) < len(pattern) {
return false
}
for i, c := range pattern {
if b[i] != c {
return false
}
}
return true
}
}
var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'})
var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''})
var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'})
var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'})
var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'})
var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'})
func scanUnquotedKey(b []byte) ([]byte, []byte, error) {
//unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _
for i := 0; i < len(b); i++ {
if !isUnquotedKeyChar(b[i]) {
return b[:i], b[i:], nil
}
}
return b, nil, nil
}
func isUnquotedKeyChar(r byte) bool {
return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_'
}
func scanLiteralString(b []byte) ([]byte, []byte, error) {
//literal-string = apostrophe *literal-char apostrophe
//apostrophe = %x27 ; ' apostrophe
//literal-char = %x09 / %x20-26 / %x28-7E / non-ascii
for i := 1; i < len(b); i++ {
switch b[i] {
case '\'':
return b[:i+1], b[i+1:], nil
case '\n':
return nil, nil, fmt.Errorf("literal strings cannot have new lines")
}
}
return nil, nil, fmt.Errorf("unterminated literal string")
}
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
//ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body
//ml-literal-string-delim
//ml-literal-string-delim = 3apostrophe
//ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ]
//
//mll-content = mll-char / newline
//mll-char = %x09 / %x20-26 / %x28-7E / non-ascii
//mll-quotes = 1*2apostrophe
for i := 3; i < len(b); i++ {
switch b[i] {
case '\'':
if scanFollowsMultilineLiteralStringDelimiter(b[i:]) {
return b[:i+3], b[:i+3], nil
}
}
}
return nil, nil, fmt.Errorf(`multiline literal string not terminated by '''`)
}
func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
if len(b) < 2 {
return nil, nil, fmt.Errorf(`windows new line missing \n`)
}
if b[1] != '\n' {
return nil, nil, fmt.Errorf(`windows new line should be \r\n`)
}
return b[:2], b[2:], nil
}
func scanWhitespace(b []byte) ([]byte, []byte) {
for i := 0; i < len(b); i++ {
switch b[i] {
case ' ', '\t':
continue
default:
return b[:i], b[i:]
}
}
return b, nil
}
func scanComment(b []byte) ([]byte, []byte, error) {
//;; Comment
//
//comment-start-symbol = %x23 ; #
//non-ascii = %x80-D7FF / %xE000-10FFFF
//non-eol = %x09 / %x20-7F / non-ascii
//
//comment = comment-start-symbol *non-eol
for i := 1; i < len(b); i++ {
switch b[i] {
case '\n':
return b[:i], b[i:], nil
}
}
return b, nil, nil
}
// TODO perform validation on the string?
func scanBasicString(b []byte) ([]byte, []byte, error) {
//basic-string = quotation-mark *basic-char quotation-mark
//quotation-mark = %x22 ; "
//basic-char = basic-unescaped / escaped
//basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//escaped = escape escape-seq-char
for i := 1; i < len(b); i++ {
switch b[i] {
case '"':
return b[:i+1], b[i+1:], nil
case '\n':
return nil, nil, fmt.Errorf("basic strings cannot have new lines")
case '\\':
if len(b) < i+2 {
return nil, nil, fmt.Errorf("need a character after \\")
}
i++ // skip the next character
}
}
return nil, nil, fmt.Errorf(`basic string not terminated by "`)
}
// TODO perform validation on the string?
func scanMultilineBasicString(b []byte) ([]byte, []byte, error) {
//ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
//ml-basic-string-delim
//ml-basic-string-delim = 3quotation-mark
//ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ]
//
//mlb-content = mlb-char / newline / mlb-escaped-nl
//mlb-char = mlb-unescaped / escaped
//mlb-quotes = 1*2quotation-mark
//mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii
//mlb-escaped-nl = escape ws newline *( wschar / newline )
for i := 3; i < len(b); i++ {
switch b[i] {
case '"':
if scanFollowsMultilineBasicStringDelimiter(b[i:]) {
return b[:i+3], b[i+3:], nil
}
case '\\':
if len(b) < i+2 {
return nil, nil, fmt.Errorf("need a character after \\")
}
i++ // skip the next character
}
}
return nil, nil, fmt.Errorf(`multiline basic string not terminated by """`)
}
+94
View File
@@ -0,0 +1,94 @@
package unmarshaler
import (
"fmt"
"reflect"
)
type target interface {
// Ensure the target's reflect value is not nil.
ensure()
// Store a string at the target.
setString(v string) error
// Appends an arbitrary value to the container.
pushValue(v reflect.Value) error
// Dereferences the target.
get() reflect.Value
}
// struct target just contain the reflect.Value of the target field.
type structTarget reflect.Value
func (t structTarget) get() reflect.Value {
return reflect.Value(t)
}
func (t structTarget) ensure() {
f := t.get()
if !f.IsNil() {
return
}
switch f.Kind() {
case reflect.Slice:
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
default:
panic(fmt.Errorf("don't know how to ensure %s", f.Kind()))
}
}
func (t structTarget) setString(v string) error {
f := t.get()
if f.Kind() != reflect.String {
return fmt.Errorf("cannot assign string to a %s", f.String())
}
f.SetString(v)
return nil
}
func (t structTarget) pushValue(v reflect.Value) error {
f := t.get()
switch f.Kind() {
case reflect.Slice:
t.ensure()
f.Set(reflect.Append(f, v))
default:
return fmt.Errorf("cannot push %s on a %s", v.Kind(), f.Kind())
}
return nil
}
func scope(v reflect.Value, name string) (target, error) {
switch v.Kind() {
case reflect.Struct:
return scopeStruct(v, name)
default:
panic(fmt.Errorf("can't scope on a %s", v.Kind()))
}
}
func scopeStruct(v reflect.Value, name string) (target, error) {
// TODO: cache this
t := v.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
// only consider exported fields
continue
}
if f.Anonymous {
// TODO: handle embedded structs
} else {
// TODO: handle names variations
if f.Name == name {
return structTarget(v.Field(i)), nil
}
}
}
return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type())
}
+166
View File
@@ -0,0 +1,166 @@
package unmarshaler
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStructTarget_Ensure(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
name string
test func(v reflect.Value)
}{
{
desc: "handle a nil slice of string",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
name: "A",
test: func(v reflect.Value) {
assert.False(t, v.IsNil())
},
},
{
desc: "handle an existing slice of string",
input: reflect.ValueOf(&struct{ A []string }{A: []string{"foo"}}).Elem(),
name: "A",
test: func(v reflect.Value) {
require.False(t, v.IsNil())
s := v.Interface().([]string)
assert.Equal(t, []string{"foo"}, s)
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, e.name)
require.NoError(t, err)
target.ensure()
v := target.get()
e.test(v)
})
}
}
func TestStructTarget_SetString(t *testing.T) {
str := "value"
examples := []struct {
desc string
input reflect.Value
name string
test func(v reflect.Value, err error)
}{
{
desc: "sets a string",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.NoError(t, err)
assert.Equal(t, str, v.String())
},
},
{
desc: "fails on a float",
input: reflect.ValueOf(&struct{ A float64 }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.Error(t, err)
},
},
{
desc: "fails on a slice",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
name: "A",
test: func(v reflect.Value, err error) {
assert.Error(t, err)
},
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, e.name)
require.NoError(t, err)
err = target.setString(str)
v := target.get()
e.test(v, err)
})
}
}
func TestPushValue_Struct(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
expected []string
error bool
}{
{
desc: "push to nil slice",
input: reflect.ValueOf(&struct{ A []string }{}).Elem(),
expected: []string{"hello"},
},
{
desc: "push to string",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
error: true,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
target, err := scope(e.input, "A")
require.NoError(t, err)
v := reflect.ValueOf("hello")
err = target.pushValue(v)
if e.error {
require.Error(t, err)
} else {
require.NoError(t, err)
x := target.get().Interface().([]string)
assert.Equal(t, e.expected, x)
}
})
}
}
func TestScope_Struct(t *testing.T) {
examples := []struct {
desc string
input reflect.Value
name string
err bool
idx []int
}{
{
desc: "simple field",
input: reflect.ValueOf(&struct{ A string }{}).Elem(),
name: "A",
idx: []int{0},
},
{
desc: "fails not-exported field",
input: reflect.ValueOf(&struct{ a string }{}).Elem(),
name: "a",
err: true,
},
}
for _, e := range examples {
t.Run(e.desc, func(t *testing.T) {
x, err := scope(e.input, e.name)
if e.err {
require.Error(t, err)
} else {
x2, ok := x.(structTarget)
require.True(t, ok)
x2.get()
}
})
}
}
+69
View File
@@ -0,0 +1,69 @@
package unmarshaler
import (
"fmt"
"reflect"
"github.com/pelletier/go-toml/v2/internal/ast"
)
func FromAst(tree ast.Root, target interface{}) error {
x := reflect.ValueOf(target)
if x.Kind() != reflect.Ptr {
return fmt.Errorf("need to target a pointer, not %s", x.Kind())
}
if x.IsNil() {
return fmt.Errorf("target pointer must be non-nil")
}
for _, node := range tree {
err := topLevelNode(x, &node)
if err != nil {
return err
}
}
return nil
}
func topLevelNode(x reflect.Value, node *ast.Node) error {
if x.Kind() != reflect.Ptr {
panic("topLevelNode should receive target, which should be a pointer")
}
if x.IsNil() {
panic("topLevelNode should receive target, which should not be a nil pointer")
}
switch node.Kind {
case ast.Table:
panic("TODO")
case ast.ArrayTable:
panic("TODO")
case ast.KeyValue:
return keyValue(x, node)
default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
}
}
func keyValue(x reflect.Value, node *ast.Node) error {
assertNode(ast.KeyValue, node)
assertPtr(x)
key := node.Key()
key = key
// TODO
return nil
}
func assertNode(expected ast.Kind, node *ast.Node) {
if node.Kind != expected {
panic(fmt.Errorf("expected node of kind %s, not %s", expected, node.Kind))
}
}
func assertPtr(x reflect.Value) {
if x.Kind() != reflect.Ptr {
panic(fmt.Errorf("should be a pointer, not a %s", x.Kind()))
}
}
+39
View File
@@ -0,0 +1,39 @@
package unmarshaler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/unmarshaler"
)
func TestFromAst_KV(t *testing.T) {
t.Skipf("later")
root := ast.Root{
ast.Node{
Kind: ast.KeyValue,
Children: []ast.Node{
{
Kind: ast.Key,
Data: []byte(`Foo`),
},
{
Kind: ast.String,
Data: []byte(`hello`),
},
},
},
}
type Doc struct {
Foo string
}
x := Doc{}
err := unmarshaler.FromAst(root, &x)
require.NoError(t, err)
assert.Equal(t, Doc{Foo: "hello"}, x)
}