diff --git a/README.md b/README.md index 4ef303a..6831deb 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Go library for the [TOML](https://github.com/mojombo/toml) format. This library supports TOML version -[v0.5.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.5.0.md) +[v1.0.0-rc.1](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v1.0.0-rc.1.md) [![GoDoc](https://godoc.org/github.com/pelletier/go-toml?status.svg)](http://godoc.org/github.com/pelletier/go-toml) [![license](https://img.shields.io/github/license/pelletier/go-toml.svg)](https://github.com/pelletier/go-toml/blob/master/LICENSE) @@ -18,7 +18,7 @@ Go-toml provides the following features for using data parsed from TOML document * Load TOML documents from files and string data * Easily navigate TOML structure using Tree -* Mashaling and unmarshaling to and from data structures +* Marshaling and unmarshaling to and from data structures * Line & column position data for all parsed elements * [Query support similar to JSON-Path](query/) * Syntax errors contain line and column numbers @@ -74,7 +74,7 @@ Or use a query: q, _ := query.Compile("$..[user,password]") results := q.Execute(config) for ii, item := range results.Values() { - fmt.Println("Query result %d: %v", ii, item) + fmt.Printf("Query result %d: %v\n", ii, item) } ``` @@ -87,7 +87,7 @@ The documentation and additional examples are available at Go-toml provides two handy command line tools: -* `tomll`: Reads TOML files and lint them. +* `tomll`: Reads TOML files and lints them. ``` go install github.com/pelletier/go-toml/cmd/tomll @@ -99,9 +99,9 @@ Go-toml provides two handy command line tools: go install github.com/pelletier/go-toml/cmd/tomljson tomljson --help ``` - + * `jsontoml`: Reads a JSON file and outputs a TOML representation. - + ``` go install github.com/pelletier/go-toml/cmd/jsontoml jsontoml --help diff --git a/go.mod b/go.mod index 07a258b..c7faa6b 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,5 @@ go 1.12 require ( github.com/BurntSushi/toml v0.3.1 github.com/davecgh/go-spew v1.1.1 - gopkg.in/yaml.v2 v2.2.8 + gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index b30bc76..6f35647 100644 --- a/go.sum +++ b/go.sum @@ -15,3 +15,5 @@ gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/keysparsing.go b/keysparsing.go index e923bc4..e091500 100644 --- a/keysparsing.go +++ b/keysparsing.go @@ -5,7 +5,6 @@ package toml import ( "errors" "fmt" - "unicode" ) // Convert the bare key group string to an array. @@ -109,5 +108,5 @@ func parseKey(key string) ([]string, error) { } func isValidBareChar(r rune) bool { - return isAlphanumeric(r) || r == '-' || unicode.IsNumber(r) + return isAlphanumeric(r) || r == '-' || isDigit(r) } diff --git a/lexer.go b/lexer.go index 0bde0a1..425e847 100644 --- a/lexer.go +++ b/lexer.go @@ -26,7 +26,7 @@ type tomlLexer struct { currentTokenStart int currentTokenStop int tokens []token - depth int + brackets []rune line int col int endbufferLine int @@ -123,6 +123,8 @@ func (l *tomlLexer) lexVoid() tomlLexStateFn { for { next := l.peek() switch next { + case '}': // after '{' + return l.lexRightCurlyBrace case '[': return l.lexTableKey case '#': @@ -140,10 +142,6 @@ func (l *tomlLexer) lexVoid() tomlLexStateFn { l.skip() } - if l.depth > 0 { - return l.lexRvalue - } - if isKeyStartChar(next) { return l.lexKey } @@ -167,10 +165,8 @@ func (l *tomlLexer) lexRvalue() tomlLexStateFn { case '=': return l.lexEqual case '[': - l.depth++ return l.lexLeftBracket case ']': - l.depth-- return l.lexRightBracket case '{': return l.lexLeftCurlyBrace @@ -188,12 +184,10 @@ func (l *tomlLexer) lexRvalue() tomlLexStateFn { fallthrough case '\n': l.skip() - if l.depth == 0 { - return l.lexVoid + if len(l.brackets) > 0 && l.brackets[len(l.brackets)-1] == '[' { + return l.lexRvalue } - return l.lexRvalue - case '_': - return l.errorf("cannot start number with underscore") + return l.lexVoid } if l.follow("true") { @@ -236,10 +230,6 @@ func (l *tomlLexer) lexRvalue() tomlLexStateFn { return l.lexNumber } - if isAlphanumeric(next) { - return l.lexKey - } - return l.errorf("no value can start with %c", next) } @@ -250,12 +240,17 @@ func (l *tomlLexer) lexRvalue() tomlLexStateFn { func (l *tomlLexer) lexLeftCurlyBrace() tomlLexStateFn { l.next() l.emit(tokenLeftCurlyBrace) + l.brackets = append(l.brackets, '{') return l.lexVoid } func (l *tomlLexer) lexRightCurlyBrace() tomlLexStateFn { l.next() l.emit(tokenRightCurlyBrace) + if len(l.brackets) == 0 || l.brackets[len(l.brackets)-1] != '{' { + return l.errorf("cannot have '}' here") + } + l.brackets = l.brackets[:len(l.brackets)-1] return l.lexRvalue } @@ -302,6 +297,9 @@ func (l *tomlLexer) lexEqual() tomlLexStateFn { func (l *tomlLexer) lexComma() tomlLexStateFn { l.next() l.emit(tokenComma) + if len(l.brackets) > 0 && l.brackets[len(l.brackets)-1] == '{' { + return l.lexVoid + } return l.lexRvalue } @@ -332,7 +330,26 @@ func (l *tomlLexer) lexKey() tomlLexStateFn { } else if r == '\n' { return l.errorf("keys cannot contain new lines") } else if isSpace(r) { - break + str := " " + // skip trailing whitespace + l.next() + for r = l.peek(); isSpace(r); r = l.peek() { + str += string(r) + l.next() + } + // break loop if not a dot + if r != '.' { + break + } + str += "." + // skip trailing whitespace after dot + l.next() + for r = l.peek(); isSpace(r); r = l.peek() { + str += string(r) + l.next() + } + growingString += str + continue } else if r == '.' { // skip } else if !isValidBareChar(r) { @@ -361,6 +378,7 @@ func (l *tomlLexer) lexComment(previousState tomlLexStateFn) tomlLexStateFn { func (l *tomlLexer) lexLeftBracket() tomlLexStateFn { l.next() l.emit(tokenLeftBracket) + l.brackets = append(l.brackets, '[') return l.lexRvalue } @@ -543,7 +561,6 @@ func (l *tomlLexer) lexString() tomlLexStateFn { } str, err := l.lexStringAsString(terminator, discardLeadingNewLine, acceptNewLines) - if err != nil { return l.errorf(err.Error()) } @@ -615,6 +632,10 @@ func (l *tomlLexer) lexInsideTableKey() tomlLexStateFn { func (l *tomlLexer) lexRightBracket() tomlLexStateFn { l.next() l.emit(tokenRightBracket) + if len(l.brackets) == 0 || l.brackets[len(l.brackets)-1] != '[' { + return l.errorf("cannot have ']' here") + } + l.brackets = l.brackets[:len(l.brackets)-1] return l.lexRvalue } diff --git a/lexer_test.go b/lexer_test.go index ff826e9..225a52a 100644 --- a/lexer_test.go +++ b/lexer_test.go @@ -8,7 +8,7 @@ import ( func testFlow(t *testing.T, input string, expectedFlow []token) { tokens := lexToml([]byte(input)) if !reflect.DeepEqual(tokens, expectedFlow) { - t.Fatal("Different flows. Expected\n", expectedFlow, "\nGot:\n", tokens) + t.Fatalf("Different flows.\nExpected:\n%v\nGot:\n%v", expectedFlow, tokens) } } @@ -22,11 +22,20 @@ func TestValidKeyGroup(t *testing.T) { } func TestNestedQuotedUnicodeKeyGroup(t *testing.T) { - testFlow(t, `[ j . "ʞ" . l ]`, []token{ + testFlow(t, `[ j . "ʞ" . l . 'ɯ' ]`, []token{ {Position{1, 1}, tokenLeftBracket, "["}, - {Position{1, 2}, tokenKeyGroup, ` j . "ʞ" . l `}, - {Position{1, 15}, tokenRightBracket, "]"}, - {Position{1, 16}, tokenEOF, ""}, + {Position{1, 2}, tokenKeyGroup, ` j . "ʞ" . l . 'ɯ' `}, + {Position{1, 21}, tokenRightBracket, "]"}, + {Position{1, 22}, tokenEOF, ""}, + }) +} + +func TestNestedQuotedUnicodeKeyAssign(t *testing.T) { + testFlow(t, ` j . "ʞ" . l . 'ɯ' = 3`, []token{ + {Position{1, 2}, tokenKey, `j . "ʞ" . l . 'ɯ'`}, + {Position{1, 20}, tokenEqual, "="}, + {Position{1, 22}, tokenInteger, "3"}, + {Position{1, 23}, tokenEOF, ""}, }) } @@ -105,9 +114,9 @@ func TestBasicKeyWithUppercaseMix(t *testing.T) { } func TestBasicKeyWithInternationalCharacters(t *testing.T) { - testFlow(t, "héllÖ", []token{ - {Position{1, 1}, tokenKey, "héllÖ"}, - {Position{1, 6}, tokenEOF, ""}, + testFlow(t, "'héllÖ'", []token{ + {Position{1, 1}, tokenKey, "'héllÖ'"}, + {Position{1, 8}, tokenEOF, ""}, }) } @@ -698,6 +707,7 @@ func TestUnicodeString(t *testing.T) { {Position{1, 22}, tokenEOF, ""}, }) } + func TestEscapeInString(t *testing.T) { testFlow(t, `foo = "\b\f\/"`, []token{ {Position{1, 1}, tokenKey, "foo"}, @@ -772,6 +782,16 @@ func TestLexUnknownRvalue(t *testing.T) { }) } +func TestLexInlineTableEmpty(t *testing.T) { + testFlow(t, `foo = {}`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 8}, tokenRightCurlyBrace, "}"}, + {Position{1, 9}, tokenEOF, ""}, + }) +} + func TestLexInlineTableBareKey(t *testing.T) { testFlow(t, `foo = { bar = "baz" }`, []token{ {Position{1, 1}, tokenKey, "foo"}, @@ -798,6 +818,116 @@ func TestLexInlineTableBareKeyDash(t *testing.T) { }) } +func TestLexInlineTableBareKeyInArray(t *testing.T) { + testFlow(t, `foo = [{ -bar_ = "baz" }]`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftBracket, "["}, + {Position{1, 8}, tokenLeftCurlyBrace, "{"}, + {Position{1, 10}, tokenKey, "-bar_"}, + {Position{1, 16}, tokenEqual, "="}, + {Position{1, 19}, tokenString, "baz"}, + {Position{1, 24}, tokenRightCurlyBrace, "}"}, + {Position{1, 25}, tokenRightBracket, "]"}, + {Position{1, 26}, tokenEOF, ""}, + }) +} + +func TestLexInlineTableError1(t *testing.T) { + testFlow(t, `foo = { 123 = 0 ]`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "123"}, + {Position{1, 13}, tokenEqual, "="}, + {Position{1, 15}, tokenInteger, "0"}, + {Position{1, 17}, tokenRightBracket, "]"}, + {Position{1, 18}, tokenError, "cannot have ']' here"}, + }) +} + +func TestLexInlineTableError2(t *testing.T) { + testFlow(t, `foo = { 123 = 0 }}`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "123"}, + {Position{1, 13}, tokenEqual, "="}, + {Position{1, 15}, tokenInteger, "0"}, + {Position{1, 17}, tokenRightCurlyBrace, "}"}, + {Position{1, 18}, tokenRightCurlyBrace, "}"}, + {Position{1, 19}, tokenError, "cannot have '}' here"}, + }) +} + +func TestLexInlineTableDottedKey1(t *testing.T) { + testFlow(t, `foo = { a = 0, 123.45abc = 0 }`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "a"}, + {Position{1, 11}, tokenEqual, "="}, + {Position{1, 13}, tokenInteger, "0"}, + {Position{1, 14}, tokenComma, ","}, + {Position{1, 16}, tokenKey, "123.45abc"}, + {Position{1, 26}, tokenEqual, "="}, + {Position{1, 28}, tokenInteger, "0"}, + {Position{1, 30}, tokenRightCurlyBrace, "}"}, + {Position{1, 31}, tokenEOF, ""}, + }) +} + +func TestLexInlineTableDottedKey2(t *testing.T) { + testFlow(t, `foo = { a = 0, '123'.'45abc' = 0 }`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "a"}, + {Position{1, 11}, tokenEqual, "="}, + {Position{1, 13}, tokenInteger, "0"}, + {Position{1, 14}, tokenComma, ","}, + {Position{1, 16}, tokenKey, "'123'.'45abc'"}, + {Position{1, 30}, tokenEqual, "="}, + {Position{1, 32}, tokenInteger, "0"}, + {Position{1, 34}, tokenRightCurlyBrace, "}"}, + {Position{1, 35}, tokenEOF, ""}, + }) +} + +func TestLexInlineTableDottedKey3(t *testing.T) { + testFlow(t, `foo = { a = 0, "123"."45ʎǝʞ" = 0 }`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "a"}, + {Position{1, 11}, tokenEqual, "="}, + {Position{1, 13}, tokenInteger, "0"}, + {Position{1, 14}, tokenComma, ","}, + {Position{1, 16}, tokenKey, `"123"."45ʎǝʞ"`}, + {Position{1, 30}, tokenEqual, "="}, + {Position{1, 32}, tokenInteger, "0"}, + {Position{1, 34}, tokenRightCurlyBrace, "}"}, + {Position{1, 35}, tokenEOF, ""}, + }) +} + +func TestLexInlineTableBareKeyWithComma(t *testing.T) { + testFlow(t, `foo = { -bar1 = "baz", -bar_ = "baz" }`, []token{ + {Position{1, 1}, tokenKey, "foo"}, + {Position{1, 5}, tokenEqual, "="}, + {Position{1, 7}, tokenLeftCurlyBrace, "{"}, + {Position{1, 9}, tokenKey, "-bar1"}, + {Position{1, 15}, tokenEqual, "="}, + {Position{1, 18}, tokenString, "baz"}, + {Position{1, 22}, tokenComma, ","}, + {Position{1, 24}, tokenKey, "-bar_"}, + {Position{1, 30}, tokenEqual, "="}, + {Position{1, 33}, tokenString, "baz"}, + {Position{1, 38}, tokenRightCurlyBrace, "}"}, + {Position{1, 39}, tokenEOF, ""}, + }) +} + func TestLexInlineTableBareKeyUnderscore(t *testing.T) { testFlow(t, `foo = { _bar = "baz" }`, []token{ {Position{1, 1}, tokenKey, "foo"}, diff --git a/marshal.go b/marshal.go index 0832630..db5a7b4 100644 --- a/marshal.go +++ b/marshal.go @@ -2,6 +2,7 @@ package toml import ( "bytes" + "encoding" "errors" "fmt" "io" @@ -69,6 +70,9 @@ const ( var timeType = reflect.TypeOf(time.Time{}) var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() +var unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem() +var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var localDateType = reflect.TypeOf(LocalDate{}) var localTimeType = reflect.TypeOf(LocalTime{}) var localDateTimeType = reflect.TypeOf(LocalDateTime{}) @@ -89,12 +93,16 @@ func isPrimitive(mtype reflect.Type) bool { case reflect.String: return true case reflect.Struct: - return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType || isCustomMarshaler(mtype) + return isTimeType(mtype) default: return false } } +func isTimeType(mtype reflect.Type) bool { + return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType +} + // Check if the given marshal type maps to a Tree slice or array func isTreeSequence(mtype reflect.Type) bool { switch mtype.Kind() { @@ -107,6 +115,30 @@ func isTreeSequence(mtype reflect.Type) bool { } } +// Check if the given marshal type maps to a slice or array of a custom marshaler type +func isCustomMarshalerSequence(mtype reflect.Type) bool { + switch mtype.Kind() { + case reflect.Ptr: + return isCustomMarshalerSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return isCustomMarshaler(mtype.Elem()) || isCustomMarshaler(reflect.New(mtype.Elem()).Type()) + default: + return false + } +} + +// Check if the given marshal type maps to a slice or array of a text marshaler type +func isTextMarshalerSequence(mtype reflect.Type) bool { + switch mtype.Kind() { + case reflect.Ptr: + return isTextMarshalerSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return isTextMarshaler(mtype.Elem()) || isTextMarshaler(reflect.New(mtype.Elem()).Type()) + default: + return false + } +} + // Check if the given marshal type maps to a non-Tree slice or array func isOtherSequence(mtype reflect.Type) bool { switch mtype.Kind() { @@ -141,12 +173,42 @@ func callCustomMarshaler(mval reflect.Value) ([]byte, error) { return mval.Interface().(Marshaler).MarshalTOML() } +func isTextMarshaler(mtype reflect.Type) bool { + return mtype.Implements(textMarshalerType) && !isTimeType(mtype) +} + +func callTextMarshaler(mval reflect.Value) ([]byte, error) { + return mval.Interface().(encoding.TextMarshaler).MarshalText() +} + +func isCustomUnmarshaler(mtype reflect.Type) bool { + return mtype.Implements(unmarshalerType) +} + +func callCustomUnmarshaler(mval reflect.Value, tval interface{}) error { + return mval.Interface().(Unmarshaler).UnmarshalTOML(tval) +} + +func isTextUnmarshaler(mtype reflect.Type) bool { + return mtype.Implements(textUnmarshalerType) +} + +func callTextUnmarshaler(mval reflect.Value, text []byte) error { + return mval.Interface().(encoding.TextUnmarshaler).UnmarshalText(text) +} + // Marshaler is the interface implemented by types that // can marshal themselves into valid TOML. type Marshaler interface { MarshalTOML() ([]byte, error) } +// Unmarshaler is the interface implemented by types that +// can unmarshal a TOML description of themselves. +type Unmarshaler interface { + UnmarshalTOML(interface{}) error +} + /* Marshal returns the TOML encoding of v. Behavior is similar to the Go json encoder, except that there is no concept of a Marshaler interface or MarshalTOML @@ -195,17 +257,19 @@ type Encoder struct { col int order marshalOrder promoteAnon bool + indentation string } // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { return &Encoder{ - w: w, - encOpts: encOptsDefaults, - annotation: annotationDefault, - line: 0, - col: 1, - order: OrderAlphabetical, + w: w, + encOpts: encOptsDefaults, + annotation: annotationDefault, + line: 0, + col: 1, + order: OrderAlphabetical, + indentation: " ", } } @@ -257,6 +321,12 @@ func (e *Encoder) Order(ord marshalOrder) *Encoder { return e } +// Indentation allows to change indentation when marshalling. +func (e *Encoder) Indentation(indent string) *Encoder { + e.indentation = indent + return e +} + // SetTagName allows changing default tag "toml" func (e *Encoder) SetTagName(v string) *Encoder { e.tag = v @@ -295,6 +365,13 @@ func (e *Encoder) PromoteAnonymous(promote bool) *Encoder { } func (e *Encoder) marshal(v interface{}) ([]byte, error) { + // Check if indentation is valid + for _, char := range e.indentation { + if !isSpace(char) { + return []byte{}, fmt.Errorf("invalid indentation: must only contains space or tab characters") + } + } + mtype := reflect.TypeOf(v) if mtype == nil { return []byte{}, errors.New("nil cannot be marshaled to TOML") @@ -317,13 +394,16 @@ func (e *Encoder) marshal(v interface{}) ([]byte, error) { if isCustomMarshaler(mtype) { return callCustomMarshaler(sval) } + if isTextMarshaler(mtype) { + return callTextMarshaler(sval) + } t, err := e.valueToTree(mtype, sval) if err != nil { return []byte{}, err } var buf bytes.Buffer - _, err = t.writeToOrdered(&buf, "", "", 0, e.arraysOneElementPerLine, e.order, false) + _, err = t.writeToOrdered(&buf, "", "", 0, e.arraysOneElementPerLine, e.order, e.indentation, false) return buf.Bytes(), err } @@ -356,7 +436,7 @@ func (e *Encoder) valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, er if tree, ok := val.(*Tree); ok && mtypef.Anonymous && !opts.nameFromTag && !e.promoteAnon { e.appendTree(tval, tree) } else { - tval.SetWithOptions(opts.name, SetOptions{ + tval.SetPathWithOptions([]string{opts.name}, SetOptions{ Comment: opts.comment, Commented: opts.commented, Multiline: opts.multiline, @@ -395,13 +475,13 @@ func (e *Encoder) valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, er return nil, err } if e.quoteMapKeys { - keyStr, err := tomlValueStringRepresentation(key.String(), "", "", e.arraysOneElementPerLine) + keyStr, err := tomlValueStringRepresentation(key.String(), "", "", e.order, e.arraysOneElementPerLine) if err != nil { return nil, err } tval.SetPath([]string{keyStr}, val) } else { - tval.Set(key.String(), val) + tval.SetPath([]string{key.String()}, val) } } } @@ -423,9 +503,6 @@ func (e *Encoder) valueToTreeSlice(mtype reflect.Type, mval reflect.Value) ([]*T // Convert given marshal slice to slice of toml values func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (interface{}, error) { - if mtype.Elem().Kind() == reflect.Interface { - return nil, fmt.Errorf("marshal can't handle []interface{}") - } tval := make([]interface{}, mval.Len(), mval.Len()) for i := 0; i < mval.Len(); i++ { val, err := e.valueToToml(mtype.Elem(), mval.Index(i)) @@ -441,7 +518,14 @@ func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (int func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) { e.line++ if mtype.Kind() == reflect.Ptr { - return e.valueToToml(mtype.Elem(), mval.Elem()) + switch { + case isCustomMarshaler(mtype): + return callCustomMarshaler(mval) + case isTextMarshaler(mtype): + return callTextMarshaler(mval) + default: + return e.valueToToml(mtype.Elem(), mval.Elem()) + } } if mtype.Kind() == reflect.Interface { return e.valueToToml(mval.Elem().Type(), mval.Elem()) @@ -449,12 +533,14 @@ func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface switch { case isCustomMarshaler(mtype): return callCustomMarshaler(mval) + case isTextMarshaler(mtype): + return callTextMarshaler(mval) case isTree(mtype): return e.valueToTree(mtype, mval) + case isOtherSequence(mtype), isCustomMarshalerSequence(mtype), isTextMarshalerSequence(mtype): + return e.valueToOtherSlice(mtype, mval) case isTreeSequence(mtype): return e.valueToTreeSlice(mtype, mval) - case isOtherSequence(mtype): - return e.valueToOtherSlice(mtype, mval) default: switch mtype.Kind() { case reflect.Bool: @@ -543,6 +629,8 @@ type Decoder struct { tval *Tree encOpts tagName string + strict bool + visitor visitorState } // NewDecoder returns a new decoder that reads from r. @@ -573,6 +661,13 @@ func (d *Decoder) SetTagName(v string) *Decoder { return d } +// Strict allows changing to strict decoding. Any fields that are found in the +// input data and do not have a corresponding struct member cause an error. +func (d *Decoder) Strict(strict bool) *Decoder { + d.strict = strict + return d +} + func (d *Decoder) unmarshal(v interface{}) error { mtype := reflect.TypeOf(v) if mtype == nil { @@ -596,10 +691,17 @@ func (d *Decoder) unmarshal(v interface{}) error { vv := reflect.ValueOf(v).Elem() + if d.strict { + d.visitor = newVisitorState(d.tval) + } + sval, err := d.valueFromTree(elem, d.tval, &vv) if err != nil { return err } + if err := d.visitor.validate(); err != nil { + return err + } reflect.ValueOf(v).Elem().Set(sval) return nil } @@ -610,6 +712,17 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V if mtype.Kind() == reflect.Ptr { return d.unwrapPointer(mtype, tval, mval1) } + + // Check if pointer to value implements the Unmarshaler interface. + if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) { + d.visitor.visitAll() + + if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err) + } + return mvalPtr.Elem(), nil + } + var mval reflect.Value switch mtype.Kind() { case reflect.Struct: @@ -641,18 +754,21 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V found := false if tval != nil { for _, key := range keysToTry { - exists := tval.Has(key) + exists := tval.HasPath([]string{key}) if !exists { continue } - val := tval.Get(key) + + d.visitor.push(key) + val := tval.GetPath([]string{key}) fval := mval.Field(i) mvalf, err := d.valueFromToml(mtypef.Type, val, &fval) if err != nil { - return mval, formatError(err, tval.GetPosition(key)) + return mval, formatError(err, tval.GetPositionPath([]string{key})) } mval.Field(i).Set(mvalf) found = true + d.visitor.pop() break } } @@ -662,32 +778,42 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V var val interface{} var err error switch mvalf.Kind() { - case reflect.Bool: - val, err = strconv.ParseBool(opts.defaultValue) - if err != nil { - return mval.Field(i), err - } - case reflect.Int: - val, err = strconv.Atoi(opts.defaultValue) - if err != nil { - return mval.Field(i), err - } case reflect.String: val = opts.defaultValue + case reflect.Bool: + val, err = strconv.ParseBool(opts.defaultValue) + case reflect.Uint: + val, err = strconv.ParseUint(opts.defaultValue, 10, 0) + case reflect.Uint8: + val, err = strconv.ParseUint(opts.defaultValue, 10, 8) + case reflect.Uint16: + val, err = strconv.ParseUint(opts.defaultValue, 10, 16) + case reflect.Uint32: + val, err = strconv.ParseUint(opts.defaultValue, 10, 32) + case reflect.Uint64: + val, err = strconv.ParseUint(opts.defaultValue, 10, 64) + case reflect.Int: + val, err = strconv.ParseInt(opts.defaultValue, 10, 0) + case reflect.Int8: + val, err = strconv.ParseInt(opts.defaultValue, 10, 8) + case reflect.Int16: + val, err = strconv.ParseInt(opts.defaultValue, 10, 16) + case reflect.Int32: + val, err = strconv.ParseInt(opts.defaultValue, 10, 32) case reflect.Int64: val, err = strconv.ParseInt(opts.defaultValue, 10, 64) - if err != nil { - return mval.Field(i), err - } + case reflect.Float32: + val, err = strconv.ParseFloat(opts.defaultValue, 32) case reflect.Float64: val, err = strconv.ParseFloat(opts.defaultValue, 64) - if err != nil { - return mval.Field(i), err - } default: - return mval.Field(i), fmt.Errorf("unsuported field type for default option") + return mvalf, fmt.Errorf("unsupported field type for default option") } - mval.Field(i).Set(reflect.ValueOf(val)) + + if err != nil { + return mvalf, err + } + mvalf.Set(reflect.ValueOf(val).Convert(mvalf.Type())) } // save the old behavior above and try to check structs @@ -696,7 +822,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V if !mtypef.Anonymous { tmpTval = nil } - v, err := d.valueFromTree(mtypef.Type, tmpTval, nil) + fval := mval.Field(i) + v, err := d.valueFromTree(mtypef.Type, tmpTval, &fval) if err != nil { return v, err } @@ -707,13 +834,15 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V case reflect.Map: mval = reflect.MakeMap(mtype) for _, key := range tval.Keys() { + d.visitor.push(key) // TODO: path splits key val := tval.GetPath([]string{key}) mvalf, err := d.valueFromToml(mtype.Elem(), val, nil) if err != nil { - return mval, formatError(err, tval.GetPosition(key)) + return mval, formatError(err, tval.GetPositionPath([]string{key})) } mval.SetMapIndex(reflect.ValueOf(key).Convert(mtype.Key()), mvalf) + d.visitor.pop() } } return mval, nil @@ -721,20 +850,30 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V // Convert toml value to marshal struct/map slice, using marshal type func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) { - mval := reflect.MakeSlice(mtype, len(tval), len(tval)) + mval, err := makeSliceOrArray(mtype, len(tval)) + if err != nil { + return mval, err + } + for i := 0; i < len(tval); i++ { + d.visitor.push(strconv.Itoa(i)) val, err := d.valueFromTree(mtype.Elem(), tval[i], nil) if err != nil { return mval, err } mval.Index(i).Set(val) + d.visitor.pop() } return mval, nil } // Convert toml value to marshal primitive slice, using marshal type func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) { - mval := reflect.MakeSlice(mtype, len(tval), len(tval)) + mval, err := makeSliceOrArray(mtype, len(tval)) + if err != nil { + return mval, err + } + for i := 0; i < len(tval); i++ { val, err := d.valueFromToml(mtype.Elem(), tval[i], nil) if err != nil { @@ -748,10 +887,14 @@ func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (r // Convert toml value to marshal primitive slice, using marshal type func (d *Decoder) valueFromOtherSliceI(mtype reflect.Type, tval interface{}) (reflect.Value, error) { val := reflect.ValueOf(tval) + length := val.Len() - lenght := val.Len() - mval := reflect.MakeSlice(mtype, lenght, lenght) - for i := 0; i < lenght; i++ { + mval, err := makeSliceOrArray(mtype, length) + if err != nil { + return mval, err + } + + for i := 0; i < length; i++ { val, err := d.valueFromToml(mtype.Elem(), val.Index(i).Interface(), nil) if err != nil { return mval, err @@ -761,6 +904,21 @@ func (d *Decoder) valueFromOtherSliceI(mtype reflect.Type, tval interface{}) (re return mval, nil } +// Create a new slice or a new array with specified length +func makeSliceOrArray(mtype reflect.Type, tLength int) (reflect.Value, error) { + var mval reflect.Value + switch mtype.Kind() { + case reflect.Slice: + mval = reflect.MakeSlice(mtype, tLength, tLength) + case reflect.Array: + mval = reflect.New(reflect.ArrayOf(mtype.Len(), mtype.Elem())).Elem() + if tLength > mtype.Len() { + return mval, fmt.Errorf("unmarshal: TOML array length (%v) exceeds destination array length (%v)", tLength, mtype.Len()) + } + } + return mval, nil +} + // Convert toml value to marshal value, using marshal type. When mval1 is non-nil // and the given type is a struct value, merge fields into it. func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) { @@ -802,6 +960,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval) case []interface{}: + d.visitor.visit() if isOtherSequence(mtype) { return d.valueFromOtherSlice(mtype, t) } @@ -815,6 +974,15 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: + d.visitor.visit() + // Check if pointer to value implements the encoding.TextUnmarshaler. + if mvalPtr := reflect.New(mtype); isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { + if err := d.unmarshalText(tval, mvalPtr); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err) + } + return mvalPtr.Elem(), nil + } + switch mtype.Kind() { case reflect.Bool, reflect.Struct: val := reflect.ValueOf(tval) @@ -865,34 +1033,34 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(d), nil } - if !val.Type().ConvertibleTo(mtype) { + if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Float64 { return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } - if reflect.Indirect(reflect.New(mtype)).OverflowInt(val.Convert(mtype).Int()) { + if reflect.Indirect(reflect.New(mtype)).OverflowInt(val.Convert(reflect.TypeOf(int64(0))).Int()) { return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) } return val.Convert(mtype), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: val := reflect.ValueOf(tval) - if !val.Type().ConvertibleTo(mtype) { + if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Float64 { return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } if val.Convert(reflect.TypeOf(int(1))).Int() < 0 { return reflect.ValueOf(nil), fmt.Errorf("%v(%T) is negative so does not fit in %v", tval, tval, mtype.String()) } - if reflect.Indirect(reflect.New(mtype)).OverflowUint(uint64(val.Convert(mtype).Uint())) { + if reflect.Indirect(reflect.New(mtype)).OverflowUint(val.Convert(reflect.TypeOf(uint64(0))).Uint()) { return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) } return val.Convert(mtype), nil case reflect.Float32, reflect.Float64: val := reflect.ValueOf(tval) - if !val.Type().ConvertibleTo(mtype) { + if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Int64 { return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String()) } - if reflect.Indirect(reflect.New(mtype)).OverflowFloat(val.Convert(mtype).Float()) { + if reflect.Indirect(reflect.New(mtype)).OverflowFloat(val.Convert(reflect.TypeOf(float64(0))).Float()) { return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String()) } @@ -904,7 +1072,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref ival := mval1.Elem() return d.valueFromToml(mval1.Elem().Type(), t, &ival) } - case reflect.Slice: + case reflect.Slice, reflect.Array: if isOtherSequence(mtype) && isOtherSequence(reflect.TypeOf(t)) { return d.valueFromOtherSliceI(mtype, t) } @@ -932,6 +1100,12 @@ func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *ref return mval, nil } +func (d *Decoder) unmarshalText(tval interface{}, mval reflect.Value) error { + var buf bytes.Buffer + fmt.Fprint(&buf, tval) + return callTextUnmarshaler(mval, buf.Bytes()) +} + func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { tag := vf.Tag.Get(an.tag) parse := strings.Split(tag, ",") @@ -974,11 +1148,7 @@ func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { func isZero(val reflect.Value) bool { switch val.Type().Kind() { - case reflect.Map: - fallthrough - case reflect.Array: - fallthrough - case reflect.Slice: + case reflect.Slice, reflect.Array, reflect.Map: return val.Len() == 0 default: return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface()) @@ -991,3 +1161,80 @@ func formatError(err error, pos Position) error { } return fmt.Errorf("%s: %s", pos, err) } + +// visitorState keeps track of which keys were unmarshaled. +type visitorState struct { + tree *Tree + path []string + keys map[string]struct{} + active bool +} + +func newVisitorState(tree *Tree) visitorState { + path, result := []string{}, map[string]struct{}{} + insertKeys(path, result, tree) + return visitorState{ + tree: tree, + path: path[:0], + keys: result, + active: true, + } +} + +func (s *visitorState) push(key string) { + if s.active { + s.path = append(s.path, key) + } +} + +func (s *visitorState) pop() { + if s.active { + s.path = s.path[:len(s.path)-1] + } +} + +func (s *visitorState) visit() { + if s.active { + delete(s.keys, strings.Join(s.path, ".")) + } +} + +func (s *visitorState) visitAll() { + if s.active { + for k := range s.keys { + if strings.HasPrefix(k, strings.Join(s.path, ".")) { + delete(s.keys, k) + } + } + } +} + +func (s *visitorState) validate() error { + if !s.active { + return nil + } + undecoded := make([]string, 0, len(s.keys)) + for key := range s.keys { + undecoded = append(undecoded, key) + } + sort.Strings(undecoded) + if len(undecoded) > 0 { + return fmt.Errorf("undecoded keys: %q", undecoded) + } + return nil +} + +func insertKeys(path []string, m map[string]struct{}, tree *Tree) { + for k, v := range tree.values { + switch node := v.(type) { + case []*Tree: + for i, item := range node { + insertKeys(append(path, k, strconv.Itoa(i)), m, item) + } + case *Tree: + insertKeys(append(path, k), m, node) + case *tomlValue: + m[strings.Join(append(path, k), ".")] = struct{}{} + } + } +} diff --git a/marshal_test.go b/marshal_test.go index 04c90ef..c5a440d 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "reflect" + "strconv" "strings" "testing" "time" @@ -30,7 +31,7 @@ var basicTestData = basicMarshalTestStruct{ SubList: []basicMarshalTestSubStruct{{"Two"}, {"Three"}}, } -var basicTestToml = []byte(`Ystrlist = ["Howdy","Hey There"] +var basicTestToml = []byte(`Ystrlist = ["Howdy", "Hey There"] Zstring = "Hello" [[Wsublist]] @@ -43,8 +44,21 @@ Zstring = "Hello" String2 = "One" `) +var basicTestTomlCustomIndentation = []byte(`Ystrlist = ["Howdy", "Hey There"] +Zstring = "Hello" + +[[Wsublist]] + String2 = "Two" + +[[Wsublist]] + String2 = "Three" + +[Xsubdoc] + String2 = "One" +`) + var basicTestTomlOrdered = []byte(`Zstring = "Hello" -Ystrlist = ["Howdy","Hey There"] +Ystrlist = ["Howdy", "Hey There"] [Xsubdoc] String2 = "One" @@ -68,12 +82,12 @@ var marshalTestToml = []byte(`title = "TOML Marshal Testing" uint = 5001 [basic_lists] - bools = [true,false,true] - dates = [1979-05-27T07:32:00Z,1980-05-27T07:32:00Z] - floats = [12.3,45.6,78.9] - ints = [8001,8001,8002] - strings = ["One","Two","Three"] - uints = [5002,5003] + bools = [true, false, true] + dates = [1979-05-27T07:32:00Z, 1980-05-27T07:32:00Z] + floats = [12.3, 45.6, 78.9] + ints = [8001, 8001, 8002] + strings = ["One", "Two", "Three"] + uints = [5002, 5003] [basic_map] one = "one" @@ -100,12 +114,12 @@ var marshalTestToml = []byte(`title = "TOML Marshal Testing" var marshalOrderPreserveToml = []byte(`title = "TOML Marshal Testing" [basic_lists] - floats = [12.3,45.6,78.9] - bools = [true,false,true] - dates = [1979-05-27T07:32:00Z,1980-05-27T07:32:00Z] - ints = [8001,8001,8002] - uints = [5002,5003] - strings = ["One","Two","Three"] + floats = [12.3, 45.6, 78.9] + bools = [true, false, true] + dates = [1979-05-27T07:32:00Z, 1980-05-27T07:32:00Z] + ints = [8001, 8001, 8002] + uints = [5002, 5003] + strings = ["One", "Two", "Three"] [[subdocptrs]] name = "Second" @@ -206,6 +220,26 @@ func TestBasicMarshal(t *testing.T) { } } +func TestBasicMarshalCustomIndentation(t *testing.T) { + var result bytes.Buffer + err := NewEncoder(&result).Indentation("\t").Encode(basicTestData) + if err != nil { + t.Fatal(err) + } + expected := basicTestTomlCustomIndentation + if !bytes.Equal(result.Bytes(), expected) { + t.Errorf("Bad marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result.Bytes()) + } +} + +func TestBasicMarshalWrongIndentation(t *testing.T) { + var result bytes.Buffer + err := NewEncoder(&result).Indentation(" \n").Encode(basicTestData) + if err.Error() != "invalid indentation: must only contains space or tab characters" { + t.Error("expect err:invalid indentation: must only contains space or tab characters but got:", err) + } +} + func TestBasicMarshalOrdered(t *testing.T) { var result bytes.Buffer err := NewEncoder(&result).Order(OrderPreserve).Encode(basicTestData) @@ -253,6 +287,59 @@ func TestBasicUnmarshal(t *testing.T) { } } +type quotedKeyMarshalTestStruct struct { + String string `toml:"Z.string-àéù"` + Float float64 `toml:"Yfloat-𝟘"` + Sub basicMarshalTestSubStruct `toml:"Xsubdoc-àéù"` + SubList []basicMarshalTestSubStruct `toml:"W.sublist-𝟘"` +} + +var quotedKeyMarshalTestData = quotedKeyMarshalTestStruct{ + String: "Hello", + Float: 3.5, + Sub: basicMarshalTestSubStruct{"One"}, + SubList: []basicMarshalTestSubStruct{{"Two"}, {"Three"}}, +} + +var quotedKeyMarshalTestToml = []byte(`"Yfloat-𝟘" = 3.5 +"Z.string-àéù" = "Hello" + +[["W.sublist-𝟘"]] + String2 = "Two" + +[["W.sublist-𝟘"]] + String2 = "Three" + +["Xsubdoc-àéù"] + String2 = "One" +`) + +func TestBasicMarshalQuotedKey(t *testing.T) { + result, err := Marshal(quotedKeyMarshalTestData) + if err != nil { + t.Fatal(err) + } + expected := quotedKeyMarshalTestToml + if !bytes.Equal(result, expected) { + t.Errorf("Bad marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +func TestBasicUnmarshalQuotedKey(t *testing.T) { + tree, err := LoadBytes(quotedKeyMarshalTestToml) + if err != nil { + t.Fatal(err) + } + + var q quotedKeyMarshalTestStruct + tree.Unmarshal(&q) + fmt.Println(q) + + if !reflect.DeepEqual(quotedKeyMarshalTestData, q) { + t.Errorf("Bad unmarshal: expected\n-----\n%v\n-----\ngot\n-----\n%v\n-----\n", quotedKeyMarshalTestData, q) + } +} + type testDoc struct { Title string `toml:"title"` BasicLists testDocBasicLists `toml:"basic_lists"` @@ -826,8 +913,8 @@ var nestedTestData = nestedMarshalTestStruct{ StringPtr: &strPtr2, } -var nestedTestToml = []byte(`String = [["Five","Six"],["One","Two"]] -StringPtr = [["Three","Four"]] +var nestedTestToml = []byte(`String = [["Five", "Six"], ["One", "Two"]] +StringPtr = [["Three", "Four"]] `) func TestNestedMarshal(t *testing.T) { @@ -859,24 +946,27 @@ type customMarshalerParent struct { } type customMarshaler struct { - FirsName string - LastName string + FirstName string + LastName string } func (c customMarshaler) MarshalTOML() ([]byte, error) { - fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName) + fullName := fmt.Sprintf("%s %s", c.FirstName, c.LastName) return []byte(fullName), nil } -var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} +var customMarshalerData = customMarshaler{FirstName: "Sally", LastName: "Fields"} var customMarshalerToml = []byte(`Sally Fields`) var nestedCustomMarshalerData = customMarshalerParent{ - Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"}, + Self: customMarshaler{FirstName: "Maiku", LastName: "Suteda"}, Friends: []customMarshaler{customMarshalerData}, } var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] me = "Maiku Suteda" `) +var nestedCustomMarshalerTomlForUnmarshal = []byte(`[friends] +FirstName = "Sally" +LastName = "Fields"`) func TestCustomMarshaler(t *testing.T) { result, err := Marshal(customMarshalerData) @@ -889,14 +979,172 @@ func TestCustomMarshaler(t *testing.T) { } } -func TestNestedCustomMarshaler(t *testing.T) { - result, err := Marshal(nestedCustomMarshalerData) +type textMarshaler struct { + FirstName string + LastName string +} + +func (m textMarshaler) MarshalText() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) + return []byte(fullName), nil +} + +func TestTextMarshaler(t *testing.T) { + m := textMarshaler{FirstName: "Sally", LastName: "Fields"} + + result, err := Marshal(m) if err != nil { t.Fatal(err) } - expected := nestedCustomMarshalerToml - if !bytes.Equal(result, expected) { - t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + expected := `Sally Fields` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +func TestUnmarshalTextMarshaler(t *testing.T) { + var nested = struct { + Friends textMarshaler `toml:"friends"` + }{} + + var expected = struct { + Friends textMarshaler `toml:"friends"` + }{ + Friends: textMarshaler{FirstName: "Sally", LastName: "Fields"}, + } + + err := Unmarshal(nestedCustomMarshalerTomlForUnmarshal, &nested) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(nested, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, nested) + } +} + +func TestNestedTextMarshaler(t *testing.T) { + var parent = struct { + Self textMarshaler `toml:"me"` + Friends []textMarshaler `toml:"friends"` + Stranger *textMarshaler `toml:"stranger"` + }{ + Self: textMarshaler{FirstName: "Maiku", LastName: "Suteda"}, + Friends: []textMarshaler{textMarshaler{FirstName: "Sally", LastName: "Fields"}}, + Stranger: &textMarshaler{FirstName: "Earl", LastName: "Henson"}, + } + + result, err := Marshal(parent) + if err != nil { + t.Fatal(err) + } + expected := `friends = ["Sally Fields"] +me = "Maiku Suteda" +stranger = "Earl Henson" +` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +type precedentMarshaler struct { + FirstName string + LastName string +} + +func (m precedentMarshaler) MarshalText() ([]byte, error) { + return []byte("shadowed"), nil +} + +func (m precedentMarshaler) MarshalTOML() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) + return []byte(fullName), nil +} + +func TestPrecedentMarshaler(t *testing.T) { + m := textMarshaler{FirstName: "Sally", LastName: "Fields"} + + result, err := Marshal(m) + if err != nil { + t.Fatal(err) + } + expected := `Sally Fields` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +type customPointerMarshaler struct { + FirstName string + LastName string +} + +func (m *customPointerMarshaler) MarshalTOML() ([]byte, error) { + return []byte("hidden"), nil +} + +type textPointerMarshaler struct { + FirstName string + LastName string +} + +func (m *textPointerMarshaler) MarshalText() ([]byte, error) { + return []byte("hidden"), nil +} + +func TestPointerMarshaler(t *testing.T) { + var parent = struct { + Self customPointerMarshaler `toml:"me"` + Stranger *customPointerMarshaler `toml:"stranger"` + Friend textPointerMarshaler `toml:"friend"` + Fiend *textPointerMarshaler `toml:"fiend"` + }{ + Self: customPointerMarshaler{FirstName: "Maiku", LastName: "Suteda"}, + Stranger: &customPointerMarshaler{FirstName: "Earl", LastName: "Henson"}, + Friend: textPointerMarshaler{FirstName: "Sally", LastName: "Fields"}, + Fiend: &textPointerMarshaler{FirstName: "Casper", LastName: "Snider"}, + } + + result, err := Marshal(parent) + if err != nil { + t.Fatal(err) + } + expected := `fiend = "hidden" +stranger = "hidden" + +[friend] + FirstName = "Sally" + LastName = "Fields" + +[me] + FirstName = "Maiku" + LastName = "Suteda" +` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +func TestPointerCustomMarshalerSequence(t *testing.T) { + var customPointerMarshalerSlice *[]*customPointerMarshaler + var customPointerMarshalerArray *[2]*customPointerMarshaler + + if !isCustomMarshalerSequence(reflect.TypeOf(customPointerMarshalerSlice)) { + t.Errorf("error: should be a sequence of custom marshaler interfaces") + } + if !isCustomMarshalerSequence(reflect.TypeOf(customPointerMarshalerArray)) { + t.Errorf("error: should be a sequence of custom marshaler interfaces") + } +} + +func TestPointerTextMarshalerSequence(t *testing.T) { + var textPointerMarshalerSlice *[]*textPointerMarshaler + var textPointerMarshalerArray *[2]*textPointerMarshaler + + if !isTextMarshalerSequence(reflect.TypeOf(textPointerMarshalerSlice)) { + t.Errorf("error: should be a sequence of text marshaler interfaces") + } + if !isTextMarshalerSequence(reflect.TypeOf(textPointerMarshalerArray)) { + t.Errorf("error: should be a sequence of text marshaler interfaces") } } @@ -1243,7 +1491,7 @@ type structArrayNoTag struct { func TestMarshalArray(t *testing.T) { expected := []byte(` [A] - B = [1,2,3] + B = [1, 2, 3] C = [1] `) @@ -1740,6 +1988,58 @@ func TestMarshalSlicePointer(t *testing.T) { } } +func TestMarshalNestedArrayInlineTables(t *testing.T) { + type table struct { + Value1 int `toml:"ZValue1"` + Value2 int `toml:"YValue2"` + Value3 int `toml:"XValue3"` + } + + type nestedTable struct { + Table table + } + + nestedArray := struct { + Simple [][]table + SimplePointer *[]*[]table + Nested [][]nestedTable + NestedPointer *[]*[]nestedTable + }{ + Simple: [][]table{{{Value1: 1}, {Value1: 10}}}, + SimplePointer: &[]*[]table{{{Value2: 2}}}, + Nested: [][]nestedTable{{{Table: table{Value3: 3}}}}, + NestedPointer: &[]*[]nestedTable{{{Table: table{Value3: -3}}}}, + } + + expectedPreserve := `Simple = [[{ ZValue1 = 1, YValue2 = 0, XValue3 = 0 }, { ZValue1 = 10, YValue2 = 0, XValue3 = 0 }]] +SimplePointer = [[{ ZValue1 = 0, YValue2 = 2, XValue3 = 0 }]] +Nested = [[{ Table = { ZValue1 = 0, YValue2 = 0, XValue3 = 3 } }]] +NestedPointer = [[{ Table = { ZValue1 = 0, YValue2 = 0, XValue3 = -3 } }]] +` + + expectedAlphabetical := `Nested = [[{ Table = { XValue3 = 3, YValue2 = 0, ZValue1 = 0 } }]] +NestedPointer = [[{ Table = { XValue3 = -3, YValue2 = 0, ZValue1 = 0 } }]] +Simple = [[{ XValue3 = 0, YValue2 = 0, ZValue1 = 1 }, { XValue3 = 0, YValue2 = 0, ZValue1 = 10 }]] +SimplePointer = [[{ XValue3 = 0, YValue2 = 2, ZValue1 = 0 }]] +` + + var bufPreserve bytes.Buffer + if err := NewEncoder(&bufPreserve).Order(OrderPreserve).Encode(nestedArray); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if !bytes.Equal(bufPreserve.Bytes(), []byte(expectedPreserve)) { + t.Errorf("Bad marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expectedPreserve, bufPreserve.String()) + } + + var bufAlphabetical bytes.Buffer + if err := NewEncoder(&bufAlphabetical).Order(OrderAlphabetical).Encode(nestedArray); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if !bytes.Equal(bufAlphabetical.Bytes(), []byte(expectedAlphabetical)) { + t.Errorf("Bad marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expectedAlphabetical, bufAlphabetical.String()) + } +} + type testDuration struct { Nanosec time.Duration `toml:"nanosec"` Microsec1 time.Duration `toml:"microsec1"` @@ -1875,21 +2175,95 @@ func TestUnmarshalCamelCaseKey(t *testing.T) { } } +func TestUnmarshalNegativeUint(t *testing.T) { + type check struct{ U uint } + + tree, _ := Load("u = -1") + err := tree.Unmarshal(&check{}) + if err.Error() != "(1, 1): -1(int64) is negative so does not fit in uint" { + t.Error("expect err:(1, 1): -1(int64) is negative so does not fit in uint but got:", err) + } +} + +func TestUnmarshalCheckConversionFloatInt(t *testing.T) { + type conversionCheck struct { + U uint + I int + F float64 + } + + treeU, _ := Load("u = 1e300") + treeI, _ := Load("i = 1e300") + treeF, _ := Load("f = 9223372036854775806") + + errU := treeU.Unmarshal(&conversionCheck{}) + errI := treeI.Unmarshal(&conversionCheck{}) + errF := treeF.Unmarshal(&conversionCheck{}) + + if errU.Error() != "(1, 1): Can't convert 1e+300(float64) to uint" { + t.Error("expect err:(1, 1): Can't convert 1e+300(float64) to uint but got:", errU) + } + if errI.Error() != "(1, 1): Can't convert 1e+300(float64) to int" { + t.Error("expect err:(1, 1): Can't convert 1e+300(float64) to int but got:", errI) + } + if errF.Error() != "(1, 1): Can't convert 9223372036854775806(int64) to float64" { + t.Error("expect err:(1, 1): Can't convert 9223372036854775806(int64) to float64 but got:", errF) + } +} + +func TestUnmarshalOverflow(t *testing.T) { + type overflow struct { + U8 uint8 + I8 int8 + F32 float32 + } + + treeU8, _ := Load("u8 = 300") + treeI8, _ := Load("i8 = 300") + treeF32, _ := Load("f32 = 1e300") + + errU8 := treeU8.Unmarshal(&overflow{}) + errI8 := treeI8.Unmarshal(&overflow{}) + errF32 := treeF32.Unmarshal(&overflow{}) + + if errU8.Error() != "(1, 1): 300(int64) would overflow uint8" { + t.Error("expect err:(1, 1): 300(int64) would overflow uint8 but got:", errU8) + } + if errI8.Error() != "(1, 1): 300(int64) would overflow int8" { + t.Error("expect err:(1, 1): 300(int64) would overflow int8 but got:", errI8) + } + if errF32.Error() != "(1, 1): 1e+300(float64) would overflow float32" { + t.Error("expect err:(1, 1): 1e+300(float64) would overflow float32 but got:", errF32) + } +} + func TestUnmarshalDefault(t *testing.T) { type EmbeddedStruct struct { StringField string `default:"c"` } + type aliasUint uint + var doc struct { StringField string `default:"a"` BoolField bool `default:"true"` - IntField int `default:"1"` - Int64Field int64 `default:"2"` - Float64Field float64 `default:"3.1"` + UintField uint `default:"1"` + Uint8Field uint8 `default:"8"` + Uint16Field uint16 `default:"16"` + Uint32Field uint32 `default:"32"` + Uint64Field uint64 `default:"64"` + IntField int `default:"-1"` + Int8Field int8 `default:"-8"` + Int16Field int16 `default:"-16"` + Int32Field int32 `default:"-32"` + Int64Field int64 `default:"-64"` + Float32Field float32 `default:"32.1"` + Float64Field float64 `default:"64.1"` NonEmbeddedStruct struct { StringField string `default:"b"` } EmbeddedStruct + AliasUintField aliasUint `default:"1000"` } err := Unmarshal([]byte(``), &doc) @@ -1902,14 +2276,41 @@ func TestUnmarshalDefault(t *testing.T) { if doc.StringField != "a" { t.Errorf("StringField should be \"a\", not %s", doc.StringField) } - if doc.IntField != 1 { - t.Errorf("IntField should be 1, not %d", doc.IntField) + if doc.UintField != 1 { + t.Errorf("UintField should be 1, not %d", doc.UintField) } - if doc.Int64Field != 2 { - t.Errorf("Int64Field should be 2, not %d", doc.Int64Field) + if doc.Uint8Field != 8 { + t.Errorf("Uint8Field should be 8, not %d", doc.Uint8Field) } - if doc.Float64Field != 3.1 { - t.Errorf("Float64Field should be 3.1, not %f", doc.Float64Field) + if doc.Uint16Field != 16 { + t.Errorf("Uint16Field should be 16, not %d", doc.Uint16Field) + } + if doc.Uint32Field != 32 { + t.Errorf("Uint32Field should be 32, not %d", doc.Uint32Field) + } + if doc.Uint64Field != 64 { + t.Errorf("Uint64Field should be 64, not %d", doc.Uint64Field) + } + if doc.IntField != -1 { + t.Errorf("IntField should be -1, not %d", doc.IntField) + } + if doc.Int8Field != -8 { + t.Errorf("Int8Field should be -8, not %d", doc.Int8Field) + } + if doc.Int16Field != -16 { + t.Errorf("Int16Field should be -16, not %d", doc.Int16Field) + } + if doc.Int32Field != -32 { + t.Errorf("Int32Field should be -32, not %d", doc.Int32Field) + } + if doc.Int64Field != -64 { + t.Errorf("Int64Field should be -64, not %d", doc.Int64Field) + } + if doc.Float32Field != 32.1 { + t.Errorf("Float32Field should be 32.1, not %f", doc.Float32Field) + } + if doc.Float64Field != 64.1 { + t.Errorf("Float64Field should be 64.1, not %f", doc.Float64Field) } if doc.NonEmbeddedStruct.StringField != "b" { t.Errorf("StringField should be \"b\", not %s", doc.NonEmbeddedStruct.StringField) @@ -1917,6 +2318,9 @@ func TestUnmarshalDefault(t *testing.T) { if doc.EmbeddedStruct.StringField != "c" { t.Errorf("StringField should be \"c\", not %s", doc.EmbeddedStruct.StringField) } + if doc.AliasUintField != 1000 { + t.Errorf("AliasUintField should be 1000, not %d", doc.AliasUintField) + } } func TestUnmarshalDefaultFailureBool(t *testing.T) { @@ -2144,7 +2548,7 @@ func TestUnmarshalPreservesUnexportedFields(t *testing.T) { [[slice1]] exported1 = "visible3" - + [[slice1]] exported1 = "visible4" @@ -2250,7 +2654,7 @@ func TestMarshalArrays(t *testing.T) { }{ XY: [2]int{1, 2}, }, - Expected: `XY = [1,2] + Expected: `XY = [1, 2] `, }, { @@ -2259,7 +2663,7 @@ func TestMarshalArrays(t *testing.T) { }{ XY: [1][2]int{{1, 2}}, }, - Expected: `XY = [[1,2]] + Expected: `XY = [[1, 2]] `, }, { @@ -2268,7 +2672,7 @@ func TestMarshalArrays(t *testing.T) { }{ XY: [1][]int{{1, 2}}, }, - Expected: `XY = [[1,2]] + Expected: `XY = [[1, 2]] `, }, { @@ -2277,7 +2681,7 @@ func TestMarshalArrays(t *testing.T) { }{ XY: [][2]int{{1, 2}}, }, - Expected: `XY = [[1,2]] + Expected: `XY = [[1, 2]] `, }, } @@ -2695,11 +3099,7 @@ func TestMarshalInterface(t *testing.T) { InterfacePointerField *interface{} } - type ShouldNotSupportStruct struct { - InterfaceArray []interface{} - } - - expected := []byte(`ArrayField = [1,2,3] + expected := []byte(`ArrayField = [1, 2, 3] InterfacePointerField = "hello world" PrimitiveField = "string" @@ -2749,11 +3149,6 @@ PrimitiveField = "string" } else { t.Fatal(err) } - - // according to the toml standard, data types of array may not be mixed - if _, err := Marshal(ShouldNotSupportStruct{[]interface{}{1, "a", true}}); err == nil { - t.Errorf("Should not support []interface{} marshaling") - } } func TestUnmarshalToNilInterface(t *testing.T) { @@ -3007,6 +3402,20 @@ type sliceStruct struct { StructSlicePtr *[]basicMarshalTestSubStruct ` toml:"struct_slice_ptr" ` } +type arrayStruct struct { + Slice [4]string ` toml:"str_slice" ` + SlicePtr *[4]string ` toml:"str_slice_ptr" ` + IntSlice [4]int ` toml:"int_slice" ` + IntSlicePtr *[4]int ` toml:"int_slice_ptr" ` + StructSlice [4]basicMarshalTestSubStruct ` toml:"struct_slice" ` + StructSlicePtr *[4]basicMarshalTestSubStruct ` toml:"struct_slice_ptr" ` +} + +type arrayTooSmallStruct struct { + Slice [1]string ` toml:"str_slice" ` + StructSlice [1]basicMarshalTestSubStruct ` toml:"struct_slice" ` +} + func TestUnmarshalSlice(t *testing.T) { tree, _ := LoadBytes(sliceTomlDemo) tree, _ = TreeFromMap(tree.ToMap()) @@ -3052,3 +3461,390 @@ func TestUnmarshalSliceFail2(t *testing.T) { } } + +func TestMarshalMixedTypeArray(t *testing.T) { + type InnerStruct struct { + IntField int + StrField string + } + + type TestStruct struct { + ArrayField []interface{} + } + + expected := []byte(`ArrayField = [3.14, 100, true, "hello world", { IntField = 100, StrField = "inner1" }, [{ IntField = 200, StrField = "inner2" }, { IntField = 300, StrField = "inner3" }]] +`) + + if result, err := Marshal(TestStruct{ + ArrayField: []interface{}{ + 3.14, + 100, + true, + "hello world", + InnerStruct{ + IntField: 100, + StrField: "inner1", + }, + []InnerStruct{ + {IntField: 200, StrField: "inner2"}, + {IntField: 300, StrField: "inner3"}, + }, + }, + }); err == nil { + if !bytes.Equal(result, expected) { + t.Errorf("Bad marshal: expected\n----\n%s\n----\ngot\n----\n%s\n----\n", expected, result) + } + } else { + t.Fatal(err) + } +} + +func TestUnmarshalMixedTypeArray(t *testing.T) { + type TestStruct struct { + ArrayField []interface{} + } + + toml := []byte(`ArrayField = [3.14,100,true,"hello world",{Field = "inner1"},[{Field = "inner2"},{Field = "inner3"}]] +`) + + actual := TestStruct{} + expected := TestStruct{ + ArrayField: []interface{}{ + 3.14, + int64(100), + true, + "hello world", + map[string]interface{}{ + "Field": "inner1", + }, + []map[string]interface{}{ + {"Field": "inner2"}, + {"Field": "inner3"}, + }, + }, + } + + if err := Unmarshal(toml, &actual); err == nil { + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Bad unmarshal: expected %#v, got %#v", expected, actual) + } + } else { + t.Fatal(err) + } +} + +func TestUnmarshalArray(t *testing.T) { + var tree *Tree + var err error + + tree, _ = LoadBytes(sliceTomlDemo) + var actual1 arrayStruct + err = tree.Unmarshal(&actual1) + if err != nil { + t.Error("shound not err", err) + } + + tree, _ = TreeFromMap(tree.ToMap()) + var actual2 arrayStruct + err = tree.Unmarshal(&actual2) + if err != nil { + t.Error("shound not err", err) + } + + expected := arrayStruct{ + Slice: [4]string{"Howdy", "Hey There"}, + SlicePtr: &[4]string{"Howdy", "Hey There"}, + IntSlice: [4]int{1, 2}, + IntSlicePtr: &[4]int{1, 2}, + StructSlice: [4]basicMarshalTestSubStruct{{"1"}, {"2"}}, + StructSlicePtr: &[4]basicMarshalTestSubStruct{{"1"}, {"2"}}, + } + if !reflect.DeepEqual(actual1, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual1) + } + if !reflect.DeepEqual(actual2, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual2) + } +} + +func TestUnmarshalArrayFail(t *testing.T) { + tree, _ := TreeFromMap(map[string]interface{}{ + "str_slice": []string{"Howdy", "Hey There"}, + }) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(0, 0): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(0, 0): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + +func TestUnmarshalArrayFail2(t *testing.T) { + tree, _ := Load(`str_slice=["Howdy","Hey There"]`) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(1, 1): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(1, 1): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + +func TestUnmarshalArrayFail3(t *testing.T) { + tree, _ := Load(`[[struct_slice]] +String2="1" +[[struct_slice]] +String2="2"`) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(3, 1): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(3, 1): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + +func TestDecoderStrict(t *testing.T) { + input := ` +[decoded] + key = "" + +[undecoded] + key = "" + + [undecoded.inner] + key = "" + + [[undecoded.array]] + key = "" + + [[undecoded.array]] + key = "" + +` + var doc struct { + Decoded struct { + Key string + } + } + + expected := `undecoded keys: ["undecoded.array.0.key" "undecoded.array.1.key" "undecoded.inner.key" "undecoded.key"]` + + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err == nil { + t.Error("expected error, got none") + } else if err.Error() != expected { + t.Errorf("expect err: %s, got: %s", expected, err.Error()) + } + + if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&doc); err != nil { + t.Errorf("unexpected err: %s", err) + } + + var m map[string]interface{} + if err := NewDecoder(bytes.NewReader([]byte(input))).Decode(&m); err != nil { + t.Errorf("unexpected err: %s", err) + } +} + +func TestDecoderStrictValid(t *testing.T) { + input := ` +[decoded] + key = "" +` + var doc struct { + Decoded struct { + Key string + } + } + + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err != nil { + t.Fatal("unexpected error:", err) + } +} + +type docUnmarshalTOML struct { + Decoded struct { + Key string + } +} + +func (d *docUnmarshalTOML) UnmarshalTOML(i interface{}) error { + if iMap, ok := i.(map[string]interface{}); !ok { + return fmt.Errorf("type assertion error: wants %T, have %T", map[string]interface{}{}, i) + } else if key, ok := iMap["key"]; !ok { + return fmt.Errorf("key '%s' not in map", "key") + } else if keyString, ok := key.(string); !ok { + return fmt.Errorf("type assertion error: wants %T, have %T", "", key) + } else { + d.Decoded.Key = keyString + } + return nil +} + +func TestDecoderStrictCustomUnmarshal(t *testing.T) { + input := `key = "ok"` + var doc docUnmarshalTOML + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err != nil { + t.Fatal("unexpected error:", err) + } + if doc.Decoded.Key != "ok" { + t.Errorf("Bad unmarshal: expected ok, got %v", doc.Decoded.Key) + } +} + +type parent struct { + Doc docUnmarshalTOML + DocPointer *docUnmarshalTOML +} + +func TestCustomUnmarshal(t *testing.T) { + input := ` +[Doc] + key = "ok1" +[DocPointer] + key = "ok2" +` + + var d parent + if err := Unmarshal([]byte(input), &d); err != nil { + t.Fatalf("unexpected err: %s", err.Error()) + } + if d.Doc.Decoded.Key != "ok1" { + t.Errorf("Bad unmarshal: expected ok, got %v", d.Doc.Decoded.Key) + } + if d.DocPointer.Decoded.Key != "ok2" { + t.Errorf("Bad unmarshal: expected ok, got %v", d.DocPointer.Decoded.Key) + } +} + +func TestCustomUnmarshalError(t *testing.T) { + input := ` +[Doc] + key = 1 +[DocPointer] + key = "ok2" +` + + expected := "(2, 1): unmarshal toml: type assertion error: wants string, have int64" + + var d parent + err := Unmarshal([]byte(input), &d) + if err == nil { + t.Error("expected error, got none") + } else if err.Error() != expected { + t.Errorf("expect err: %s, got: %s", expected, err.Error()) + } +} + +type intWrapper struct { + Value int +} + +func (w *intWrapper) UnmarshalText(text []byte) error { + var err error + if w.Value, err = strconv.Atoi(string(text)); err == nil { + return nil + } + if b, err := strconv.ParseBool(string(text)); err == nil { + if b { + w.Value = 1 + } + return nil + } + if f, err := strconv.ParseFloat(string(text), 32); err == nil { + w.Value = int(f) + return nil + } + return fmt.Errorf("unsupported: %s", text) +} + +func TestTextUnmarshal(t *testing.T) { + var doc struct { + UnixTime intWrapper + Version *intWrapper + + Bool intWrapper + Int intWrapper + Float intWrapper + } + + input := ` +UnixTime = "12" +Version = "42" +Bool = true +Int = 21 +Float = 2.0 +` + + if err := Unmarshal([]byte(input), &doc); err != nil { + t.Fatalf("unexpected err: %s", err.Error()) + } + if doc.UnixTime.Value != 12 { + t.Fatalf("expected UnixTime: 12 got: %d", doc.UnixTime.Value) + } + if doc.Version.Value != 42 { + t.Fatalf("expected Version: 42 got: %d", doc.Version.Value) + } + if doc.Bool.Value != 1 { + t.Fatalf("expected Bool: 1 got: %d", doc.Bool.Value) + } + if doc.Int.Value != 21 { + t.Fatalf("expected Int: 21 got: %d", doc.Int.Value) + } + if doc.Float.Value != 2 { + t.Fatalf("expected Float: 2 got: %d", doc.Float.Value) + } +} + +func TestTextUnmarshalError(t *testing.T) { + var doc struct { + Failer intWrapper + } + + input := `Failer = "hello"` + if err := Unmarshal([]byte(input), &doc); err == nil { + t.Fatalf("expected err, got none") + } +} + +// issue406 +func TestPreserveNotEmptyField(t *testing.T) { + toml := []byte(`Field1 = "ccc"`) + type Inner struct { + InnerField1 string + InnerField2 int + } + type TestStruct struct { + Field1 string + Field2 int + Field3 Inner + } + + actual := TestStruct{ + "aaa", + 100, + Inner{ + "bbb", + 200, + }, + } + + expected := TestStruct{ + "ccc", + 100, + Inner{ + "bbb", + 200, + }, + } + + err := Unmarshal(toml, &actual) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Bad unmarshal: expected %+v, got %+v", expected, actual) + } +} diff --git a/parser.go b/parser.go index 1b344fe..7bf40bb 100644 --- a/parser.go +++ b/parser.go @@ -158,6 +158,11 @@ func (p *tomlParser) parseGroup() tomlParserStateFn { if err := p.tree.createSubTree(keys, startToken.Position); err != nil { p.raiseError(key, "%s", err) } + destTree := p.tree.GetPath(keys) + if target, ok := destTree.(*Tree); ok && target != nil && target.inline { + p.raiseError(key, "could not re-define exist inline table or its sub-table : %s", + strings.Join(keys, ".")) + } p.assume(tokenRightBracket) p.currentTable = keys return p.parseStart @@ -201,6 +206,11 @@ func (p *tomlParser) parseAssign() tomlParserStateFn { strings.Join(tableKey, ".")) } + if targetNode.inline { + p.raiseError(key, "could not add key or sub-table to exist inline table or its sub-table : %s", + strings.Join(tableKey, ".")) + } + // assign value to the found table keyVal := parsedKey[len(parsedKey)-1] localKey := []string{keyVal} @@ -411,12 +421,13 @@ Loop: if tokenIsComma(previous) { p.raiseError(previous, "trailing comma at the end of inline table") } + tree.inline = true return tree } func (p *tomlParser) parseArray() interface{} { var array []interface{} - arrayType := reflect.TypeOf(nil) + arrayType := reflect.TypeOf(newTree()) for { follow := p.peek() if follow == nil || follow.typ == tokenEOF { @@ -427,11 +438,8 @@ func (p *tomlParser) parseArray() interface{} { break } val := p.parseRvalue() - if arrayType == nil { - arrayType = reflect.TypeOf(val) - } if reflect.TypeOf(val) != arrayType { - p.raiseError(follow, "mixed types in array") + arrayType = nil } array = append(array, val) follow = p.peek() @@ -445,6 +453,12 @@ func (p *tomlParser) parseArray() interface{} { p.getToken() } } + + // if the array is a mixed-type array or its length is 0, + // don't convert it to a table array + if len(array) <= 0 { + arrayType = nil + } // An array of Trees is actually an array of inline // tables, which is a shorthand for a table array. If the // array was not converted from []interface{} to []*Tree, diff --git a/parser_test.go b/parser_test.go index 4c5a65e..5e96b84 100644 --- a/parser_test.go +++ b/parser_test.go @@ -239,7 +239,8 @@ func TestLocalDateTime(t *testing.T) { Minute: 32, Second: 0, Nanosecond: 0, - }}, + }, + }, }) } @@ -257,7 +258,8 @@ func TestLocalDateTimeNano(t *testing.T) { Minute: 32, Second: 0, Nanosecond: 999999000, - }}, + }, + }, }) } @@ -486,18 +488,6 @@ func TestNestedEmptyArrays(t *testing.T) { }) } -func TestArrayMixedTypes(t *testing.T) { - _, err := Load("a = [42, 16.0]") - if err.Error() != "(1, 10): mixed types in array" { - t.Error("Bad error message:", err.Error()) - } - - _, err = Load("a = [42, \"hello\"]") - if err.Error() != "(1, 11): mixed types in array" { - t.Error("Bad error message:", err.Error()) - } -} - func TestArrayNestedStrings(t *testing.T) { tree, err := Load("data = [ [\"gamma\", \"delta\"], [\"Foo\"] ]") assertTree(t, tree, err, map[string]interface{}{ @@ -677,7 +667,7 @@ func TestInlineTableUnterminated(t *testing.T) { func TestInlineTableCommaExpected(t *testing.T) { _, err := Load("foo = {hello = 53 test = foo}") - if err.Error() != "(1, 19): comma expected between fields in inline table" { + if err.Error() != "(1, 19): unexpected token type in inline table: no value can start with t" { t.Error("Bad error message:", err.Error()) } } @@ -691,7 +681,7 @@ func TestInlineTableCommaStart(t *testing.T) { func TestInlineTableDoubleComma(t *testing.T) { _, err := Load("foo = {hello = 53,, foo = 17}") - if err.Error() != "(1, 19): need field between two commas in inline table" { + if err.Error() != "(1, 19): unexpected token type in inline table: keys cannot contain , character" { t.Error("Bad error message:", err.Error()) } } @@ -703,6 +693,34 @@ func TestInlineTableTrailingComma(t *testing.T) { } } +func TestAddKeyToInlineTable(t *testing.T) { + _, err := Load("type = { name = \"Nail\" }\ntype.edible = false") + if err.Error() != "(2, 1): could not add key or sub-table to exist inline table or its sub-table : type" { + t.Error("Bad error message:", err.Error()) + } +} + +func TestAddSubTableToInlineTable(t *testing.T) { + _, err := Load("a = { b = \"c\" }\na.d.e = \"f\"") + if err.Error() != "(2, 1): could not add key or sub-table to exist inline table or its sub-table : a.d" { + t.Error("Bad error message:", err.Error()) + } +} + +func TestAddKeyToSubTableOfInlineTable(t *testing.T) { + _, err := Load("a = { b = { c = \"d\" } }\na.b.e = \"f\"") + if err.Error() != "(2, 1): could not add key or sub-table to exist inline table or its sub-table : a.b" { + t.Error("Bad error message:", err.Error()) + } +} + +func TestReDefineInlineTable(t *testing.T) { + _, err := Load("a = { b = \"c\" }\n[a]\n d = \"e\"") + if err.Error() != "(2, 2): could not re-define exist inline table or its sub-table : a" { + t.Error("Bad error message:", err.Error()) + } +} + func TestDuplicateGroups(t *testing.T) { _, err := Load("[foo]\na=2\n[foo]b=3") if err.Error() != "(3, 2): duplicated tables" { @@ -900,13 +918,11 @@ func TestTomlValueStringRepresentation(t *testing.T) { {"hello world", "\"hello world\""}, {"\b\t\n\f\r\"\\", "\"\\b\\t\\n\\f\\r\\\"\\\\\""}, {"\x05", "\"\\u0005\""}, - {time.Date(1979, time.May, 27, 7, 32, 0, 0, time.UTC), - "1979-05-27T07:32:00Z"}, - {[]interface{}{"gamma", "delta"}, - "[\"gamma\",\"delta\"]"}, + {time.Date(1979, time.May, 27, 7, 32, 0, 0, time.UTC), "1979-05-27T07:32:00Z"}, + {[]interface{}{"gamma", "delta"}, "[\"gamma\", \"delta\"]"}, {nil, ""}, } { - result, err := tomlValueStringRepresentation(item.Value, "", "", false) + result, err := tomlValueStringRepresentation(item.Value, "", "", OrderAlphabetical, false) if err != nil { t.Errorf("Test %d - unexpected error: %s", idx, err) } @@ -1033,7 +1049,7 @@ func TestInvalidFloatParsing(t *testing.T) { } _, err = Load("a=_1_2") - if err.Error() != "(1, 3): cannot start number with underscore" { + if err.Error() != "(1, 3): no value can start with _" { t.Error("Bad error message:", err.Error()) } } @@ -1097,11 +1113,10 @@ The quick brown \ the lazy dog.""" str3 = """\ - The quick brown \ - fox jumps over \ - the lazy dog.\ - """`) - + The quick brown \` + " " + ` + fox jumps over \` + " " + ` + the lazy dog.\` + " " + ` + """`) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/query/README.md b/query/README.md new file mode 100644 index 0000000..75b3759 --- /dev/null +++ b/query/README.md @@ -0,0 +1,201 @@ +# Query package + +## Overview + +Package query performs JSONPath-like queries on a TOML document. + +The query path implementation is based loosely on the JSONPath specification: +http://goessner.net/articles/JsonPath/. + +The idea behind a query path is to allow quick access to any element, or set +of elements within TOML document, with a single expression. + +```go +result, err := query.CompileAndExecute("$.foo.bar.baz", tree) +``` + +This is roughly equivalent to: + +```go +next := tree.Get("foo") +if next != nil { + next = next.Get("bar") + if next != nil { + next = next.Get("baz") + } +} +result := next +``` + +err is nil if any parsing exception occurs. + +If no node in the tree matches the query, result will simply contain an empty list of +items. + +As illustrated above, the query path is much more efficient, especially since +the structure of the TOML file can vary. Rather than making assumptions about +a document's structure, a query allows the programmer to make structured +requests into the document, and get zero or more values as a result. + +## Query syntax + +The syntax of a query begins with a root token, followed by any number +sub-expressions: + +``` +$ + Root of the TOML tree. This must always come first. +.name + Selects child of this node, where 'name' is a TOML key + name. +['name'] + Selects child of this node, where 'name' is a string + containing a TOML key name. +[index] + Selcts child array element at 'index'. +..expr + Recursively selects all children, filtered by an a union, + index, or slice expression. +..* + Recursive selection of all nodes at this point in the + tree. +.* + Selects all children of the current node. +[expr,expr] + Union operator - a logical 'or' grouping of two or more + sub-expressions: index, key name, or filter. +[start:end:step] + Slice operator - selects array elements from start to + end-1, at the given step. All three arguments are + optional. +[?(filter)] + Named filter expression - the function 'filter' is + used to filter children at this node. +``` + +## Query Indexes And Slices + +Index expressions perform no bounds checking, and will contribute no +values to the result set if the provided index or index range is invalid. +Negative indexes represent values from the end of the array, counting backwards. + +```go +// select the last index of the array named 'foo' +query.CompileAndExecute("$.foo[-1]", tree) +``` + +Slice expressions are supported, by using ':' to separate a start/end index pair. + +```go +// select up to the first five elements in the array +query.CompileAndExecute("$.foo[0:5]", tree) +``` + +Slice expressions also allow negative indexes for the start and stop +arguments. + +```go +// select all array elements except the last one. +query.CompileAndExecute("$.foo[0:-1]", tree) +``` + +Slice expressions may have an optional stride/step parameter: + +```go +// select every other element +query.CompileAndExecute("$.foo[0::2]", tree) +``` + +Slice start and end parameters are also optional: + +```go +// these are all equivalent and select all the values in the array +query.CompileAndExecute("$.foo[:]", tree) +query.CompileAndExecute("$.foo[::]", tree) +query.CompileAndExecute("$.foo[::1]", tree) +query.CompileAndExecute("$.foo[0:]", tree) +query.CompileAndExecute("$.foo[0::]", tree) +query.CompileAndExecute("$.foo[0::1]", tree) +``` + +## Query Filters + +Query filters are used within a Union [,] or single Filter [] expression. +A filter only allows nodes that qualify through to the next expression, +and/or into the result set. + +```go +// returns children of foo that are permitted by the 'bar' filter. +query.CompileAndExecute("$.foo[?(bar)]", tree) +``` + +There are several filters provided with the library: + +``` +tree + Allows nodes of type Tree. +int + Allows nodes of type int64. +float + Allows nodes of type float64. +string + Allows nodes of type string. +time + Allows nodes of type time.Time. +bool + Allows nodes of type bool. +``` + +## Query Results + +An executed query returns a Result object. This contains the nodes +in the TOML tree that qualify the query expression. Position information +is also available for each value in the set. + +```go +// display the results of a query +results := query.CompileAndExecute("$.foo.bar.baz", tree) +for idx, value := results.Values() { + fmt.Println("%v: %v", results.Positions()[idx], value) +} +``` + +## Compiled Queries + +Queries may be executed directly on a Tree object, or compiled ahead +of time and executed discretely. The former is more convenient, but has the +penalty of having to recompile the query expression each time. + +```go +// basic query +results := query.CompileAndExecute("$.foo.bar.baz", tree) + +// compiled query +query, err := toml.Compile("$.foo.bar.baz") +results := query.Execute(tree) + +// run the compiled query again on a different tree +moreResults := query.Execute(anotherTree) +``` + +## User Defined Query Filters + +Filter expressions may also be user defined by using the SetFilter() +function on the Query object. The function must return true/false, which +signifies if the passed node is kept or discarded, respectively. + +```go +// create a query that references a user-defined filter +query, _ := query.Compile("$[?(bazOnly)]") + +// define the filter, and assign it to the query +query.SetFilter("bazOnly", func(node interface{}) bool{ + if tree, ok := node.(*Tree); ok { + return tree.Has("baz") + } + return false // reject all other node types +}) + +// run the query +query.Execute(tree) +``` diff --git a/query/doc.go b/query/doc.go index ed63c11..d0efb21 100644 --- a/query/doc.go +++ b/query/doc.go @@ -25,7 +25,7 @@ // items. // // As illustrated above, the query path is much more efficient, especially since -// the structure of the TOML file can vary. Rather than making assumptions about +// the structure of the TOML file can vary. Rather than making assumptions about // a document's structure, a query allows the programmer to make structured // requests into the document, and get zero or more values as a result. // @@ -35,7 +35,7 @@ // sub-expressions: // // $ -// Root of the TOML tree. This must always come first. +// Root of the TOML tree. This must always come first. // .name // Selects child of this node, where 'name' is a TOML key // name. @@ -57,7 +57,7 @@ // sub-expressions: index, key name, or filter. // [start:end:step] // Slice operator - selects array elements from start to -// end-1, at the given step. All three arguments are +// end-1, at the given step. All three arguments are // optional. // [?(filter)] // Named filter expression - the function 'filter' is @@ -80,25 +80,23 @@ // Slice expressions also allow negative indexes for the start and stop // arguments. // -// // select all array elements. +// // select all array elements except the last one. // query.CompileAndExecute("$.foo[0:-1]", tree) // // Slice expressions may have an optional stride/step parameter: // // // select every other element -// query.CompileAndExecute("$.foo[0:-1:2]", tree) +// query.CompileAndExecute("$.foo[0::2]", tree) // // Slice start and end parameters are also optional: // // // these are all equivalent and select all the values in the array // query.CompileAndExecute("$.foo[:]", tree) -// query.CompileAndExecute("$.foo[0:]", tree) -// query.CompileAndExecute("$.foo[:-1]", tree) -// query.CompileAndExecute("$.foo[0:-1:]", tree) +// query.CompileAndExecute("$.foo[::]", tree) // query.CompileAndExecute("$.foo[::1]", tree) +// query.CompileAndExecute("$.foo[0:]", tree) +// query.CompileAndExecute("$.foo[0::]", tree) // query.CompileAndExecute("$.foo[0::1]", tree) -// query.CompileAndExecute("$.foo[:-1:1]", tree) -// query.CompileAndExecute("$.foo[0:-1:1]", tree) // // Query Filters // @@ -126,8 +124,8 @@ // // Query Results // -// An executed query returns a Result object. This contains the nodes -// in the TOML tree that qualify the query expression. Position information +// An executed query returns a Result object. This contains the nodes +// in the TOML tree that qualify the query expression. Position information // is also available for each value in the set. // // // display the results of a query @@ -139,7 +137,7 @@ // Compiled Queries // // Queries may be executed directly on a Tree object, or compiled ahead -// of time and executed discretely. The former is more convenient, but has the +// of time and executed discretely. The former is more convenient, but has the // penalty of having to recompile the query expression each time. // // // basic query @@ -155,7 +153,7 @@ // User Defined Query Filters // // Filter expressions may also be user defined by using the SetFilter() -// function on the Query object. The function must return true/false, which +// function on the Query object. The function must return true/false, which // signifies if the passed node is kept or discarded, respectively. // // // create a query that references a user-defined filter @@ -166,7 +164,7 @@ // if tree, ok := node.(*Tree); ok { // return tree.Has("baz") // } -// return false // reject all other node types +// return false // reject all other node types // }) // // // run the query diff --git a/query/match.go b/query/match.go index d7bb15a..37b43da 100644 --- a/query/match.go +++ b/query/match.go @@ -2,6 +2,8 @@ package query import ( "fmt" + "reflect" + "github.com/pelletier/go-toml" ) @@ -44,16 +46,16 @@ func newMatchKeyFn(name string) *matchKeyFn { func (f *matchKeyFn) call(node interface{}, ctx *queryContext) { if array, ok := node.([]*toml.Tree); ok { for _, tree := range array { - item := tree.Get(f.Name) + item := tree.GetPath([]string{f.Name}) if item != nil { - ctx.lastPosition = tree.GetPosition(f.Name) + ctx.lastPosition = tree.GetPositionPath([]string{f.Name}) f.next.call(item, ctx) } } } else if tree, ok := node.(*toml.Tree); ok { - item := tree.Get(f.Name) + item := tree.GetPath([]string{f.Name}) if item != nil { - ctx.lastPosition = tree.GetPosition(f.Name) + ctx.lastPosition = tree.GetPositionPath([]string{f.Name}) f.next.call(item, ctx) } } @@ -70,53 +72,130 @@ func newMatchIndexFn(idx int) *matchIndexFn { } func (f *matchIndexFn) call(node interface{}, ctx *queryContext) { - if arr, ok := node.([]interface{}); ok { - if f.Idx < len(arr) && f.Idx >= 0 { - if treesArray, ok := node.([]*toml.Tree); ok { - if len(treesArray) > 0 { - ctx.lastPosition = treesArray[0].Position() - } - } - f.next.call(arr[f.Idx], ctx) + v := reflect.ValueOf(node) + if v.Kind() == reflect.Slice { + if v.Len() == 0 { + return + } + + // Manage negative values + idx := f.Idx + if idx < 0 { + idx += v.Len() + } + if 0 <= idx && idx < v.Len() { + callNextIndexSlice(f.next, node, ctx, v.Index(idx).Interface()) } } } +func callNextIndexSlice(next pathFn, node interface{}, ctx *queryContext, value interface{}) { + if treesArray, ok := node.([]*toml.Tree); ok { + ctx.lastPosition = treesArray[0].Position() + } + next.call(value, ctx) +} + // filter by slicing type matchSliceFn struct { matchBase - Start, End, Step int + Start, End, Step *int } -func newMatchSliceFn(start, end, step int) *matchSliceFn { - return &matchSliceFn{Start: start, End: end, Step: step} +func newMatchSliceFn() *matchSliceFn { + return &matchSliceFn{} +} + +func (f *matchSliceFn) setStart(start int) *matchSliceFn { + f.Start = &start + return f +} + +func (f *matchSliceFn) setEnd(end int) *matchSliceFn { + f.End = &end + return f +} + +func (f *matchSliceFn) setStep(step int) *matchSliceFn { + f.Step = &step + return f } func (f *matchSliceFn) call(node interface{}, ctx *queryContext) { - if arr, ok := node.([]interface{}); ok { - // adjust indexes for negative values, reverse ordering - realStart, realEnd := f.Start, f.End - if realStart < 0 { - realStart = len(arr) + realStart + v := reflect.ValueOf(node) + if v.Kind() == reflect.Slice { + if v.Len() == 0 { + return } - if realEnd < 0 { - realEnd = len(arr) + realEnd + + var start, end, step int + + // Initialize step + if f.Step != nil { + step = *f.Step + } else { + step = 1 } - if realEnd < realStart { - realEnd, realStart = realStart, realEnd // swap - } - // loop and gather - for idx := realStart; idx < realEnd; idx += f.Step { - if treesArray, ok := node.([]*toml.Tree); ok { - if len(treesArray) > 0 { - ctx.lastPosition = treesArray[0].Position() - } + + // Initialize start + if f.Start != nil { + start = *f.Start + // Manage negative values + if start < 0 { + start += v.Len() + } + // Manage out of range values + start = max(start, 0) + start = min(start, v.Len()-1) + } else if step > 0 { + start = 0 + } else { + start = v.Len() - 1 + } + + // Initialize end + if f.End != nil { + end = *f.End + // Manage negative values + if end < 0 { + end += v.Len() + } + // Manage out of range values + end = max(end, -1) + end = min(end, v.Len()) + } else if step > 0 { + end = v.Len() + } else { + end = -1 + } + + // Loop on values + if step > 0 { + for idx := start; idx < end; idx += step { + callNextIndexSlice(f.next, node, ctx, v.Index(idx).Interface()) + } + } else { + for idx := start; idx > end; idx += step { + callNextIndexSlice(f.next, node, ctx, v.Index(idx).Interface()) } - f.next.call(arr[idx], ctx) } } } +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + // match anything type matchAnyFn struct { matchBase @@ -129,8 +208,8 @@ func newMatchAnyFn() *matchAnyFn { func (f *matchAnyFn) call(node interface{}, ctx *queryContext) { if tree, ok := node.(*toml.Tree); ok { for _, k := range tree.Keys() { - v := tree.Get(k) - ctx.lastPosition = tree.GetPosition(k) + v := tree.GetPath([]string{k}) + ctx.lastPosition = tree.GetPositionPath([]string{k}) f.next.call(v, ctx) } } @@ -168,8 +247,8 @@ func (f *matchRecursiveFn) call(node interface{}, ctx *queryContext) { var visit func(tree *toml.Tree) visit = func(tree *toml.Tree) { for _, k := range tree.Keys() { - v := tree.Get(k) - ctx.lastPosition = tree.GetPosition(k) + v := tree.GetPath([]string{k}) + ctx.lastPosition = tree.GetPositionPath([]string{k}) f.next.call(v, ctx) switch node := v.(type) { case *toml.Tree: @@ -207,9 +286,9 @@ func (f *matchFilterFn) call(node interface{}, ctx *queryContext) { switch castNode := node.(type) { case *toml.Tree: for _, k := range castNode.Keys() { - v := castNode.Get(k) + v := castNode.GetPath([]string{k}) if fn(v) { - ctx.lastPosition = castNode.GetPosition(k) + ctx.lastPosition = castNode.GetPositionPath([]string{k}) f.next.call(v, ctx) } } diff --git a/query/match_test.go b/query/match_test.go index 429b8f6..47472c1 100644 --- a/query/match_test.go +++ b/query/match_test.go @@ -2,8 +2,10 @@ package query import ( "fmt" - "github.com/pelletier/go-toml" + "strconv" "testing" + + "github.com/pelletier/go-toml" ) // dump path tree to a string @@ -19,8 +21,17 @@ func pathString(root pathFn) string { result += fmt.Sprintf("{%d}", fn.Idx) result += pathString(fn.next) case *matchSliceFn: - result += fmt.Sprintf("{%d:%d:%d}", - fn.Start, fn.End, fn.Step) + startString, endString, stepString := "nil", "nil", "nil" + if fn.Start != nil { + startString = strconv.Itoa(*fn.Start) + } + if fn.End != nil { + endString = strconv.Itoa(*fn.End) + } + if fn.Step != nil { + stepString = strconv.Itoa(*fn.Step) + } + result += fmt.Sprintf("{%s:%s:%s}", startString, endString, stepString) result += pathString(fn.next) case *matchAnyFn: result += "{}" @@ -110,7 +121,7 @@ func TestPathSliceStart(t *testing.T) { assertPath(t, "$[123:]", buildPath( - newMatchSliceFn(123, maxInt, 1), + newMatchSliceFn().setStart(123), )) } @@ -118,7 +129,7 @@ func TestPathSliceStartEnd(t *testing.T) { assertPath(t, "$[123:456]", buildPath( - newMatchSliceFn(123, 456, 1), + newMatchSliceFn().setStart(123).setEnd(456), )) } @@ -126,7 +137,7 @@ func TestPathSliceStartEndColon(t *testing.T) { assertPath(t, "$[123:456:]", buildPath( - newMatchSliceFn(123, 456, 1), + newMatchSliceFn().setStart(123).setEnd(456), )) } @@ -134,7 +145,7 @@ func TestPathSliceStartStep(t *testing.T) { assertPath(t, "$[123::7]", buildPath( - newMatchSliceFn(123, maxInt, 7), + newMatchSliceFn().setStart(123).setStep(7), )) } @@ -142,7 +153,7 @@ func TestPathSliceEndStep(t *testing.T) { assertPath(t, "$[:456:7]", buildPath( - newMatchSliceFn(0, 456, 7), + newMatchSliceFn().setEnd(456).setStep(7), )) } @@ -150,7 +161,7 @@ func TestPathSliceStep(t *testing.T) { assertPath(t, "$[::7]", buildPath( - newMatchSliceFn(0, maxInt, 7), + newMatchSliceFn().setStep(7), )) } @@ -158,7 +169,7 @@ func TestPathSliceAll(t *testing.T) { assertPath(t, "$[123:456:7]", buildPath( - newMatchSliceFn(123, 456, 7), + newMatchSliceFn().setStart(123).setEnd(456).setStep(7), )) } diff --git a/query/parser.go b/query/parser.go index 5f69b70..be27d35 100644 --- a/query/parser.go +++ b/query/parser.go @@ -203,12 +203,13 @@ loop: // labeled loop for easy breaking func (p *queryParser) parseSliceExpr() queryParserStateFn { // init slice to grab all elements - start, end, step := 0, maxInt, 1 + var start, end, step *int = nil, nil, nil // parse optional start tok := p.getToken() if tok.typ == tokenInteger { - start = tok.Int() + v := tok.Int() + start = &v tok = p.getToken() } if tok.typ != tokenColon { @@ -218,11 +219,12 @@ func (p *queryParser) parseSliceExpr() queryParserStateFn { // parse optional end tok = p.getToken() if tok.typ == tokenInteger { - end = tok.Int() + v := tok.Int() + end = &v tok = p.getToken() } if tok.typ == tokenRightBracket { - p.query.appendPath(newMatchSliceFn(start, end, step)) + p.query.appendPath(&matchSliceFn{Start: start, End: end, Step: step}) return p.parseMatchExpr } if tok.typ != tokenColon { @@ -232,17 +234,18 @@ func (p *queryParser) parseSliceExpr() queryParserStateFn { // parse optional step tok = p.getToken() if tok.typ == tokenInteger { - step = tok.Int() - if step < 0 { - return p.parseError(tok, "step must be a positive value") + v := tok.Int() + if v == 0 { + return p.parseError(tok, "step cannot be zero") } + step = &v tok = p.getToken() } if tok.typ != tokenRightBracket { return p.parseError(tok, "expected ']'") } - p.query.appendPath(newMatchSliceFn(start, end, step)) + p.query.appendPath(&matchSliceFn{Start: start, End: end, Step: step}) return p.parseMatchExpr } diff --git a/query/parser_test.go b/query/parser_test.go index af93276..91d3f70 100644 --- a/query/parser_test.go +++ b/query/parser_test.go @@ -78,6 +78,19 @@ func assertValue(t *testing.T, result, ref interface{}) { } } +func assertParseError(t *testing.T, query string, errString string) { + _, err := Compile(query) + if err == nil { + t.Error("error should be non-nil") + return + } + if err.Error() != errString { + t.Errorf("error does not match") + t.Log("test:", err.Error()) + t.Log("ref: ", errString) + } +} + func assertQueryPositions(t *testing.T, tomlDoc string, query string, ref []interface{}) { tree, err := toml.Load(tomlDoc) if err != nil { @@ -128,54 +141,213 @@ func TestQueryKeyString(t *testing.T) { }) } -func TestQueryIndex(t *testing.T) { +func TestQueryKeyUnicodeString(t *testing.T) { assertQueryPositions(t, - "[foo]\na = [1,2,3,4,5,6,7,8,9,0]", - "$.foo.a[5]", + "['f𝟘.o']\na = 42", + "$['f𝟘.o']['a']", []interface{}{ queryTestNode{ - int64(6), toml.Position{2, 1}, + int64(42), toml.Position{2, 1}, }, }) } +func TestQueryIndexError1(t *testing.T) { + assertParseError(t, "$.foo.a[5", "(1, 10): expected ',' or ']', not ''") +} + +func TestQueryIndexError2(t *testing.T) { + assertParseError(t, "$.foo.a[]", "(1, 9): expected union sub expression, not ']', 0") +} + +func TestQueryIndex(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[5]", + []interface{}{ + queryTestNode{int64(5), toml.Position{2, 1}}, + }) +} + +func TestQueryIndexNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[-2]", + []interface{}{ + queryTestNode{int64(8), toml.Position{2, 1}}, + }) +} + +func TestQueryIndexWrong(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[99]", + []interface{}{}) +} + +func TestQueryIndexEmpty(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = []", + "$.foo.a[5]", + []interface{}{}) +} + +func TestQueryIndexTree(t *testing.T) { + assertQueryPositions(t, + "[[foo]]\na = [0,1,2,3,4,5,6,7,8,9]\n[[foo]]\nb = 3", + "$.foo[1].b", + []interface{}{ + queryTestNode{int64(3), toml.Position{4, 1}}, + }) +} + +func TestQuerySliceError1(t *testing.T) { + assertParseError(t, "$.foo.a[3:?]", "(1, 11): expected ']' or ':'") +} + +func TestQuerySliceError2(t *testing.T) { + assertParseError(t, "$.foo.a[:::]", "(1, 11): expected ']'") +} + +func TestQuerySliceError3(t *testing.T) { + assertParseError(t, "$.foo.a[::0]", "(1, 11): step cannot be zero") +} + func TestQuerySliceRange(t *testing.T) { assertQueryPositions(t, - "[foo]\na = [1,2,3,4,5,6,7,8,9,0]", - "$.foo.a[0:5]", + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[:5]", []interface{}{ - queryTestNode{ - int64(1), toml.Position{2, 1}, - }, - queryTestNode{ - int64(2), toml.Position{2, 1}, - }, - queryTestNode{ - int64(3), toml.Position{2, 1}, - }, - queryTestNode{ - int64(4), toml.Position{2, 1}, - }, - queryTestNode{ - int64(5), toml.Position{2, 1}, - }, + queryTestNode{int64(0), toml.Position{2, 1}}, + queryTestNode{int64(1), toml.Position{2, 1}}, + queryTestNode{int64(2), toml.Position{2, 1}}, + queryTestNode{int64(3), toml.Position{2, 1}}, + queryTestNode{int64(4), toml.Position{2, 1}}, }) } func TestQuerySliceStep(t *testing.T) { assertQueryPositions(t, - "[foo]\na = [1,2,3,4,5,6,7,8,9,0]", + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", "$.foo.a[0:5:2]", + []interface{}{ + queryTestNode{int64(0), toml.Position{2, 1}}, + queryTestNode{int64(2), toml.Position{2, 1}}, + queryTestNode{int64(4), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceStartNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[-3:]", + []interface{}{ + queryTestNode{int64(7), toml.Position{2, 1}}, + queryTestNode{int64(8), toml.Position{2, 1}}, + queryTestNode{int64(9), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceEndNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[:-6]", + []interface{}{ + queryTestNode{int64(0), toml.Position{2, 1}}, + queryTestNode{int64(1), toml.Position{2, 1}}, + queryTestNode{int64(2), toml.Position{2, 1}}, + queryTestNode{int64(3), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceStepNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[::-2]", + []interface{}{ + queryTestNode{int64(9), toml.Position{2, 1}}, + queryTestNode{int64(7), toml.Position{2, 1}}, + queryTestNode{int64(5), toml.Position{2, 1}}, + queryTestNode{int64(3), toml.Position{2, 1}}, + queryTestNode{int64(1), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceStartOverRange(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[-99:3]", + []interface{}{ + queryTestNode{int64(0), toml.Position{2, 1}}, + queryTestNode{int64(1), toml.Position{2, 1}}, + queryTestNode{int64(2), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceStartOverRangeNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[99:7:-1]", + []interface{}{ + queryTestNode{int64(9), toml.Position{2, 1}}, + queryTestNode{int64(8), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceEndOverRange(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[7:99]", + []interface{}{ + queryTestNode{int64(7), toml.Position{2, 1}}, + queryTestNode{int64(8), toml.Position{2, 1}}, + queryTestNode{int64(9), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceEndOverRangeNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[2:-99:-1]", + []interface{}{ + queryTestNode{int64(2), toml.Position{2, 1}}, + queryTestNode{int64(1), toml.Position{2, 1}}, + queryTestNode{int64(0), toml.Position{2, 1}}, + }) +} + +func TestQuerySliceWrongRange(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[5:3]", + []interface{}{}) +} + +func TestQuerySliceWrongRangeNegative(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = [0,1,2,3,4,5,6,7,8,9]", + "$.foo.a[3:5:-1]", + []interface{}{}) +} + +func TestQuerySliceEmpty(t *testing.T) { + assertQueryPositions(t, + "[foo]\na = []", + "$.foo.a[5:]", + []interface{}{}) +} + +func TestQuerySliceTree(t *testing.T) { + assertQueryPositions(t, + "[[foo]]\na='nok'\n[[foo]]\na = [0,1,2,3,4,5,6,7,8,9]\n[[foo]]\na='ok'\nb = 3", + "$.foo[1:].a", []interface{}{ queryTestNode{ - int64(1), toml.Position{2, 1}, - }, - queryTestNode{ - int64(3), toml.Position{2, 1}, - }, - queryTestNode{ - int64(5), toml.Position{2, 1}, - }, + []interface{}{ + int64(0), int64(1), int64(2), int64(3), int64(4), + int64(5), int64(6), int64(7), int64(8), int64(9)}, + toml.Position{4, 1}}, + queryTestNode{"ok", toml.Position{6, 1}}, }) } @@ -265,12 +437,8 @@ func TestQueryRecursionAll(t *testing.T) { "b": int64(2), }, toml.Position{1, 1}, }, - queryTestNode{ - int64(1), toml.Position{2, 1}, - }, - queryTestNode{ - int64(2), toml.Position{3, 1}, - }, + queryTestNode{int64(1), toml.Position{2, 1}}, + queryTestNode{int64(2), toml.Position{3, 1}}, queryTestNode{ map[string]interface{}{ "foo": map[string]interface{}{ @@ -285,12 +453,8 @@ func TestQueryRecursionAll(t *testing.T) { "b": int64(4), }, toml.Position{4, 1}, }, - queryTestNode{ - int64(3), toml.Position{5, 1}, - }, - queryTestNode{ - int64(4), toml.Position{6, 1}, - }, + queryTestNode{int64(3), toml.Position{5, 1}}, + queryTestNode{int64(4), toml.Position{6, 1}}, queryTestNode{ map[string]interface{}{ "foo": map[string]interface{}{ @@ -305,12 +469,8 @@ func TestQueryRecursionAll(t *testing.T) { "b": int64(6), }, toml.Position{7, 1}, }, - queryTestNode{ - int64(5), toml.Position{8, 1}, - }, - queryTestNode{ - int64(6), toml.Position{9, 1}, - }, + queryTestNode{int64(5), toml.Position{8, 1}}, + queryTestNode{int64(6), toml.Position{9, 1}}, }) } @@ -358,59 +518,30 @@ func TestQueryFilterFn(t *testing.T) { assertQueryPositions(t, string(buff), "$..[?(int)]", []interface{}{ - queryTestNode{ - int64(8001), toml.Position{13, 1}, - }, - queryTestNode{ - int64(8001), toml.Position{13, 1}, - }, - queryTestNode{ - int64(8002), toml.Position{13, 1}, - }, - queryTestNode{ - int64(5000), toml.Position{14, 1}, - }, + queryTestNode{int64(8001), toml.Position{13, 1}}, + queryTestNode{int64(8001), toml.Position{13, 1}}, + queryTestNode{int64(8002), toml.Position{13, 1}}, + queryTestNode{int64(5000), toml.Position{14, 1}}, }) assertQueryPositions(t, string(buff), "$..[?(string)]", []interface{}{ - queryTestNode{ - "TOML Example", toml.Position{3, 1}, - }, - queryTestNode{ - "Tom Preston-Werner", toml.Position{6, 1}, - }, - queryTestNode{ - "GitHub", toml.Position{7, 1}, - }, - queryTestNode{ - "GitHub Cofounder & CEO\nLikes tater tots and beer.", - toml.Position{8, 1}, - }, - queryTestNode{ - "192.168.1.1", toml.Position{12, 1}, - }, - queryTestNode{ - "10.0.0.1", toml.Position{21, 3}, - }, - queryTestNode{ - "eqdc10", toml.Position{22, 3}, - }, - queryTestNode{ - "10.0.0.2", toml.Position{25, 3}, - }, - queryTestNode{ - "eqdc10", toml.Position{26, 3}, - }, + queryTestNode{"TOML Example", toml.Position{3, 1}}, + queryTestNode{"Tom Preston-Werner", toml.Position{6, 1}}, + queryTestNode{"GitHub", toml.Position{7, 1}}, + queryTestNode{"GitHub Cofounder & CEO\nLikes tater tots and beer.", toml.Position{8, 1}}, + queryTestNode{"192.168.1.1", toml.Position{12, 1}}, + queryTestNode{"10.0.0.1", toml.Position{21, 3}}, + queryTestNode{"eqdc10", toml.Position{22, 3}}, + queryTestNode{"10.0.0.2", toml.Position{25, 3}}, + queryTestNode{"eqdc10", toml.Position{26, 3}}, }) assertQueryPositions(t, string(buff), "$..[?(float)]", - []interface{}{ - queryTestNode{ - 4e-08, toml.Position{30, 1}, - }, + []interface{}{ + queryTestNode{4e-08, toml.Position{30, 1}}, }) tv, _ := time.Parse(time.RFC3339, "1979-05-27T07:32:00Z") @@ -471,16 +602,12 @@ func TestQueryFilterFn(t *testing.T) { assertQueryPositions(t, string(buff), "$..[?(time)]", []interface{}{ - queryTestNode{ - tv, toml.Position{9, 1}, - }, + queryTestNode{tv, toml.Position{9, 1}}, }) assertQueryPositions(t, string(buff), "$..[?(bool)]", []interface{}{ - queryTestNode{ - true, toml.Position{15, 1}, - }, + queryTestNode{true, toml.Position{15, 1}}, }) } diff --git a/query/query_test.go b/query/query_test.go index 903a8dc..87d1351 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -7,25 +7,26 @@ import ( "github.com/pelletier/go-toml" ) -func assertArrayContainsInAnyOrder(t *testing.T, array []interface{}, objects ...interface{}) { +func assertArrayContainsInOrder(t *testing.T, array []interface{}, objects ...interface{}) { if len(array) != len(objects) { t.Fatalf("array contains %d objects but %d are expected", len(array), len(objects)) } - for _, o := range objects { - found := false - for _, a := range array { - if a == o { - found = true - break - } - } - if !found { - t.Fatal(o, "not found in array", array) + for i := 0; i < len(array); i++ { + if array[i] != objects[i] { + t.Fatalf("wanted '%s', have '%s'", objects[i], array[i]) } } } +func checkQuery(t *testing.T, tree *toml.Tree, query string, objects ...interface{}) { + results, err := CompileAndExecute(query, tree) + if err != nil { + t.Fatal("unexpected error:", err) + } + assertArrayContainsInOrder(t, results.Values(), objects...) +} + func TestQueryExample(t *testing.T) { config, _ := toml.Load(` [[book]] @@ -37,16 +38,18 @@ func TestQueryExample(t *testing.T) { [[book]] title = "Neuromancer" author = "William Gibson" - `) - authors, err := CompileAndExecute("$.book.author", config) - if err != nil { - t.Fatal("unexpected error:", err) - } - names := authors.Values() - if len(names) != 3 { - t.Fatalf("query should return 3 names but returned %d", len(names)) - } - assertArrayContainsInAnyOrder(t, names, "Stephen King", "Ernest Hemmingway", "William Gibson") + `) + + checkQuery(t, config, "$.book.author", "Stephen King", "Ernest Hemmingway", "William Gibson") + + checkQuery(t, config, "$.book[0].author", "Stephen King") + checkQuery(t, config, "$.book[-1].author", "William Gibson") + checkQuery(t, config, "$.book[1:].author", "Ernest Hemmingway", "William Gibson") + checkQuery(t, config, "$.book[-1:].author", "William Gibson") + checkQuery(t, config, "$.book[::2].author", "Stephen King", "William Gibson") + checkQuery(t, config, "$.book[::-1].author", "William Gibson", "Ernest Hemmingway", "Stephen King") + checkQuery(t, config, "$.book[:].author", "Stephen King", "Ernest Hemmingway", "William Gibson") + checkQuery(t, config, "$.book[::].author", "Stephen King", "Ernest Hemmingway", "William Gibson") } func TestQueryReadmeExample(t *testing.T) { @@ -56,16 +59,7 @@ user = "pelletier" password = "mypassword" `) - query, err := Compile("$..[user,password]") - if err != nil { - t.Fatal("unexpected error:", err) - } - results := query.Execute(config) - values := results.Values() - if len(values) != 2 { - t.Fatalf("query should return 2 values but returned %d", len(values)) - } - assertArrayContainsInAnyOrder(t, values, "pelletier", "mypassword") + checkQuery(t, config, "$..[user,password]", "pelletier", "mypassword") } func TestQueryPathNotPresent(t *testing.T) { diff --git a/query/tokens.go b/query/tokens.go index 9ae579d..098c856 100644 --- a/query/tokens.go +++ b/query/tokens.go @@ -2,9 +2,9 @@ package query import ( "fmt" - "github.com/pelletier/go-toml" "strconv" - "unicode" + + "github.com/pelletier/go-toml" ) // Define tokens @@ -92,11 +92,11 @@ func isSpace(r rune) bool { } func isAlphanumeric(r rune) bool { - return unicode.IsLetter(r) || r == '_' + return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' } func isDigit(r rune) bool { - return unicode.IsNumber(r) + return '0' <= r && r <= '9' } func isHexDigit(r rune) bool { diff --git a/token.go b/token.go index 36a3fc8..6af4ec4 100644 --- a/token.go +++ b/token.go @@ -1,9 +1,6 @@ package toml -import ( - "fmt" - "unicode" -) +import "fmt" // Define tokens type tokenType int @@ -112,7 +109,7 @@ func isSpace(r rune) bool { } func isAlphanumeric(r rune) bool { - return unicode.IsLetter(r) || r == '_' + return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' } func isKeyChar(r rune) bool { @@ -127,7 +124,7 @@ func isKeyStartChar(r rune) bool { } func isDigit(r rune) bool { - return unicode.IsNumber(r) + return '0' <= r && r <= '9' } func isHexDigit(r rune) bool { diff --git a/toml.go b/toml.go index fdc74d8..d2a6e3e 100644 --- a/toml.go +++ b/toml.go @@ -23,6 +23,7 @@ type Tree struct { values map[string]interface{} // string -> *tomlValue, *Tree, []*Tree comment string commented bool + inline bool position Position } @@ -414,6 +415,7 @@ func (t *Tree) createSubTree(keys []string, pos Position) error { if !exists { tree := newTreeWithPosition(Position{Line: t.position.Line + i, Col: t.position.Col}) tree.position = pos + tree.inline = subtree.inline subtree.values[intermediateKey] = tree nextTree = tree } diff --git a/toml_testgen_test.go b/toml_testgen_test.go index 688ae51..2306926 100644 --- a/toml_testgen_test.go +++ b/toml_testgen_test.go @@ -5,21 +5,6 @@ import ( "testing" ) -func TestInvalidArrayMixedTypesArraysAndInts(t *testing.T) { - input := `arrays-and-ints = [1, ["Arrays are not integers."]]` - testgenInvalid(t, input) -} - -func TestInvalidArrayMixedTypesIntsAndFloats(t *testing.T) { - input := `ints-and-floats = [1, 1.1]` - testgenInvalid(t, input) -} - -func TestInvalidArrayMixedTypesStringsAndInts(t *testing.T) { - input := `strings-and-ints = ["hi", 42]` - testgenInvalid(t, input) -} - func TestInvalidDatetimeMalformedNoLeads(t *testing.T) { input := `no-leads = 1987-7-05T17:45:00Z` testgenInvalid(t, input) diff --git a/tomltree_create_test.go b/tomltree_create_test.go index 3465a10..228d3dc 100644 --- a/tomltree_create_test.go +++ b/tomltree_create_test.go @@ -105,7 +105,7 @@ func TestTreeCreateToTreeInvalidTableGroupType(t *testing.T) { } func TestRoundTripArrayOfTables(t *testing.T) { - orig := "\n[[stuff]]\n name = \"foo\"\n things = [\"a\",\"b\"]\n" + orig := "\n[[stuff]]\n name = \"foo\"\n things = [\"a\", \"b\"]\n" tree, err := Load(orig) if err != nil { t.Fatalf("unexpected error: %s", err) diff --git a/tomltree_write.go b/tomltree_write.go index 9acc2f3..2d6487e 100644 --- a/tomltree_write.go +++ b/tomltree_write.go @@ -103,7 +103,30 @@ func encodeTomlString(value string) string { return b.String() } -func tomlValueStringRepresentation(v interface{}, commented string, indent string, arraysOneElementPerLine bool) (string, error) { +func tomlTreeStringRepresentation(t *Tree, ord marshalOrder) (string, error) { + var orderedVals []sortNode + switch ord { + case OrderPreserve: + orderedVals = sortByLines(t) + default: + orderedVals = sortAlphabetical(t) + } + + var values []string + for _, node := range orderedVals { + k := node.key + v := t.values[k] + + repr, err := tomlValueStringRepresentation(v, "", "", ord, false) + if err != nil { + return "", err + } + values = append(values, quoteKeyIfNeeded(k)+" = "+repr) + } + return "{ " + strings.Join(values, ", ") + " }", nil +} + +func tomlValueStringRepresentation(v interface{}, commented string, indent string, ord marshalOrder, arraysOneElementPerLine bool) (string, error) { // this interface check is added to dereference the change made in the writeTo function. // That change was made to allow this function to see formatting options. tv, ok := v.(*tomlValue) @@ -140,7 +163,7 @@ func tomlValueStringRepresentation(v interface{}, commented string, indent strin return "\"" + encodeTomlString(value) + "\"", nil case []byte: b, _ := v.([]byte) - return tomlValueStringRepresentation(string(b), commented, indent, arraysOneElementPerLine) + return tomlValueStringRepresentation(string(b), commented, indent, ord, arraysOneElementPerLine) case bool: if value { return "true", nil @@ -154,6 +177,8 @@ func tomlValueStringRepresentation(v interface{}, commented string, indent strin return value.String(), nil case LocalTime: return value.String(), nil + case *Tree: + return tomlTreeStringRepresentation(value, ord) case nil: return "", nil } @@ -164,7 +189,7 @@ func tomlValueStringRepresentation(v interface{}, commented string, indent strin var values []string for i := 0; i < rv.Len(); i++ { item := rv.Index(i).Interface() - itemRepr, err := tomlValueStringRepresentation(item, commented, indent, arraysOneElementPerLine) + itemRepr, err := tomlValueStringRepresentation(item, commented, indent, ord, arraysOneElementPerLine) if err != nil { return "", err } @@ -187,7 +212,7 @@ func tomlValueStringRepresentation(v interface{}, commented string, indent strin return stringBuffer.String(), nil } - return "[" + strings.Join(values, ",") + "]", nil + return "[" + strings.Join(values, ", ") + "]", nil } return "", fmt.Errorf("unsupported value type %T: %v", v, v) } @@ -282,10 +307,10 @@ func sortAlphabetical(t *Tree) (vals []sortNode) { } func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool) (int64, error) { - return t.writeToOrdered(w, indent, keyspace, bytesCount, arraysOneElementPerLine, OrderAlphabetical, false) + return t.writeToOrdered(w, indent, keyspace, bytesCount, arraysOneElementPerLine, OrderAlphabetical, " ", false) } -func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool, ord marshalOrder, parentCommented bool) (int64, error) { +func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool, ord marshalOrder, indentString string, parentCommented bool) (int64, error) { var orderedVals []sortNode switch ord { @@ -301,7 +326,7 @@ func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount i k := node.key v := t.values[k] - combinedKey := k + combinedKey := quoteKeyIfNeeded(k) if keyspace != "" { combinedKey = keyspace + "." + combinedKey } @@ -335,7 +360,7 @@ func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount i if err != nil { return bytesCount, err } - bytesCount, err = node.writeToOrdered(w, indent+" ", combinedKey, bytesCount, arraysOneElementPerLine, ord, parentCommented || t.commented || tv.commented) + bytesCount, err = node.writeToOrdered(w, indent+indentString, combinedKey, bytesCount, arraysOneElementPerLine, ord, indentString, parentCommented || t.commented || tv.commented) if err != nil { return bytesCount, err } @@ -351,7 +376,7 @@ func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount i return bytesCount, err } - bytesCount, err = subTree.writeToOrdered(w, indent+" ", combinedKey, bytesCount, arraysOneElementPerLine, ord, parentCommented || t.commented || subTree.commented) + bytesCount, err = subTree.writeToOrdered(w, indent+indentString, combinedKey, bytesCount, arraysOneElementPerLine, ord, indentString, parentCommented || t.commented || subTree.commented) if err != nil { return bytesCount, err } @@ -368,7 +393,7 @@ func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount i if parentCommented || t.commented || v.commented { commented = "# " } - repr, err := tomlValueStringRepresentation(v, commented, indent, arraysOneElementPerLine) + repr, err := tomlValueStringRepresentation(v, commented, indent, ord, arraysOneElementPerLine) if err != nil { return bytesCount, err }