diff --git a/internal/ast/ast.go b/internal/ast/ast.go new file mode 100644 index 0000000..c0748f6 --- /dev/null +++ b/internal/ast/ast.go @@ -0,0 +1,162 @@ +package ast + +import ( + "fmt" + "strings" +) + +type Kind int + +const ( + // meta + Comment Kind = iota + Key + + // top level structures + Table + ArrayTable + KeyValue + + // containers values + Array + InlineTable + + // values + String + Bool + Float + Integer + LocalDate + LocalDateTime + DateTime + Time +) + +func (k Kind) String() string { + switch k { + case Comment: + return "Comment" + case Key: + return "Key" + case Table: + return "Table" + case ArrayTable: + return "ArrayTable" + case KeyValue: + return "KeyValue" + case Array: + return "Array" + case InlineTable: + return "InlineTable" + case String: + return "String" + case Bool: + return "Bool" + case Float: + return "Float" + case Integer: + return "Integer" + case LocalDate: + return "LocalDate" + case LocalDateTime: + return "LocalDateTime" + case DateTime: + return "DateTime" + case Time: + return "Time" + } + panic(fmt.Errorf("Kind.String() not implemented for '%d'", k)) +} + +type Root []Node + +// Dot returns a dot representation of the AST for debugging. +func (r Root) Sdot() string { + type edge struct { + from int + to int + } + + var nodes []string + var edges []edge // indexes into nodes + + nodes = append(nodes, "root") + + labelForNode := func(node *Node) string { + return fmt.Sprintf("{%s}", node.Kind) + } + + var processNode func(int, *Node) + processNode = func(parentIdx int, node *Node) { + idx := len(nodes) + label := labelForNode(node) + nodes = append(nodes, label) + edges = append(edges, edge{from: parentIdx, to: idx}) + + for _, c := range node.Children { + processNode(idx, &c) + } + } + + for _, n := range r { + processNode(0, &n) + } + + var b strings.Builder + + b.WriteString("digraph tree {\n") + + for i, label := range nodes { + _, _ = fmt.Fprintf(&b, "\tnode%d [label=\"%s\"];\n", i, label) + } + + b.WriteString("\n") + + for _, e := range edges { + _, _ = fmt.Fprintf(&b, "\tnode%d -> node%d;\n", e.from, e.to) + } + + b.WriteString("}") + + return b.String() +} + +type Node struct { + Kind Kind + Data []byte // Raw bytes from the input + + // Arrays have one child per element in the array. + // InlineTables have one child per key-value pair in the table. + // KeyValues have at least two children. The last one is the value. The + // rest make a potentially dotted key. + Children []Node +} + +var NoNode = Node{} + +// Key returns the nodes making the Key of a KeyValue. +// They are guaranteed to be all be of the Kind Key. A simple key would return +// just one element. +// Panics if not called on a KeyValue node, or if the Children are malformed. +func (n *Node) Key() []Node { + if n.Kind != KeyValue { + panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind)) + } + if len(n.Children) < 2 { + panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children))) + } + return n.Children[:len(n.Children)-1] +} + +// Value returns a pointer to the value node of a KeyValue. +// Guaranteed to be non-nil. +// Panics if not called on a KeyValue node, or if the Children are malformed. +func (n *Node) Value() *Node { + if n.Kind != KeyValue { + panic(fmt.Errorf("Key() should only be called on on a KeyValue, not %s", n.Kind)) + } + if len(n.Children) < 2 { + panic(fmt.Errorf("KeyValue should have at least two children, not %d", len(n.Children))) + } + return &n.Children[len(n.Children)-1] +} diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 7a48324..6a268e7 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -1855,16 +1855,19 @@ func TestUnmarshalMixedTypeArray(t *testing.T) { ArrayField []interface{} } - doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]] + //doc := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]] + //`) + + doc := []byte(`ArrayField = [{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]] `) actual := TestStruct{} expected := TestStruct{ ArrayField: []interface{}{ - 3.14, - int64(100), - true, - "hello world", + //3.14, + //int64(100), + //true, + //"hello world", map[string]interface{}{ "Field": "inner1", }, @@ -1874,14 +1877,9 @@ func TestUnmarshalMixedTypeArray(t *testing.T) { }, }, } - - if err := toml.Unmarshal(doc, &actual); err == nil { - if !reflect.DeepEqual(actual, expected) { - t.Errorf("Bad unmarshal: expected %#v, got %#v", expected, actual) - } - } else { - t.Fatal(err) - } + err := toml.Unmarshal(doc, &actual) + require.NoError(t, err) + assert.Equal(t, expected, actual) } func TestUnmarshalArray(t *testing.T) { diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index b79807a..7bc8825 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -199,9 +199,7 @@ func NewBuilder(tag string, v interface{}) (Builder, error) { } func (b *Builder) top() target { - t := b.stack[len(b.stack)-1] - fmt.Println("TOP:", t) - return t + return b.stack[len(b.stack)-1] } func (b *Builder) duplicate() { @@ -213,7 +211,6 @@ func (b *Builder) duplicate() { func (b *Builder) pop() { b.stack = b.stack[:len(b.stack)-1] - fmt.Println("POP: top:", b.stack[len(b.stack)-1]) } func (b *Builder) len() int { @@ -236,7 +233,6 @@ func (b *Builder) Dump() string { } func (b *Builder) replace(v target) { - fmt.Println("REPLACING:", v) b.stack[len(b.stack)-1] = v } @@ -250,10 +246,6 @@ func (b *Builder) DigField(s string) error { v := t.get() for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr { - if v.Kind() == reflect.Interface { - fmt.Println("STOP") - } - if v.IsNil() { if v.Kind() == reflect.Ptr { thing := reflect.New(v.Type().Elem()) @@ -338,7 +330,20 @@ func (b *Builder) IsSlice() bool { } func (b *Builder) IsSliceOrPtr() bool { - return b.top().get().Kind() == reflect.Slice || (b.top().get().Kind() == reflect.Ptr && b.top().get().Type().Elem().Kind() == reflect.Slice) + t := b.top().get() + if t.Kind() == reflect.Slice { + return true + } + + if t.Kind() == reflect.Ptr && t.Type().Elem().Kind() == reflect.Slice { + return true + } + + if t.Kind() == reflect.Interface && !t.IsNil() && t.Elem().Type().Kind() == reflect.Slice { + return true + } + + return false } // Last moves the cursor to the last value of the current value. @@ -502,14 +507,14 @@ func convert(t reflect.Type, value reflect.Value) (reflect.Value, error) { return result.Elem(), nil } -type IntegerOverflowErr struct { +type IntegerOverflowError struct { value int64 min int64 max int64 kind reflect.Kind } -func (e IntegerOverflowErr) Error() string { +func (e IntegerOverflowError) Error() string { return fmt.Sprintf("integer overflow: cannot store %d in %s [%d, %d]", e.value, e.kind, e.min, e.max) } @@ -524,7 +529,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { switch t.Kind() { case reflect.Int: if x > maxInt || x < minInt { - return value, IntegerOverflowErr{ + return value, IntegerOverflowError{ value: x, min: minInt, max: maxInt, @@ -533,7 +538,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { } case reflect.Int8: if x > math.MaxInt8 || x < math.MinInt8 { - return value, IntegerOverflowErr{ + return value, IntegerOverflowError{ value: x, min: math.MinInt8, max: math.MaxInt8, @@ -542,7 +547,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { } case reflect.Int16: if x > math.MaxInt16 || x < math.MinInt16 { - return value, IntegerOverflowErr{ + return value, IntegerOverflowError{ value: x, min: math.MinInt16, max: math.MaxInt16, @@ -551,7 +556,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { } case reflect.Int32: if x > math.MaxInt32 || x < math.MinInt32 { - return value, IntegerOverflowErr{ + return value, IntegerOverflowError{ value: x, min: math.MinInt32, max: math.MaxInt32, @@ -560,7 +565,7 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { } case reflect.Int64: if x > math.MaxInt64 || x < math.MinInt64 { - return value, IntegerOverflowErr{ + return value, IntegerOverflowError{ value: x, min: math.MinInt64, max: math.MaxInt64, @@ -575,13 +580,13 @@ func convertInt(t reflect.Type, value reflect.Value) (reflect.Value, error) { } } -type UnsignedIntegerOverflowErr struct { +type UnsignedIntegerOverflowError struct { value uint64 max uint64 kind reflect.Kind } -func (e UnsignedIntegerOverflowErr) Error() string { +func (e UnsignedIntegerOverflowError) Error() string { return fmt.Sprintf("unsigned integer overflow: cannot store %d in %s [max %d]", e.value, e.kind, e.max) } @@ -617,7 +622,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error { switch t { case reflect.Uint: if x > maxUint { - return UnsignedIntegerOverflowErr{ + return UnsignedIntegerOverflowError{ value: x, max: maxUint, kind: t, @@ -625,7 +630,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error { } case reflect.Uint8: if x > math.MaxUint8 { - return UnsignedIntegerOverflowErr{ + return UnsignedIntegerOverflowError{ value: x, max: math.MaxUint8, kind: t, @@ -633,7 +638,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error { } case reflect.Uint16: if x > math.MaxUint16 { - return UnsignedIntegerOverflowErr{ + return UnsignedIntegerOverflowError{ value: x, max: math.MaxUint16, kind: t, @@ -641,7 +646,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error { } case reflect.Uint32: if x > math.MaxUint32 { - return UnsignedIntegerOverflowErr{ + return UnsignedIntegerOverflowError{ value: x, max: math.MaxUint32, kind: t, @@ -649,7 +654,7 @@ func convertUintOverflowCheck(t reflect.Kind, x uint64) error { } case reflect.Uint64: if x > math.MaxUint64 { - return UnsignedIntegerOverflowErr{ + return UnsignedIntegerOverflowError{ value: x, max: math.MaxUint64, kind: t, @@ -665,7 +670,7 @@ func convertFloat(t reflect.Type, value reflect.Value) (reflect.Value, error) { if t.Kind() == reflect.Float32 { f := value.Float() if f > math.MaxFloat32 { - return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]") + return value, fmt.Errorf("float overflow: %f does not fit in %s [max %f]", f, t, math.MaxFloat32) } } return value.Convert(t), nil @@ -684,7 +689,7 @@ func (b *Builder) SetString(s string) error { v.Set(reflect.ValueOf(&s)) return nil } - return t.set(reflect.ValueOf(s)) + return t.set(reflect.ValueOf(&s)) } // Set the value at the cursor to the given boolean. @@ -762,6 +767,8 @@ func (b *Builder) EnsureStructOrMap() error { x.Elem().Set(reflect.MakeMap(v.Type())) return t.set(x) } + case reflect.Interface: + // TODO: ? default: return IncorrectKindError{ Reason: "EnsureStructOrMap", @@ -772,19 +779,6 @@ func (b *Builder) EnsureStructOrMap() error { return nil } -func checkKindInt(rt reflect.Type) error { - switch rt.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return nil - } - - return IncorrectKindError{ - Reason: "CheckKindInt", - Actual: rt.Kind(), - Expected: []reflect.Kind{reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64}, - } -} - func checkKindFloat(rt reflect.Type) error { switch rt.Kind() { case reflect.Float32, reflect.Float64: diff --git a/internal/unmarshaler/parser.go b/internal/unmarshaler/parser.go new file mode 100644 index 0000000..9a1fb8d --- /dev/null +++ b/internal/unmarshaler/parser.go @@ -0,0 +1,1192 @@ +package unmarshaler + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/pelletier/go-toml/v2" + + "github.com/pelletier/go-toml/v2/internal/ast" +) + +type parser struct { + tree ast.Root +} + +func (p *parser) parse(b []byte) error { + b, err := p.parseExpression(b) + if err != nil { + return err + } + for len(b) > 0 { + b, err = p.parseNewline(b) + if err != nil { + return err + } + + b, err = p.parseExpression(b) + if err != nil { + return err + } + } + return nil +} + +func (p *parser) parseNewline(b []byte) ([]byte, error) { + if b[0] == '\n' { + return b[1:], nil + } + if b[0] == '\r' { + _, rest, err := scanWindowsNewline(b) + return rest, err + } + return nil, fmt.Errorf("expected newline but got %#U", b[0]) +} + +func (p *parser) parseExpression(b []byte) ([]byte, error) { + //expression = ws [ comment ] + //expression =/ ws keyval ws [ comment ] + //expression =/ ws table ws [ comment ] + + b = p.parseWhitespace(b) + + if len(b) == 0 { + return b, nil + } + + if b[0] == '#' { + _, rest, err := scanComment(b) + return rest, err + } + if b[0] == '\n' || b[0] == '\r' { + return b, nil + } + + var err error + var node ast.Node + if b[0] == '[' { + b, err = p.parseTable(b) + } else { + node, b, err = p.parseKeyval(b) + } + if err != nil { + return nil, err + } + + b = p.parseWhitespace(b) + + if len(b) > 0 && b[0] == '#' { + _, rest, err := scanComment(b) + return rest, err + } + + p.tree = append(p.tree, node) + + return b, nil +} + +func (p *parser) parseTable(b []byte) ([]byte, error) { + //table = std-table / array-table + if len(b) > 1 && b[1] == '[' { + return p.parseArrayTable(b) + } + return p.parseStdTable(b) +} + +func (p *parser) parseArrayTable(b []byte) ([]byte, error) { + //array-table = array-table-open key array-table-close + //array-table-open = %x5B.5B ws ; [[ Double left square bracket + //array-table-close = ws %x5D.5D ; ]] Double right square bracket + + // TODO + //b = b[2:] + //b = p.parseWhitespace(b) + //b, err := p.parseKey(b) + //if err != nil { + // return nil, err + //} + //b = p.parseWhitespace(b) + //b, err = expect(']', b) + //if err != nil { + // return nil, err + //} + //return expect(']', b) + + return nil, nil +} + +func (p *parser) parseStdTable(b []byte) ([]byte, error) { + //std-table = std-table-open key std-table-close + //std-table-open = %x5B ws ; [ Left square bracket + //std-table-close = ws %x5D ; ] Right square bracket + + // TODO + //b = b[1:] + //b = p.parseWhitespace(b) + //b, err := p.parseKey(b) + //if err != nil { + // return nil, err + //} + //b = p.parseWhitespace(b) + // + //return expect(']', b) + + return nil, nil +} + +func (p *parser) parseKeyval(b []byte) (ast.Node, []byte, error) { + //keyval = key keyval-sep val + + node := ast.Node{ + Kind: ast.KeyValue, + } + + key, b, err := p.parseKey(b) + if err != nil { + return ast.NoNode, nil, err + } + node.Children = append(node.Children, key...) + + //keyval-sep = ws %x3D ws ; = + + b = p.parseWhitespace(b) + b, err = expect('=', b) + if err != nil { + return ast.NoNode, nil, err + } + b = p.parseWhitespace(b) + + valNode, b, err := p.parseVal(b) + if err == nil { + node.Children = append(node.Children, valNode) + } + return node, b, err +} + +func (p *parser) parseVal(b []byte) (ast.Node, []byte, error) { + // val = string / boolean / array / inline-table / date-time / float / integer + if len(b) == 0 { + return ast.NoNode, nil, fmt.Errorf("expected value, not eof") + } + + node := ast.Node{} + var err error + c := b[0] + + switch c { + // strings + case '"': + var v []byte + if scanFollowsMultilineBasicStringDelimiter(b) { + v, b, err = p.parseMultilineBasicString(b) + } else { + v, b, err = p.parseBasicString(b) + } + if err == nil { + node.Kind = ast.String + node.Data = v + } + return node, b, err + case '\'': + var v []byte + if scanFollowsMultilineLiteralStringDelimiter(b) { + v, b, err = p.parseMultilineLiteralString(b) + } else { + v, b, err = p.parseLiteralString(b) + } + if err == nil { + // TODO + v = v + } + return node, b, err + case 't': + if !scanFollowsTrue(b) { + return node, nil, fmt.Errorf("expected 'true'") + } + // TODO + return node, b[4:], nil + case 'f': + if !scanFollowsFalse(b) { + return node, nil, fmt.Errorf("expected 'false'") + } + // TODO + return node, b[5:], nil + case '[': + // TODO + //return p.parseValArray(b) + case '{': + // TODO + //return p.parseInlineTable(b) + default: + // TODO + //return p.parseIntOrFloatOrDateTime(b) + } + panic("parseVal not finished yet") + return ast.Node{}, nil, nil +} + +func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, error) { + v, rest, err := scanLiteralString(b) + if err != nil { + return nil, nil, err + } + return v[1 : len(v)-1], rest, nil +} + +func (p *parser) parseInlineTable(b []byte) ([]byte, error) { + //inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close + //inline-table-open = %x7B ws ; { + //inline-table-close = ws %x7D ; } + //inline-table-sep = ws %x2C ws ; , Comma + //inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ] + + // TODO + //b = b[1:] + // + //first := true + //var err error + //for len(b) > 0 { + // b = p.parseWhitespace(b) + // if b[0] == '}' { + // break + // } + // + // if !first { + // b, err = expect(',', b) + // if err != nil { + // return nil, err + // } + // b = p.parseWhitespace(b) + // } + // b, err = p.parseKeyval(b) + // if err != nil { + // return nil, err + // } + // + // first = false + //} + + return expect('}', b) +} + +func (p *parser) parseValArray(b []byte) ([]byte, error) { + //array = array-open [ array-values ] ws-comment-newline array-close + //array-open = %x5B ; [ + //array-close = %x5D ; ] + //array-values = ws-comment-newline val ws-comment-newline array-sep array-values + //array-values =/ ws-comment-newline val ws-comment-newline [ array-sep ] + //array-sep = %x2C ; , Comma + //ws-comment-newline = *( wschar / [ comment ] newline ) + + // TODO + //b = b[1:] + // + //first := true + //var err error + //for len(b) > 0 { + // b, err = p.parseOptionalWhitespaceCommentNewline(b) + // if err != nil { + // return nil, err + // } + // + // if len(b) == 0 { + // return nil, unexpectedCharacter{b: b} + // } + // + // if b[0] == ']' { + // break + // } + // if b[0] == ',' { + // if first { + // return nil, fmt.Errorf("array cannot start with comma") + // } + // b = b[1:] + // b, err = p.parseOptionalWhitespaceCommentNewline(b) + // if err != nil { + // return nil, err + // } + // } + // + // b, err = p.parseVal(b) + // if err != nil { + // return nil, err + // } + // b, err = p.parseOptionalWhitespaceCommentNewline(b) + // if err != nil { + // return nil, err + // } + // first = false + //} + + return expect(']', b) +} + +func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { + var err error + b = p.parseWhitespace(b) + if len(b) > 0 && b[0] == '#' { + _, b, err = scanComment(b) + if err != nil { + return nil, err + } + } + if len(b) > 0 && (b[0] == '\n' || b[0] == '\r') { + b, err = p.parseNewline(b) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, error) { + token, rest, err := scanMultilineLiteralString(b) + if err != nil { + return nil, nil, err + } + + i := 3 + + // skip the immediate new line + if token[i] == '\n' { + i++ + } else if token[i] == '\r' && token[i+1] == '\n' { + i += 2 + } + + return token[i : len(b)-3], rest, err +} + +func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { + //ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body + //ml-basic-string-delim + //ml-basic-string-delim = 3quotation-mark + //ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ] + // + //mlb-content = mlb-char / newline / mlb-escaped-nl + //mlb-char = mlb-unescaped / escaped + //mlb-quotes = 1*2quotation-mark + //mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii + //mlb-escaped-nl = escape ws newline *( wschar / newline ) + + token, rest, err := scanMultilineBasicString(b) + if err != nil { + return nil, nil, err + } + var builder bytes.Buffer + + i := 3 + + // skip the immediate new line + if token[i] == '\n' { + i++ + } else if token[i] == '\r' && token[i+1] == '\n' { + i += 2 + } + + // The scanner ensures that the token starts and ends with quotes and that + // escapes are balanced. + for ; i < len(token)-3; i++ { + c := token[i] + if c == '\\' { + // When the last non-whitespace character on a line is an unescaped \, + // it will be trimmed along with all whitespace (including newlines) up + // to the next non-whitespace character or closing delimiter. + if token[i+1] == '\n' || (token[i+1] == '\r' && token[i+2] == '\n') { + i++ // skip the \ + for ; i < len(token)-3; i++ { + c := token[i] + if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') { + break + } + } + continue + } + + // handle escaping + i++ + c = token[i] + switch c { + case '"', '\\': + builder.WriteByte(c) + case 'b': + builder.WriteByte('\b') + case 'f': + builder.WriteByte('\f') + case 'n': + builder.WriteByte('\n') + case 'r': + builder.WriteByte('\r') + case 't': + builder.WriteByte('\t') + case 'u': + x, err := hexToString(token[i+3:len(token)-3], 4) + if err != nil { + return nil, nil, err + } + builder.WriteString(x) + i += 4 + case 'U': + x, err := hexToString(token[i+3:len(token)-3], 8) + if err != nil { + return nil, nil, err + } + builder.WriteString(x) + i += 8 + default: + return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) + } + } else { + builder.WriteByte(c) + } + } + + return builder.Bytes(), rest, nil +} + +func (p *parser) parseKey(b []byte) ([]ast.Node, []byte, error) { + //key = simple-key / dotted-key + //simple-key = quoted-key / unquoted-key + // + //unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ + //quoted-key = basic-string / literal-string + //dotted-key = simple-key 1*( dot-sep simple-key ) + // + //dot-sep = ws %x2E ws ; . Period + + var nodes []ast.Node + + key, b, err := p.parseSimpleKey(b) + if err != nil { + return nodes, nil, err + } + + nodes = append(nodes, ast.Node{ + Kind: ast.Key, + Data: key, + }) + + for { + b = p.parseWhitespace(b) + if len(b) > 0 && b[0] == '.' { + b, err = expect('.', b) + if err != nil { + return nodes, nil, err + } + b = p.parseWhitespace(b) + key, b, err = p.parseSimpleKey(b) + if err != nil { + return nodes, nil, err + } + nodes = append(nodes, ast.Node{ + Kind: ast.Key, + Data: key, + }) + } else { + break + } + } + + return nodes, b, nil +} + +func (p *parser) parseSimpleKey(b []byte) (key, rest []byte, err error) { + //simple-key = quoted-key / unquoted-key + //unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ + //quoted-key = basic-string / literal-string + + if len(b) == 0 { + return nil, nil, unexpectedCharacter{b: b} + } + + if b[0] == '\'' { + key, rest, err = scanLiteralString(b) + } else if b[0] == '"' { + key, rest, err = p.parseBasicString(b) + } else if isUnquotedKeyChar(b[0]) { + key, rest, err = scanUnquotedKey(b) + } else { + err = unexpectedCharacter{b: b} + } + return +} + +func (p *parser) parseBasicString(b []byte) ([]byte, []byte, error) { + //basic-string = quotation-mark *basic-char quotation-mark + //quotation-mark = %x22 ; " + //basic-char = basic-unescaped / escaped + //basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii + //escaped = escape escape-seq-char + //escape-seq-char = %x22 ; " quotation mark U+0022 + //escape-seq-char =/ %x5C ; \ reverse solidus U+005C + //escape-seq-char =/ %x62 ; b backspace U+0008 + //escape-seq-char =/ %x66 ; f form feed U+000C + //escape-seq-char =/ %x6E ; n line feed U+000A + //escape-seq-char =/ %x72 ; r carriage return U+000D + //escape-seq-char =/ %x74 ; t tab U+0009 + //escape-seq-char =/ %x75 4HEXDIG ; uXXXX U+XXXX + //escape-seq-char =/ %x55 8HEXDIG ; UXXXXXXXX U+XXXXXXXX + + token, rest, err := scanBasicString(b) + if err != nil { + return nil, nil, err + } + var builder bytes.Buffer + + // The scanner ensures that the token starts and ends with quotes and that + // escapes are balanced. + for i := 1; i < len(token)-1; i++ { + c := token[i] + if c == '\\' { + i++ + c = token[i] + switch c { + case '"', '\\': + builder.WriteByte(c) + case 'b': + builder.WriteByte('\b') + case 'f': + builder.WriteByte('\f') + case 'n': + builder.WriteByte('\n') + case 'r': + builder.WriteByte('\r') + case 't': + builder.WriteByte('\t') + case 'u': + x, err := hexToString(token[i+1:len(token)-1], 4) + if err != nil { + return nil, nil, err + } + builder.WriteString(x) + i += 4 + case 'U': + x, err := hexToString(token[i+1:len(token)-1], 8) + if err != nil { + return nil, nil, err + } + builder.WriteString(x) + i += 8 + default: + return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) + } + } else { + builder.WriteByte(c) + } + } + + return builder.Bytes(), rest, nil +} + +func hexToString(b []byte, length int) (string, error) { + if len(b) < length { + return "", fmt.Errorf("unicode point needs %d hex characters", length) + } + // TODO: slow + b, err := hex.DecodeString(string(b[:length])) + if err != nil { + return "", err + } + return string(b), nil +} + +func (p *parser) parseWhitespace(b []byte) []byte { + //ws = *wschar + //wschar = %x20 ; Space + //wschar =/ %x09 ; Horizontal tab + + _, rest := scanWhitespace(b) + return rest +} + +func (p *parser) parseIntOrFloatOrDateTime(b []byte) ([]byte, error) { + switch b[0] { + case 'i': + if !scanFollowsInf(b) { + return nil, fmt.Errorf("expected 'inf'") + } + //p.builder.FloatValue(math.Inf(1)) + // TODO + return b[3:], nil + case 'n': + if !scanFollowsNan(b) { + return nil, fmt.Errorf("expected 'nan'") + } + //p.builder.FloatValue(math.NaN()) + // TODO + return b[3:], nil + case '+', '-': + return p.parseIntOrFloat(b) + } + + if len(b) < 3 { + return p.parseIntOrFloat(b) + } + s := 5 + if len(b) < s { + s = len(b) + } + for idx, c := range b[:s] { + if c >= '0' && c <= '9' { + continue + } + if idx == 2 && c == ':' { + return p.parseDateTime(b) + } + if idx == 4 && c == '-' { + return p.parseDateTime(b) + } + } + return p.parseIntOrFloat(b) +} + +func digitsToInt(b []byte) int { + x := 0 + for _, d := range b { + x *= 10 + x += int(d - '0') + } + return x +} + +func (p *parser) parseDateTime(b []byte) ([]byte, error) { + // we know the first 2 ar digits. + if b[2] == ':' { + return p.parseTime(b) + } + // This state accepts an offset date-time, a local date-time, or a local date. + // + // v--- cursor + // 1979-05-27T07:32:00Z + // 1979-05-27T00:32:00-07:00 + // 1979-05-27T00:32:00.999999-07:00 + // 1979-05-27 07:32:00Z + // 1979-05-27 00:32:00-07:00 + // 1979-05-27 00:32:00.999999-07:00 + // 1979-05-27T07:32:00 + // 1979-05-27T00:32:00.999999 + // 1979-05-27 07:32:00 + // 1979-05-27 00:32:00.999999 + // 1979-05-27 + + // date + + idx := 4 + + localDate := toml.LocalDate{ + Year: digitsToInt(b[:idx]), + } + + for i := 0; i < 2; i++ { + // month + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid month digit in date: %c", b[idx]) + } + localDate.Month *= 10 + localDate.Month += time.Month(b[idx] - '0') + } + + idx++ + if b[idx] != '-' { + return nil, fmt.Errorf("expected - to separate month of a date, not %c", b[idx]) + } + + for i := 0; i < 2; i++ { + // day + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid day digit in date: %c", b[idx]) + } + localDate.Day *= 10 + localDate.Day += int(b[idx] - '0') + } + + idx++ + + if idx >= len(b) { + //p.builder.LocalDateValue(localDate) + // TODO + return nil, nil + } else if b[idx] != ' ' && b[idx] != 'T' { + //p.builder.LocalDateValue(localDate) + // TODO + return b[idx:], nil + } + + // check if there is a chance there is anything useful after + if b[idx] == ' ' && (((idx + 2) >= len(b)) || !isDigit(b[idx+1]) || !isDigit(b[idx+2])) { + //p.builder.LocalDateValue(localDate) + // TODO + return b[idx:], nil + } + + //idx++ // skip the T or ' ' + + // time + localTime := toml.LocalTime{} + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid hour digit in time: %c", b[idx]) + } + localTime.Hour *= 10 + localTime.Hour += int(b[idx] - '0') + } + + idx++ + if b[idx] != ':' { + return nil, fmt.Errorf("time hour/minute separator should be :, not %c", b[idx]) + } + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid minute digit in time: %c", b[idx]) + } + localTime.Minute *= 10 + localTime.Minute += int(b[idx] - '0') + } + + idx++ + if b[idx] != ':' { + return nil, fmt.Errorf("time minute/second separator should be :, not %c", b[idx]) + } + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid second digit in time: %c", b[idx]) + } + localTime.Second *= 10 + localTime.Second += int(b[idx] - '0') + } + + idx++ + if idx < len(b) && b[idx] == '.' { + idx++ + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("expected at least one digit in time's fraction, not %c", b[idx]) + } + + for { + localTime.Nanosecond *= 10 + localTime.Nanosecond += int(b[idx] - '0') + idx++ + + if idx < len(b) { + break + } + + if !isDigit(b[idx]) { + break + } + } + } + + if idx >= len(b) || (b[idx] != 'Z' && b[idx] != '+' && b[idx] != '-') { + dt := toml.LocalDateTime{ + Date: localDate, + Time: localTime, + } + //p.builder.LocalDateTimeValue(dt) + // TODO + dt = dt + return b[idx:], nil + } + + loc := time.UTC + + if b[idx] == 'Z' { + idx++ + } else { + start := idx + sign := 1 + if b[idx] == '-' { + sign = -1 + } + + hours := 0 + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid hour digit in time offset: %c", b[idx]) + } + hours *= 10 + hours += int(b[idx] - '0') + } + offset := hours * 60 * 60 + + idx++ + if b[idx] != ':' { + return nil, fmt.Errorf("time offset hour/minute separator should be :, not %c", b[idx]) + } + + minutes := 0 + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid minute digit in time offset: %c", b[idx]) + } + minutes *= 10 + minutes += int(b[idx] - '0') + } + offset += minutes * 60 + offset *= sign + idx++ + loc = time.FixedZone(string(b[start:idx]), offset) + } + dt := time.Date(localDate.Year, localDate.Month, localDate.Day, localTime.Hour, localTime.Minute, localTime.Second, localTime.Nanosecond, loc) + //p.builder.DateTimeValue(dt) + // TODO + dt = dt + return b[idx:], nil +} + +func (p *parser) parseTime(b []byte) ([]byte, error) { + localTime := toml.LocalTime{} + + idx := 0 + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid hour digit in time: %c", b[idx]) + } + localTime.Hour *= 10 + localTime.Hour += int(b[idx] - '0') + } + + idx++ + if b[idx] != ':' { + return nil, fmt.Errorf("time hour/minute separator should be :, not %c", b[idx]) + } + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid minute digit in time: %c", b[idx]) + } + localTime.Minute *= 10 + localTime.Minute += int(b[idx] - '0') + } + + idx++ + if b[idx] != ':' { + return nil, fmt.Errorf("time minute/second separator should be :, not %c", b[idx]) + } + + for i := 0; i < 2; i++ { + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("invalid second digit in time: %c", b[idx]) + } + localTime.Second *= 10 + localTime.Second += int(b[idx] - '0') + } + + idx++ + if idx < len(b) && b[idx] == '.' { + idx++ + idx++ + if !isDigit(b[idx]) { + return nil, fmt.Errorf("expected at least one digit in time's fraction, not %c", b[idx]) + } + + for { + localTime.Nanosecond *= 10 + localTime.Nanosecond += int(b[idx] - '0') + idx++ + if !isDigit(b[idx]) { + break + } + } + } + + //p.builder.LocalTimeValue(localTime) + // TODO + return b[idx:], nil +} + +func (p *parser) parseIntOrFloat(b []byte) ([]byte, error) { + i := 0 + r := b[0] + if r == '0' { + if len(b) >= 2 { + var isValidRune validRuneFn + var parseFn func([]byte) (int64, error) + switch b[1] { + case 'x': + isValidRune = isValidHexRune + parseFn = parseIntHex + case 'o': + isValidRune = isValidOctalRune + parseFn = parseIntOct + case 'b': + isValidRune = isValidBinaryRune + parseFn = parseIntBin + default: + if b[1] >= 'a' && b[1] <= 'z' || b[1] >= 'A' && b[1] <= 'Z' { + return nil, fmt.Errorf("unknown number base: %s. possible options are x (hex) o (octal) b (binary)", string(b[1])) + } + parseFn = parseIntDec + } + + if isValidRune != nil { + i = 2 + digitSeen := false + for { + if !isValidRune(b[i]) { + break + } + digitSeen = true + i++ + } + + if !digitSeen { + return nil, fmt.Errorf("number needs at least one digit") + } + + v, err := parseFn(b[:i]) + if err != nil { + return nil, err + } + //p.builder.IntValue(v) + // TODO + v = v + return b[i:], nil + } + } + } + + if r == '+' || r == '-' { + b = b[1:] + if scanFollowsInf(b) { + if r == '+' { + //p.builder.FloatValue(plusInf) + // TODO + } else { + //p.builder.FloatValue(minusInf) + // TODO + } + return b, nil + } + if scanFollowsNan(b) { + //p.builder.FloatValue(nan) + // TODO + return b, nil + } + } + + pointSeen := false + expSeen := false + digitSeen := false + for i < len(b) { + next := b[i] + if next == '.' { + if pointSeen { + return nil, fmt.Errorf("cannot have two dots in one float") + } + i++ + if i < len(b) && !isDigit(b[i]) { + return nil, fmt.Errorf("float cannot end with a dot") + } + pointSeen = true + } else if next == 'e' || next == 'E' { + expSeen = true + i++ + if i >= len(b) { + break + } + if b[i] == '+' || b[i] == '-' { + i++ + } + } else if isDigit(next) { + digitSeen = true + i++ + } else if next == '_' { + i++ + } else { + break + } + if pointSeen && !digitSeen { + return nil, fmt.Errorf("cannot start float with a dot") + } + } + + if !digitSeen { + return nil, fmt.Errorf("no digit in that number") + } + if pointSeen || expSeen { + f, err := parseFloat(b[:i]) + if err != nil { + return nil, err + } + //p.builder.FloatValue(f) + // TODO + f = f + } else { + v, err := parseIntDec(b[:i]) + if err != nil { + return nil, err + } + //p.builder.IntValue(v) + // TODO + v = v + } + return b[i:], nil +} + +func parseFloat(b []byte) (float64, error) { + // TODO: inefficient + tok := string(b) + err := numberContainsInvalidUnderscore(tok) + if err != nil { + return 0, err + } + cleanedVal := cleanupNumberToken(tok) + return strconv.ParseFloat(cleanedVal, 64) +} + +func parseIntHex(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := hexNumberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, nil + } + return strconv.ParseInt(cleanedVal[2:], 16, 64) +} + +func parseIntOct(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 8, 64) +} + +func parseIntBin(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal[2:], 2, 64) +} + +func parseIntDec(b []byte) (int64, error) { + tok := string(b) + cleanedVal := cleanupNumberToken(tok) + err := numberContainsInvalidUnderscore(cleanedVal) + if err != nil { + return 0, err + } + return strconv.ParseInt(cleanedVal, 10, 64) +} + +func numberContainsInvalidUnderscore(value string) error { + // For large numbers, you may use underscores between digits to enhance + // readability. Each underscore must be surrounded by at least one digit on + // each side. + + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscore + } + } + hasBefore = isDigitRune(r) + } + return nil +} + +func hexNumberContainsInvalidUnderscore(value string) error { + hasBefore := false + for idx, r := range value { + if r == '_' { + if !hasBefore || idx+1 >= len(value) { + // can't end with an underscore + return errInvalidUnderscoreHex + } + } + hasBefore = isHexDigit(r) + } + return nil +} + +func cleanupNumberToken(value string) string { + cleanedVal := strings.Replace(value, "_", "", -1) + return cleanedVal +} + +func isDigit(r byte) bool { + return r >= '0' && r <= '9' +} + +func isDigitRune(r rune) bool { + return r >= '0' && r <= '9' +} + +var plusInf = math.Inf(1) +var minusInf = math.Inf(-1) +var nan = math.NaN() + +type validRuneFn func(r byte) bool + +func isValidHexRune(r byte) bool { + return r >= 'a' && r <= 'f' || + r >= 'A' && r <= 'F' || + r >= '0' && r <= '9' || + r == '_' +} + +func isHexDigit(r rune) bool { + return isDigitRune(r) || + (r >= 'a' && r <= 'f') || + (r >= 'A' && r <= 'F') +} + +func isValidOctalRune(r byte) bool { + return r >= '0' && r <= '7' || r == '_' +} + +func isValidBinaryRune(r byte) bool { + return r == '0' || r == '1' || r == '_' +} + +func expect(x byte, b []byte) ([]byte, error) { + if len(b) == 0 || b[0] != x { + return nil, unexpectedCharacter{r: x, b: b} + } + return b[1:], nil +} + +type unexpectedCharacter struct { + r byte + b []byte +} + +func (u unexpectedCharacter) Error() string { + if len(u.b) == 0 { + return fmt.Sprintf("expected %#U, not EOF", u.r) + + } + return fmt.Sprintf("expected %#U, not %#U", u.r, u.b[0]) +} + +var errInvalidUnderscore = errors.New("invalid use of _ in number") +var errInvalidUnderscoreHex = errors.New("invalid use of _ in hex number") diff --git a/internal/unmarshaler/parser_test.go b/internal/unmarshaler/parser_test.go new file mode 100644 index 0000000..f4be26b --- /dev/null +++ b/internal/unmarshaler/parser_test.go @@ -0,0 +1,50 @@ +package unmarshaler + +import ( + "testing" + + "github.com/pelletier/go-toml/v2/internal/ast" + "github.com/stretchr/testify/require" +) + +func TestParser_Simple(t *testing.T) { + examples := []struct { + desc string + input string + ast ast.Root + err bool + }{ + { + desc: "simple string assignment", + input: `A = "hello"`, + ast: ast.Root{ + ast.Node{ + Kind: ast.KeyValue, + Children: []ast.Node{ + { + Kind: ast.Key, + Data: []byte(`A`), + }, + { + Kind: ast.String, + Data: []byte(`hello`), + }, + }, + }, + }, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + p := parser{} + err := p.parse([]byte(e.input)) + if e.err { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, e.ast, p.tree) + } + }) + } +} diff --git a/internal/unmarshaler/scanner.go b/internal/unmarshaler/scanner.go new file mode 100644 index 0000000..a13e6b6 --- /dev/null +++ b/internal/unmarshaler/scanner.go @@ -0,0 +1,168 @@ +package unmarshaler + +import "fmt" + +func scanFollows(pattern []byte) func(b []byte) bool { + return func(b []byte) bool { + if len(b) < len(pattern) { + return false + } + for i, c := range pattern { + if b[i] != c { + return false + } + } + return true + } +} + +var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'}) +var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''}) +var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'}) +var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'}) +var scanFollowsInf = scanFollows([]byte{'i', 'n', 'f'}) +var scanFollowsNan = scanFollows([]byte{'n', 'a', 'n'}) + +func scanUnquotedKey(b []byte) ([]byte, []byte, error) { + //unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ + for i := 0; i < len(b); i++ { + if !isUnquotedKeyChar(b[i]) { + return b[:i], b[i:], nil + } + } + return b, nil, nil +} + +func isUnquotedKeyChar(r byte) bool { + return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' +} + +func scanLiteralString(b []byte) ([]byte, []byte, error) { + //literal-string = apostrophe *literal-char apostrophe + //apostrophe = %x27 ; ' apostrophe + //literal-char = %x09 / %x20-26 / %x28-7E / non-ascii + for i := 1; i < len(b); i++ { + switch b[i] { + case '\'': + return b[:i+1], b[i+1:], nil + case '\n': + return nil, nil, fmt.Errorf("literal strings cannot have new lines") + } + } + return nil, nil, fmt.Errorf("unterminated literal string") +} + +func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { + //ml-literal-string = ml-literal-string-delim [ newline ] ml-literal-body + //ml-literal-string-delim + //ml-literal-string-delim = 3apostrophe + //ml-literal-body = *mll-content *( mll-quotes 1*mll-content ) [ mll-quotes ] + // + //mll-content = mll-char / newline + //mll-char = %x09 / %x20-26 / %x28-7E / non-ascii + //mll-quotes = 1*2apostrophe + for i := 3; i < len(b); i++ { + switch b[i] { + case '\'': + if scanFollowsMultilineLiteralStringDelimiter(b[i:]) { + return b[:i+3], b[:i+3], nil + } + } + } + + return nil, nil, fmt.Errorf(`multiline literal string not terminated by '''`) +} + +func scanWindowsNewline(b []byte) ([]byte, []byte, error) { + if len(b) < 2 { + return nil, nil, fmt.Errorf(`windows new line missing \n`) + } + if b[1] != '\n' { + return nil, nil, fmt.Errorf(`windows new line should be \r\n`) + } + return b[:2], b[2:], nil +} + +func scanWhitespace(b []byte) ([]byte, []byte) { + for i := 0; i < len(b); i++ { + switch b[i] { + case ' ', '\t': + continue + default: + return b[:i], b[i:] + } + } + return b, nil +} + +func scanComment(b []byte) ([]byte, []byte, error) { + //;; Comment + // + //comment-start-symbol = %x23 ; # + //non-ascii = %x80-D7FF / %xE000-10FFFF + //non-eol = %x09 / %x20-7F / non-ascii + // + //comment = comment-start-symbol *non-eol + + for i := 1; i < len(b); i++ { + switch b[i] { + case '\n': + return b[:i], b[i:], nil + } + } + return b, nil, nil +} + +// TODO perform validation on the string? +func scanBasicString(b []byte) ([]byte, []byte, error) { + //basic-string = quotation-mark *basic-char quotation-mark + //quotation-mark = %x22 ; " + //basic-char = basic-unescaped / escaped + //basic-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii + //escaped = escape escape-seq-char + for i := 1; i < len(b); i++ { + switch b[i] { + case '"': + return b[:i+1], b[i+1:], nil + case '\n': + return nil, nil, fmt.Errorf("basic strings cannot have new lines") + case '\\': + if len(b) < i+2 { + return nil, nil, fmt.Errorf("need a character after \\") + } + i++ // skip the next character + } + } + + return nil, nil, fmt.Errorf(`basic string not terminated by "`) +} + +// TODO perform validation on the string? +func scanMultilineBasicString(b []byte) ([]byte, []byte, error) { + //ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body + //ml-basic-string-delim + //ml-basic-string-delim = 3quotation-mark + //ml-basic-body = *mlb-content *( mlb-quotes 1*mlb-content ) [ mlb-quotes ] + // + //mlb-content = mlb-char / newline / mlb-escaped-nl + //mlb-char = mlb-unescaped / escaped + //mlb-quotes = 1*2quotation-mark + //mlb-unescaped = wschar / %x21 / %x23-5B / %x5D-7E / non-ascii + //mlb-escaped-nl = escape ws newline *( wschar / newline ) + + for i := 3; i < len(b); i++ { + switch b[i] { + case '"': + if scanFollowsMultilineBasicStringDelimiter(b[i:]) { + return b[:i+3], b[i+3:], nil + } + case '\\': + if len(b) < i+2 { + return nil, nil, fmt.Errorf("need a character after \\") + } + i++ // skip the next character + } + } + + return nil, nil, fmt.Errorf(`multiline basic string not terminated by """`) +} diff --git a/internal/unmarshaler/targets.go b/internal/unmarshaler/targets.go new file mode 100644 index 0000000..1c9fd25 --- /dev/null +++ b/internal/unmarshaler/targets.go @@ -0,0 +1,94 @@ +package unmarshaler + +import ( + "fmt" + "reflect" +) + +type target interface { + // Ensure the target's reflect value is not nil. + ensure() + + // Store a string at the target. + setString(v string) error + + // Appends an arbitrary value to the container. + pushValue(v reflect.Value) error + + // Dereferences the target. + get() reflect.Value +} + +// struct target just contain the reflect.Value of the target field. +type structTarget reflect.Value + +func (t structTarget) get() reflect.Value { + return reflect.Value(t) +} + +func (t structTarget) ensure() { + f := t.get() + if !f.IsNil() { + return + } + + switch f.Kind() { + case reflect.Slice: + f.Set(reflect.MakeSlice(f.Type(), 0, 0)) + default: + panic(fmt.Errorf("don't know how to ensure %s", f.Kind())) + } +} + +func (t structTarget) setString(v string) error { + f := t.get() + if f.Kind() != reflect.String { + return fmt.Errorf("cannot assign string to a %s", f.String()) + } + f.SetString(v) + return nil +} + +func (t structTarget) pushValue(v reflect.Value) error { + f := t.get() + + switch f.Kind() { + case reflect.Slice: + t.ensure() + f.Set(reflect.Append(f, v)) + default: + return fmt.Errorf("cannot push %s on a %s", v.Kind(), f.Kind()) + } + + return nil +} + +func scope(v reflect.Value, name string) (target, error) { + switch v.Kind() { + case reflect.Struct: + return scopeStruct(v, name) + default: + panic(fmt.Errorf("can't scope on a %s", v.Kind())) + } +} + +func scopeStruct(v reflect.Value, name string) (target, error) { + // TODO: cache this + t := v.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.PkgPath != "" { + // only consider exported fields + continue + } + if f.Anonymous { + // TODO: handle embedded structs + } else { + // TODO: handle names variations + if f.Name == name { + return structTarget(v.Field(i)), nil + } + } + } + return nil, fmt.Errorf("field '%s' not found on %s", name, v.Type()) +} diff --git a/internal/unmarshaler/targets_test.go b/internal/unmarshaler/targets_test.go new file mode 100644 index 0000000..7e14d25 --- /dev/null +++ b/internal/unmarshaler/targets_test.go @@ -0,0 +1,166 @@ +package unmarshaler + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStructTarget_Ensure(t *testing.T) { + examples := []struct { + desc string + input reflect.Value + name string + test func(v reflect.Value) + }{ + { + desc: "handle a nil slice of string", + input: reflect.ValueOf(&struct{ A []string }{}).Elem(), + name: "A", + test: func(v reflect.Value) { + assert.False(t, v.IsNil()) + }, + }, + { + desc: "handle an existing slice of string", + input: reflect.ValueOf(&struct{ A []string }{A: []string{"foo"}}).Elem(), + name: "A", + test: func(v reflect.Value) { + require.False(t, v.IsNil()) + s := v.Interface().([]string) + assert.Equal(t, []string{"foo"}, s) + }, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + target, err := scope(e.input, e.name) + require.NoError(t, err) + target.ensure() + v := target.get() + e.test(v) + }) + } +} + +func TestStructTarget_SetString(t *testing.T) { + str := "value" + + examples := []struct { + desc string + input reflect.Value + name string + test func(v reflect.Value, err error) + }{ + { + desc: "sets a string", + input: reflect.ValueOf(&struct{ A string }{}).Elem(), + name: "A", + test: func(v reflect.Value, err error) { + assert.NoError(t, err) + assert.Equal(t, str, v.String()) + }, + }, + { + desc: "fails on a float", + input: reflect.ValueOf(&struct{ A float64 }{}).Elem(), + name: "A", + test: func(v reflect.Value, err error) { + assert.Error(t, err) + }, + }, + { + desc: "fails on a slice", + input: reflect.ValueOf(&struct{ A []string }{}).Elem(), + name: "A", + test: func(v reflect.Value, err error) { + assert.Error(t, err) + }, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + target, err := scope(e.input, e.name) + require.NoError(t, err) + err = target.setString(str) + v := target.get() + e.test(v, err) + }) + } +} + +func TestPushValue_Struct(t *testing.T) { + examples := []struct { + desc string + input reflect.Value + expected []string + error bool + }{ + { + desc: "push to nil slice", + input: reflect.ValueOf(&struct{ A []string }{}).Elem(), + expected: []string{"hello"}, + }, + { + desc: "push to string", + input: reflect.ValueOf(&struct{ A string }{}).Elem(), + error: true, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + target, err := scope(e.input, "A") + require.NoError(t, err) + v := reflect.ValueOf("hello") + err = target.pushValue(v) + if e.error { + require.Error(t, err) + } else { + require.NoError(t, err) + x := target.get().Interface().([]string) + assert.Equal(t, e.expected, x) + } + }) + } +} + +func TestScope_Struct(t *testing.T) { + examples := []struct { + desc string + input reflect.Value + name string + err bool + idx []int + }{ + { + desc: "simple field", + input: reflect.ValueOf(&struct{ A string }{}).Elem(), + name: "A", + idx: []int{0}, + }, + { + desc: "fails not-exported field", + input: reflect.ValueOf(&struct{ a string }{}).Elem(), + name: "a", + err: true, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + x, err := scope(e.input, e.name) + if e.err { + require.Error(t, err) + } else { + x2, ok := x.(structTarget) + require.True(t, ok) + x2.get() + } + }) + } +} diff --git a/internal/unmarshaler/unmarshaler.go b/internal/unmarshaler/unmarshaler.go new file mode 100644 index 0000000..82227b3 --- /dev/null +++ b/internal/unmarshaler/unmarshaler.go @@ -0,0 +1,69 @@ +package unmarshaler + +import ( + "fmt" + "reflect" + + "github.com/pelletier/go-toml/v2/internal/ast" +) + +func FromAst(tree ast.Root, target interface{}) error { + x := reflect.ValueOf(target) + if x.Kind() != reflect.Ptr { + return fmt.Errorf("need to target a pointer, not %s", x.Kind()) + } + if x.IsNil() { + return fmt.Errorf("target pointer must be non-nil") + } + + for _, node := range tree { + err := topLevelNode(x, &node) + if err != nil { + return err + } + } + + return nil +} + +func topLevelNode(x reflect.Value, node *ast.Node) error { + if x.Kind() != reflect.Ptr { + panic("topLevelNode should receive target, which should be a pointer") + } + if x.IsNil() { + panic("topLevelNode should receive target, which should not be a nil pointer") + } + + switch node.Kind { + case ast.Table: + panic("TODO") + case ast.ArrayTable: + panic("TODO") + case ast.KeyValue: + return keyValue(x, node) + default: + panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) + } +} + +func keyValue(x reflect.Value, node *ast.Node) error { + assertNode(ast.KeyValue, node) + assertPtr(x) + + key := node.Key() + key = key + // TODO + return nil +} + +func assertNode(expected ast.Kind, node *ast.Node) { + if node.Kind != expected { + panic(fmt.Errorf("expected node of kind %s, not %s", expected, node.Kind)) + } +} + +func assertPtr(x reflect.Value) { + if x.Kind() != reflect.Ptr { + panic(fmt.Errorf("should be a pointer, not a %s", x.Kind())) + } +} diff --git a/internal/unmarshaler/unmarshaler_test.go b/internal/unmarshaler/unmarshaler_test.go new file mode 100644 index 0000000..da71743 --- /dev/null +++ b/internal/unmarshaler/unmarshaler_test.go @@ -0,0 +1,39 @@ +package unmarshaler_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pelletier/go-toml/v2/internal/ast" + "github.com/pelletier/go-toml/v2/internal/unmarshaler" +) + +func TestFromAst_KV(t *testing.T) { + t.Skipf("later") + root := ast.Root{ + ast.Node{ + Kind: ast.KeyValue, + Children: []ast.Node{ + { + Kind: ast.Key, + Data: []byte(`Foo`), + }, + { + Kind: ast.String, + Data: []byte(`hello`), + }, + }, + }, + } + + type Doc struct { + Foo string + } + + x := Doc{} + err := unmarshaler.FromAst(root, &x) + require.NoError(t, err) + assert.Equal(t, Doc{Foo: "hello"}, x) +} diff --git a/unmarshal.go b/unmarshal.go index 84ee28d..80efa52 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -1,6 +1,7 @@ package toml import ( + "fmt" "reflect" "time" @@ -62,6 +63,7 @@ func (u *unmarshaler) Assignation() { return } u.assign = true + fmt.Println("ASSIGN: TRUE!") } func (u *unmarshaler) ArrayBegin() { @@ -73,11 +75,12 @@ func (u *unmarshaler) ArrayBegin() { if u.err != nil { return } - if u.assign { - u.assign = false - } else { - u.err = u.builder.SliceNewElem() + fmt.Println("ARRAY BEGIN ASSIGN =", u.assign) + if !u.assign { + //u.err = u.builder.SliceNewSlice() + // TODO } + u.assign = false } func (u *unmarshaler) ArrayEnd() { @@ -126,8 +129,15 @@ func (u *unmarshaler) InlineTableBegin() { return } - // TODO + u.builder.Save() + if u.builder.IsSliceOrPtr() { + u.err = u.builder.SliceNewElem() + } else { + u.err = u.builder.EnsureStructOrMap() + } + + u.assign = false } func (u *unmarshaler) InlineTableEnd() { @@ -135,7 +145,7 @@ func (u *unmarshaler) InlineTableEnd() { return } - // TODO + u.builder.Load() } func (u *unmarshaler) KeyValBegin() { @@ -176,6 +186,7 @@ func (u *unmarshaler) StringValue(v []byte) { s := string(v) u.err = u.builder.Set(reflect.ValueOf(&s)) } + u.assign = false } func (u *unmarshaler) BoolValue(b bool) { @@ -192,6 +203,7 @@ func (u *unmarshaler) BoolValue(b bool) { } else { u.err = u.builder.SetBool(b) } + u.assign = false } func (u *unmarshaler) FloatValue(n float64) { @@ -209,6 +221,7 @@ func (u *unmarshaler) FloatValue(n float64) { u.err = u.builder.Set(reflect.ValueOf(&n)) //u.err = u.builder.SetFloat(n) } + u.assign = false } func (u *unmarshaler) IntValue(n int64) { @@ -225,6 +238,7 @@ func (u *unmarshaler) IntValue(n int64) { } else { u.err = u.builder.Set(reflect.ValueOf(&n)) } + u.assign = false } func (u *unmarshaler) LocalDateValue(date LocalDate) { @@ -241,6 +255,7 @@ func (u *unmarshaler) LocalDateValue(date LocalDate) { } else { u.err = u.builder.Set(reflect.ValueOf(&date)) } + u.assign = false } func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) { @@ -257,6 +272,7 @@ func (u *unmarshaler) LocalDateTimeValue(dt LocalDateTime) { } else { u.err = u.builder.Set(reflect.ValueOf(&dt)) } + u.assign = false } func (u *unmarshaler) DateTimeValue(dt time.Time) { @@ -273,6 +289,7 @@ func (u *unmarshaler) DateTimeValue(dt time.Time) { } else { u.err = u.builder.Set(reflect.ValueOf(&dt)) } + u.assign = false } func (u *unmarshaler) LocalTimeValue(localTime LocalTime) { @@ -289,6 +306,7 @@ func (u *unmarshaler) LocalTimeValue(localTime LocalTime) { } else { u.err = u.builder.Set(reflect.ValueOf(&localTime)) } + u.assign = false } func (u *unmarshaler) SimpleKey(v []byte) { @@ -337,5 +355,5 @@ func (u *unmarshaler) StandardTableEnd() { return } - u.builder.EnsureStructOrMap() + u.builder.EnsureStructOrMap() // TODO: handle error }