diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..f9fa173 --- /dev/null +++ b/parser.go @@ -0,0 +1 @@ +package toml diff --git a/scanner.go b/scanner.go index 99d9f0a..4786b26 100644 --- a/scanner.go +++ b/scanner.go @@ -20,8 +20,6 @@ var scanFollowsMultilineBasicStringDelimiter = scanFollows([]byte{'"', '"', '"'} var scanFollowsMultilineLiteralStringDelimiter = scanFollows([]byte{'\'', '\'', '\''}) var scanFollowsTrue = scanFollows([]byte{'t', 'r', 'u', 'e'}) var scanFollowsFalse = scanFollows([]byte{'f', 'a', 'l', 's', 'e'}) -var scanFollowsArrayTableBegin = scanFollows([]byte{arrayOrTableBegin, arrayOrTableBegin}) -var scanFollowsArrayTableEnd = scanFollows([]byte{arrayOrTableEnd, arrayOrTableEnd}) func scanNewline(b []byte) ([]byte, []byte, error) { if len(b) == 0 { @@ -42,75 +40,6 @@ func scanNewline(b []byte) ([]byte, []byte, error) { return nil, nil, unexpectedCharacter{b: b} } -const ( - dot = '.' - equal = '=' - comma = ',' - inlineTableBegin = '{' - inlineTableEnd = '}' - comment = '#' - arrayOrTableBegin = '[' - arrayOrTableEnd = ']' -) - -// scan returns a []byte containing the next lexical token, bytes left, and an error. -// -// eof is signaled by an empty token and nil error. -func scan(b []byte) ([]byte, []byte, error) { - if len(b) == 0 { - return b, b, nil - } - - switch b[0] { - case dot, equal, inlineTableBegin, inlineTableEnd, comma: - return b[:1], b[1:], nil - case '"': - if scanFollowsMultilineBasicStringDelimiter(b) { - return scanMultilineBasicString(b) - } - return scanBasicString(b) - case '\'': - if scanFollowsMultilineLiteralStringDelimiter(b) { - return scanMultilineLiteralString(b) - } - return scanLiteralString(b) - case comment: - return scanComment(b) - case ' ', '\t': - data, rest := scanWhitespace(b) - return data, rest, nil - case '\r': - return scanWindowsNewline(b) - case '\n': - return b[:1], b[1:], nil - case 't': - if scanFollowsTrue(b) { - return b[:4], b[4:], nil - } - case 'f': - if scanFollowsFalse(b) { - return b[:5], b[5:], nil - } - case arrayOrTableBegin: - if scanFollowsArrayTableBegin(b) { - return b[:2], b[2:], nil - } - return b[:1], b[1:], nil - case arrayOrTableEnd: - if scanFollowsArrayTableEnd(b) { - return b[:2], b[2:], nil - } - return b[:1], b[1:], nil - } - - if isUnquotedKeyChar(b[0]) { - return scanUnquotedKey(b) - } - - // TODO: numbers, date-time - panic("unhandled scan") -} - func scanUnquotedKey(b []byte) ([]byte, []byte, error) { //unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ for i := 0; i < len(b); i++ { diff --git a/toml_test.go b/toml_test.go deleted file mode 100644 index 846e380..0000000 --- a/toml_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package toml - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -var inputs = []string{ - ` #foo`, - `#foo`, - `#`, - "\n\n\n", - "#one\n # two \n", - `a = false`, - `abc = false`, - ` abc = false # foo`, - `'abc' = false`, - `"foo bar" = false`, - `"hello\tworld" = false`, - `"hello \u1234 foo" = false`, - `a.b.c = false`, - `a."b".c = true`, - `a = "foo"`, - `b = 'sample thingy'`, - `a = []`, - `b = ["foo"]`, - `c = [[[]]]`, - `d = ["foo","bar"]`, - `d = ["foo", "test"]`, - `d = {}`, - `e = {f = "bar"}`, - `[foo]`, - `[ test ]`, - `[ "hello".world ]`, - `[test] - a = false`, - `[[foo]]`, -} - -func TestScan(t *testing.T) { - for i, data := range inputs { - t.Run(fmt.Sprintf("example %d", i), func(t *testing.T) { - fmt.Printf("input:\n\t`%s`\n", data) - b := []byte(data) - var token []byte - var err error - for len(b) > 0 { - token, b, err = scan(b) - require.NoError(t, err) - fmt.Printf("token => '%s'\n", string(token)) - } - }) - } -} - -func TestParse(t *testing.T) { - for i, data := range inputs { - t.Run(fmt.Sprintf("example %d", i), func(t *testing.T) { - fmt.Printf("input:\n\t`%s`\n", data) - b := []byte(data) - err := parse(b) - require.NoError(t, err) - }) - } -} - -//type noopParser struct { -//} -// -//func (n noopParser) ArrayTableBegin() {} -//func (n noopParser) ArrayTableEnd() {} -//func (n noopParser) StandardTableBegin() {} -//func (n noopParser) StandardTableEnd() {} -//func (n noopParser) InlineTableSeparator() {} -//func (n noopParser) InlineTableBegin() {} -//func (n noopParser) InlineTableEnd() {} -//func (n noopParser) ArraySeparator() {} -//func (n noopParser) ArrayBegin() {} -//func (n noopParser) ArrayEnd() {} -//func (n noopParser) Whitespace(b []byte) {} -//func (n noopParser) Comment(b []byte) {} -//func (n noopParser) UnquotedKey(b []byte) {} -//func (n noopParser) LiteralString(b []byte) {} -//func (n noopParser) BasicString(b []byte) {} -//func (n noopParser) Dot(b []byte) {} -//func (n noopParser) Boolean(b []byte) {} -//func (n noopParser) Equal(b []byte) {} - -// -//func BenchmarkParseAll(b *testing.B) { -// b.ReportAllocs() -// -// for i := 0; i < b.N; i++ { -// for _, data := range inputs { -// p := noopParser{} -// l := lexer{parser: &p, data: []byte(data)} -// err := l.run() -// if err != nil { -// b.Fatalf("error: %s", err) -// } -// } -// } -//} diff --git a/toml.go b/unmarshal.go similarity index 63% rename from toml.go rename to unmarshal.go index ffe1b1a..fcaeb95 100644 --- a/toml.go +++ b/unmarshal.go @@ -1,23 +1,147 @@ package toml import ( + "bytes" "encoding/hex" "fmt" - "strings" + "reflect" ) -func parse(b []byte) error { - b, err := parseExpression(b) +func Unmarshal(data []byte, v interface{}) error { + if v == nil { + return fmt.Errorf("cannot unmarshal to nil target") + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return fmt.Errorf("can only marshal to pointer, not %s", rv.Kind()) + } + + u := &unmarshaler{stack: []reflect.Value{rv.Elem()}} + parseErr := parser{builder: u}.parse(data) + if parseErr != nil { + return parseErr + } + return u.err +} + +type unmarshaler struct { + // Each stack frame is a pointer to the root object that should be + // considered when settings values. + // It at least contains the root object passed to Unmarshal. + stack []reflect.Value + + // First error that appeared during the construction of the object. + // When set all callbacks are no-ops. + err error + + // State that indicates the parser is processing a [table] name. If false + // keys are interpreted as part of a key-value. + parsingTable bool +} + +func (u *unmarshaler) KeyValBegin() { + u.push(u.top()) +} + +func (u *unmarshaler) KeyValEnd() { + u.pop() +} + +func getOrCreateChild(parent reflect.Value, key string) (reflect.Value, error) { + if parent.Type().Kind() != reflect.Struct { + return reflect.Value{}, fmt.Errorf("value of type '%s' cannot have children", parent) + } + f := parent.FieldByName(key) + if !f.IsValid() { + // TODO: implement alternative names + return reflect.Value{}, fmt.Errorf("field '%s' not found", key) + } + // TODO create things + return f, nil +} + +func (u *unmarshaler) top() reflect.Value { + return u.stack[len(u.stack)-1] +} + +func (u *unmarshaler) push(v reflect.Value) { + u.stack = append(u.stack, v) +} + +func (u *unmarshaler) pop() { + u.stack = u.stack[:len(u.stack)-1] +} + +func (u *unmarshaler) replace(v reflect.Value) { + u.stack[len(u.stack)-1] = v +} + +func (u *unmarshaler) StringValue(v []byte) { + if u.err != nil { + return + } + u.top().SetString(string(v)) +} + +func (u *unmarshaler) SimpleKey(v []byte) { + if u.err != nil { + return + } + + target, err := getOrCreateChild(u.top(), string(v)) + if err != nil { + u.err = err + return + } + + u.replace(target) +} + +func (u *unmarshaler) StandardTableBegin() { + if u.err != nil { + return + } + + // tables are only top-level + u.stack = u.stack[:1] +} + +func (u *unmarshaler) StandardTableEnd() { + if u.err != nil { + return + } + + panic("implement me") +} + +type builder interface { + SimpleKey(v []byte) + + StandardTableBegin() + StandardTableEnd() + + KeyValBegin() + KeyValEnd() + + StringValue(v []byte) +} + +type parser struct { + builder builder +} + +func (p parser) parse(b []byte) error { + b, err := p.parseExpression(b) if err != nil { return err } for len(b) > 0 { - b, err = parseNewline(b) + b, err = p.parseNewline(b) if err != nil { return err } - b, err = parseExpression(b) + b, err = p.parseExpression(b) if err != nil { return err } @@ -25,7 +149,7 @@ func parse(b []byte) error { return nil } -func parseNewline(b []byte) ([]byte, error) { +func (p parser) parseNewline(b []byte) ([]byte, error) { if b[0] == '\n' { return b[1:], nil } @@ -36,12 +160,12 @@ func parseNewline(b []byte) ([]byte, error) { return nil, fmt.Errorf("expected newline but got %#U", b[0]) } -func parseExpression(b []byte) ([]byte, error) { +func (p parser) parseExpression(b []byte) ([]byte, error) { //expression = ws [ comment ] //expression =/ ws keyval ws [ comment ] //expression =/ ws table ws [ comment ] - b = parseWhitespace(b) + b = p.parseWhitespace(b) if len(b) == 0 { return b, nil @@ -58,15 +182,15 @@ func parseExpression(b []byte) ([]byte, error) { var err error if b[0] == '[' { - b, err = parseTable(b) + b, err = p.parseTable(b) } else { - b, err = parseKeyval(b) + b, err = p.parseKeyval(b) } if err != nil { return nil, err } - b = parseWhitespace(b) + b = p.parseWhitespace(b) if len(b) > 0 && b[0] == '#' { _, rest, err := scanComment(b) @@ -76,26 +200,26 @@ func parseExpression(b []byte) ([]byte, error) { return b, nil } -func parseTable(b []byte) ([]byte, error) { +func (p parser) parseTable(b []byte) ([]byte, error) { //table = std-table / array-table if len(b) > 1 && b[1] == '[' { - return parseArrayTable(b) + return p.parseArrayTable(b) } - return parseStdTable(b) + return p.parseStdTable(b) } -func parseArrayTable(b []byte) ([]byte, error) { +func (p parser) parseArrayTable(b []byte) ([]byte, error) { //array-table = array-table-open key array-table-close //array-table-open = %x5B.5B ws ; [[ Double left square bracket //array-table-close = ws %x5D.5D ; ]] Double right square bracket b = b[2:] - b = parseWhitespace(b) - b, err := parseKey(b) + b = p.parseWhitespace(b) + b, err := p.parseKey(b) if err != nil { return nil, err } - b = parseWhitespace(b) + b = p.parseWhitespace(b) b, err = expect(']', b) if err != nil { return nil, err @@ -103,42 +227,49 @@ func parseArrayTable(b []byte) ([]byte, error) { return expect(']', b) } -func parseStdTable(b []byte) ([]byte, error) { +func (p parser) parseStdTable(b []byte) ([]byte, error) { //std-table = std-table-open key std-table-close //std-table-open = %x5B ws ; [ Left square bracket //std-table-close = ws %x5D ; ] Right square bracket + p.builder.StandardTableBegin() + defer p.builder.StandardTableEnd() + b = b[1:] - b = parseWhitespace(b) - b, err := parseKey(b) + b = p.parseWhitespace(b) + b, err := p.parseKey(b) if err != nil { return nil, err } - b = parseWhitespace(b) + b = p.parseWhitespace(b) + return expect(']', b) } -func parseKeyval(b []byte) ([]byte, error) { +func (p parser) parseKeyval(b []byte) ([]byte, error) { //keyval = key keyval-sep val - b, err := parseKey(b) + p.builder.KeyValBegin() + defer p.builder.KeyValEnd() + + b, err := p.parseKey(b) if err != nil { return nil, err } //keyval-sep = ws %x3D ws ; = - b = parseWhitespace(b) + b = p.parseWhitespace(b) b, err = expect('=', b) if err != nil { return nil, err } - b = parseWhitespace(b) + b = p.parseWhitespace(b) - return parseVal(b) + return p.parseVal(b) } -func parseVal(b []byte) ([]byte, error) { +func (p parser) parseVal(b []byte) ([]byte, error) { // val = string / boolean / array / inline-table / date-time / float / integer if len(b) == 0 { return nil, fmt.Errorf("expected value, not eof") @@ -150,15 +281,19 @@ func parseVal(b []byte) ([]byte, error) { switch c { // strings case '"': + var v []byte if scanFollowsMultilineBasicStringDelimiter(b) { - _, b, err = parseMultilineBasicString(b) + v, b, err = p.parseMultilineBasicString(b) } else { - _, b, err = parseBasicString(b) + v, b, err = p.parseBasicString(b) + } + if err == nil { + p.builder.StringValue(v) } return b, err case '\'': if scanFollowsMultilineLiteralStringDelimiter(b) { - _, b, err = parseMultilineLiteralString(b) + _, b, err = p.parseMultilineLiteralString(b) } else { _, b, err = scanLiteralString(b) } @@ -174,9 +309,9 @@ func parseVal(b []byte) ([]byte, error) { } return b[5:], nil case '[': - return parseValArray(b) + return p.parseValArray(b) case '{': - return parseInlineTable(b) + return p.parseInlineTable(b) // TODO date-time @@ -188,7 +323,7 @@ func parseVal(b []byte) ([]byte, error) { } } -func parseInlineTable(b []byte) ([]byte, error) { +func (p parser) parseInlineTable(b []byte) ([]byte, error) { //inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close //inline-table-open = %x7B ws ; { //inline-table-close = ws %x7D ; } @@ -200,7 +335,7 @@ func parseInlineTable(b []byte) ([]byte, error) { first := true var err error for len(b) > 0 { - b = parseWhitespace(b) + b = p.parseWhitespace(b) if b[0] == '}' { break } @@ -210,16 +345,19 @@ func parseInlineTable(b []byte) ([]byte, error) { if err != nil { return nil, err } - b = parseWhitespace(b) + b = p.parseWhitespace(b) + } + b, err = p.parseKeyval(b) + if err != nil { + return nil, err } - b, err = parseKeyval(b) first = false } return expect('}', b) } -func parseValArray(b []byte) ([]byte, error) { +func (p parser) parseValArray(b []byte) ([]byte, error) { //array = array-open [ array-values ] ws-comment-newline array-close //array-open = %x5B ; [ //array-close = %x5D ; ] @@ -233,7 +371,7 @@ func parseValArray(b []byte) ([]byte, error) { first := true var err error for len(b) > 0 { - b, err = parseOptionalWhitespaceCommentNewline(b) + b, err = p.parseOptionalWhitespaceCommentNewline(b) if err != nil { return nil, err } @@ -250,17 +388,17 @@ func parseValArray(b []byte) ([]byte, error) { return nil, fmt.Errorf("array cannot start with comma") } b = b[1:] - b, err = parseOptionalWhitespaceCommentNewline(b) + b, err = p.parseOptionalWhitespaceCommentNewline(b) if err != nil { return nil, err } } - b, err = parseVal(b) + b, err = p.parseVal(b) if err != nil { return nil, err } - b, err = parseOptionalWhitespaceCommentNewline(b) + b, err = p.parseOptionalWhitespaceCommentNewline(b) if err != nil { return nil, err } @@ -270,9 +408,9 @@ func parseValArray(b []byte) ([]byte, error) { return expect(']', b) } -func parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { +func (p parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { var err error - b = parseWhitespace(b) + b = p.parseWhitespace(b) if len(b) > 0 && b[0] == '#' { _, b, err = scanComment(b) if err != nil { @@ -280,7 +418,7 @@ func parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { } } if len(b) > 0 && (b[0] == '\n' || b[0] == '\r') { - b, err = parseNewline(b) + b, err = p.parseNewline(b) if err != nil { return nil, err } @@ -288,7 +426,7 @@ func parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { return b, nil } -func parseMultilineLiteralString(b []byte) (string, []byte, error) { +func (p parser) parseMultilineLiteralString(b []byte) (string, []byte, error) { token, rest, err := scanMultilineLiteralString(b) if err != nil { return "", nil, err @@ -306,7 +444,7 @@ func parseMultilineLiteralString(b []byte) (string, []byte, error) { return string(token[i : len(b)-3]), rest, err } -func parseMultilineBasicString(b []byte) (string, []byte, error) { +func (p parser) parseMultilineBasicString(b []byte) ([]byte, []byte, error) { //ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body //ml-basic-string-delim //ml-basic-string-delim = 3quotation-mark @@ -320,9 +458,9 @@ func parseMultilineBasicString(b []byte) (string, []byte, error) { token, rest, err := scanMultilineBasicString(b) if err != nil { - return "", nil, err + return nil, nil, err } - var builder strings.Builder + var builder bytes.Buffer i := 3 @@ -371,29 +509,29 @@ func parseMultilineBasicString(b []byte) (string, []byte, error) { case 'u': x, err := hexToString(token[i+3:len(token)-3], 4) if err != nil { - return "", nil, err + return nil, nil, err } builder.WriteString(x) i += 4 case 'U': x, err := hexToString(token[i+3:len(token)-3], 8) if err != nil { - return "", nil, err + return nil, nil, err } builder.WriteString(x) i += 8 default: - return "", nil, fmt.Errorf("invalid escaped character: %#U", c) + return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) } } else { builder.WriteByte(c) } } - return builder.String(), rest, nil + return builder.Bytes(), rest, nil } -func parseKey(b []byte) ([]byte, error) { +func (p parser) parseKey(b []byte) ([]byte, error) { //key = simple-key / dotted-key //simple-key = quoted-key / unquoted-key // @@ -403,20 +541,20 @@ func parseKey(b []byte) ([]byte, error) { // //dot-sep = ws %x2E ws ; . Period - b, err := parseSimpleKey(b) + b, err := p.parseSimpleKey(b) if err != nil { return nil, err } for { - b = parseWhitespace(b) + b = p.parseWhitespace(b) if len(b) > 0 && b[0] == '.' { b, err = expect('.', b) if err != nil { return nil, err } - b = parseWhitespace(b) - b, err = parseSimpleKey(b) + b = p.parseWhitespace(b) + b, err = p.parseSimpleKey(b) if err != nil { return nil, err } @@ -428,7 +566,7 @@ func parseKey(b []byte) ([]byte, error) { return b, nil } -func parseSimpleKey(b []byte) ([]byte, error) { +func (p parser) parseSimpleKey(b []byte) (rest []byte, err error) { //simple-key = quoted-key / unquoted-key //unquoted-key = 1*( ALPHA / DIGIT / %x2D / %x5F ) ; A-Z / a-z / 0-9 / - / _ //quoted-key = basic-string / literal-string @@ -437,24 +575,21 @@ func parseSimpleKey(b []byte) ([]byte, error) { return nil, unexpectedCharacter{b: b} } + var v []byte if b[0] == '\'' { - _, rest, err := scanLiteralString(b) - return rest, err + v, rest, err = scanLiteralString(b) + } else if b[0] == '"' { + v, rest, err = p.parseBasicString(b) + } else if isUnquotedKeyChar(b[0]) { + v, rest, err = scanUnquotedKey(b) + } else { + return nil, unexpectedCharacter{b: b} } - if b[0] == '"' { - _, rest, err := parseBasicString(b) - return rest, err - } - - if isUnquotedKeyChar(b[0]) { - _, rest, err := scanUnquotedKey(b) - return rest, err - } - - return nil, unexpectedCharacter{b: b} + p.builder.SimpleKey(v) + return } -func parseBasicString(b []byte) (string, []byte, error) { +func (p parser) parseBasicString(b []byte) ([]byte, []byte, error) { //basic-string = quotation-mark *basic-char quotation-mark //quotation-mark = %x22 ; " //basic-char = basic-unescaped / escaped @@ -472,9 +607,9 @@ func parseBasicString(b []byte) (string, []byte, error) { token, rest, err := scanBasicString(b) if err != nil { - return "", nil, err + return nil, nil, err } - var builder strings.Builder + var builder bytes.Buffer // The scanner ensures that the token starts and ends with quotes and that // escapes are balanced. @@ -499,26 +634,26 @@ func parseBasicString(b []byte) (string, []byte, error) { case 'u': x, err := hexToString(token[i+1:len(token)-1], 4) if err != nil { - return "", nil, err + return nil, nil, err } builder.WriteString(x) i += 4 case 'U': x, err := hexToString(token[i+1:len(token)-1], 8) if err != nil { - return "", nil, err + return nil, nil, err } builder.WriteString(x) i += 8 default: - return "", nil, fmt.Errorf("invalid escaped character: %#U", c) + return nil, nil, fmt.Errorf("invalid escaped character: %#U", c) } } else { builder.WriteByte(c) } } - return builder.String(), rest, nil + return builder.Bytes(), rest, nil } func hexToString(b []byte, length int) (string, error) { @@ -533,7 +668,7 @@ func hexToString(b []byte, length int) (string, error) { return string(b), nil } -func parseWhitespace(b []byte) []byte { +func (p parser) parseWhitespace(b []byte) []byte { //ws = *wschar //wschar = %x20 ; Space //wschar =/ %x09 ; Horizontal tab diff --git a/unmarshal_test.go b/unmarshal_test.go new file mode 100644 index 0000000..ad116ba --- /dev/null +++ b/unmarshal_test.go @@ -0,0 +1,34 @@ +package toml + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnmarshalSimple(t *testing.T) { + x := struct{ Foo string }{} + err := Unmarshal([]byte(`Foo = "hello"`), &x) + require.NoError(t, err) + assert.Equal(t, "hello", x.Foo) +} + +func TestUnmarshalNestedStructs(t *testing.T) { + x := struct{ Foo struct{ Bar string } }{} + err := Unmarshal([]byte(`Foo.Bar = "hello"`), &x) + require.NoError(t, err) + assert.Equal(t, "hello", x.Foo.Bar) +} + +func TestUnmarshalNestedStructsMultipleExpressions(t *testing.T) { + x := struct { + A struct{ B string } + C string + }{} + err := Unmarshal([]byte(`A.B = "hello" +C = "test"`), &x) + require.NoError(t, err) + assert.Equal(t, "hello", x.A.B) + assert.Equal(t, "test", x.C) +}