From 37714006b6168ee81ac9834e2b630f41e2705692 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Thu, 8 Apr 2021 10:07:29 -0400 Subject: [PATCH] V2 Marshaler MVP (#495) --- README.md | 7 +- .../imported_tests/marshal_imported_test.go | 166 +++++ marshaler.go | 639 ++++++++++++++++++ marshaler_test.go | 215 ++++++ parser.go | 261 ------- targets.go | 9 +- toml_testgen_support_test.go | 19 +- unmarshaler.go | 15 + 8 files changed, 1058 insertions(+), 273 deletions(-) create mode 100644 internal/imported_tests/marshal_imported_test.go create mode 100644 marshaler.go create mode 100644 marshaler_test.go diff --git a/README.md b/README.md index a3cccfc..4083c81 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,12 @@ Development branch. Use at your own risk. ### Marshal -- [ ] Minimal implementation +- [x] Minimal implementation +- [ ] Multiline strings +- [ ] Multiline arrays +- [ ] `inline` tag for tables +- [ ] Optional indentation +- [ ] Option to pick default quotes ### Document diff --git a/internal/imported_tests/marshal_imported_test.go b/internal/imported_tests/marshal_imported_test.go new file mode 100644 index 0000000..98ad71f --- /dev/null +++ b/internal/imported_tests/marshal_imported_test.go @@ -0,0 +1,166 @@ +package imported_tests + +// Those tests have been imported from v1, but adjust to match the new +// defaults of v2. + +import ( + "testing" + "time" + + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" +) + +func TestDocMarshal(t *testing.T) { + type testDoc struct { + Title string `toml:"title"` + BasicLists testDocBasicLists `toml:"basic_lists"` + SubDocPtrs []*testSubDoc `toml:"subdocptrs"` + BasicMap map[string]string `toml:"basic_map"` + Subdocs testDocSubs `toml:"subdoc"` + Basics testDocBasics `toml:"basic"` + SubDocList []testSubDoc `toml:"subdoclist"` + err int `toml:"shouldntBeHere"` + unexported int `toml:"shouldntBeHere"` + Unexported2 int `toml:"-"` + } + + var docData = testDoc{ + Title: "TOML Marshal Testing", + unexported: 0, + Unexported2: 0, + Basics: testDocBasics{ + Bool: true, + Date: time.Date(1979, 5, 27, 7, 32, 0, 0, time.UTC), + Float32: 123.4, + Float64: 123.456782132399, + Int: 5000, + Uint: 5001, + String: &biteMe, + unexported: 0, + }, + BasicLists: testDocBasicLists{ + Floats: []*float32{&float1, &float2, &float3}, + Bools: []bool{true, false, true}, + Dates: []time.Time{ + time.Date(1979, 5, 27, 7, 32, 0, 0, time.UTC), + time.Date(1980, 5, 27, 7, 32, 0, 0, time.UTC), + }, + Ints: []int{8001, 8001, 8002}, + Strings: []string{"One", "Two", "Three"}, + UInts: []uint{5002, 5003}, + }, + BasicMap: map[string]string{ + "one": "one", + "two": "two", + }, + Subdocs: testDocSubs{ + First: testSubDoc{"First", 0}, + Second: &subdoc, + }, + SubDocList: []testSubDoc{ + {"List.First", 0}, + {"List.Second", 0}, + }, + SubDocPtrs: []*testSubDoc{&subdoc}, + } + + marshalTestToml := `title = 'TOML Marshal Testing' +[basic_lists] +floats = [12.3, 45.6, 78.9] +bools = [true, false, true] +dates = [1979-05-27T07:32:00Z, 1980-05-27T07:32:00Z] +ints = [8001, 8001, 8002] +uints = [5002, 5003] +strings = ['One', 'Two', 'Three'] + +[[subdocptrs]] +name = 'Second' + +[basic_map] +one = 'one' +two = 'two' + +[subdoc] +[subdoc.second] +name = 'Second' + +[subdoc.first] +name = 'First' + + +[basic] +uint = 5001 +bool = true +float = 123.4 +float64 = 123.456782132399 +int = 5000 +string = 'Bite me' +date = 1979-05-27T07:32:00Z + +[[subdoclist]] +name = 'List.First' +[[subdoclist]] +name = 'List.Second' + +` + + result, err := toml.Marshal(docData) + require.NoError(t, err) + require.Equal(t, marshalTestToml, string(result)) +} + +func TestBasicMarshalQuotedKey(t *testing.T) { + result, err := toml.Marshal(quotedKeyMarshalTestData) + require.NoError(t, err) + + expected := `'Z.string-àéù' = 'Hello' +'Yfloat-𝟘' = 3.5 +['Xsubdoc-àéù'] +String2 = 'One' + +[['W.sublist-𝟘']] +String2 = 'Two' +[['W.sublist-𝟘']] +String2 = 'Three' + +` + + require.Equal(t, string(expected), string(result)) + +} + +func TestEmptyMarshal(t *testing.T) { + type emptyMarshalTestStruct struct { + Title string `toml:"title"` + Bool bool `toml:"bool"` + Int int `toml:"int"` + String string `toml:"string"` + StringList []string `toml:"stringlist"` + Ptr *basicMarshalTestStruct `toml:"ptr"` + Map map[string]string `toml:"map"` + } + + doc := emptyMarshalTestStruct{ + Title: "Placeholder", + Bool: false, + Int: 0, + String: "", + StringList: []string{}, + Ptr: nil, + Map: map[string]string{}, + } + result, err := toml.Marshal(doc) + require.NoError(t, err) + + expected := `title = 'Placeholder' +bool = false +int = 0 +string = '' +stringlist = [] +[map] + +` + + require.Equal(t, string(expected), string(result)) +} diff --git a/marshaler.go b/marshaler.go new file mode 100644 index 0000000..fa8f341 --- /dev/null +++ b/marshaler.go @@ -0,0 +1,639 @@ +package toml + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// Marshal serializes a Go value as a TOML document. +// +// It is a shortcut for Encoder.Encode() with the default options. +func Marshal(v interface{}) ([]byte, error) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.Encode(v) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Encoder writes a TOML document to an output stream. +type Encoder struct { + w io.Writer +} + +type encoderCtx struct { + // Current top-level key. + parentKey []string + + // Key that should be used for a KV. + key string + // Extra flag to account for the empty string + hasKey bool + + // Set to true to indicate that the encoder is inside a KV, so that all + // tables need to be inlined. + insideKv bool + + // Set to true to skip the first table header in an array table. + skipTableHeader bool +} + +func (ctx *encoderCtx) shiftKey() { + if ctx.hasKey { + ctx.parentKey = append(ctx.parentKey, ctx.key) + ctx.clearKey() + } +} + +func (ctx *encoderCtx) setKey(k string) { + ctx.key = k + ctx.hasKey = true +} + +func (ctx *encoderCtx) clearKey() { + ctx.key = "" + ctx.hasKey = false +} + +// NewEncoder returns a new Encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: w, + } +} + +// Encode writes a TOML representation of v to the stream. +// +// If v cannot be represented to TOML it returns an error. +// +// Encoding rules: +// +// 1. A top level slice containing only maps or structs is encoded as [[table +// array]]. +// +// 2. All slices not matching rule 1 are encoded as [array]. As a result, any +// map or struct they contain is encoded as an {inline table}. +// +// 3. Nil interfaces and nil pointers are not supported. +// +// 4. Keys in key-values always have one part. +// +// 5. Intermediate tables are always printed. +func (enc *Encoder) Encode(v interface{}) error { + var b []byte + var ctx encoderCtx + b, err := enc.encode(b, ctx, reflect.ValueOf(v)) + if err != nil { + return err + } + _, err = enc.w.Write(b) + return err +} + +func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + switch i := v.Interface().(type) { + case time.Time: // TODO: add TextMarshaler + b = i.AppendFormat(b, time.RFC3339) + return b, nil + } + + // containers + switch v.Kind() { + case reflect.Map: + return enc.encodeMap(b, ctx, v) + case reflect.Struct: + return enc.encodeStruct(b, ctx, v) + case reflect.Slice: + return enc.encodeSlice(b, ctx, v) + case reflect.Interface: + if v.IsNil() { + return nil, errNilInterface + } + return enc.encode(b, ctx, v.Elem()) + case reflect.Ptr: + if v.IsNil() { + return enc.encode(b, ctx, reflect.Zero(v.Type().Elem())) + } + return enc.encode(b, ctx, v.Elem()) + } + + // values + var err error + switch v.Kind() { + case reflect.String: + b, err = enc.encodeString(b, v.String()) + case reflect.Float32: + b = strconv.AppendFloat(b, v.Float(), 'f', -1, 32) + case reflect.Float64: + b = strconv.AppendFloat(b, v.Float(), 'f', -1, 64) + case reflect.Bool: + if v.Bool() { + b = append(b, "true"...) + } else { + b = append(b, "false"...) + } + case reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint: + b = strconv.AppendUint(b, v.Uint(), 10) + case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int: + b = strconv.AppendInt(b, v.Int(), 10) + default: + err = fmt.Errorf("unsupported encode value kind: %s", v.Kind()) + } + if err != nil { + return nil, err + } + + return b, nil +} + +func isNil(v reflect.Value) bool { + switch v.Kind() { + case reflect.Ptr, reflect.Interface, reflect.Map: + return v.IsNil() + default: + return false + } +} + +func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + var err error + + if !ctx.hasKey { + panic("caller of encodeKv should have set the key in the context") + } + + if isNil(v) { + return b, nil + } + + b, err = enc.encodeKey(b, ctx.key) + if err != nil { + return nil, err + } + + b = append(b, " = "...) + + // create a copy of the context because the value of a KV shouldn't + // modify the global context. + subctx := ctx + subctx.insideKv = true + subctx.shiftKey() + + b, err = enc.encode(b, subctx, v) + if err != nil { + return nil, err + } + + return b, nil +} + +const literalQuote = '\'' + +func (enc *Encoder) encodeString(b []byte, v string) ([]byte, error) { + if needsQuoting(v) { + b = enc.encodeQuotedString(b, v) + } else { + b = enc.encodeLiteralString(b, v) + } + return b, nil +} + +func needsQuoting(v string) bool { + return strings.ContainsAny(v, "'\b\f\n\r\t") +} + +// caller should have checked that the string does not contain new lines or ' +func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte { + b = append(b, literalQuote) + b = append(b, v...) + b = append(b, literalQuote) + return b +} + +func (enc *Encoder) encodeQuotedString(b []byte, v string) []byte { + const stringQuote = '"' + + b = append(b, stringQuote) + + for _, r := range v { + switch r { + case '\\': + b = append(b, `\\`...) + continue + case '"': + b = append(b, `\"`...) + continue + case '\b': + b = append(b, `\b`...) + continue + case '\f': + b = append(b, `\f`...) + continue + case '\n': + b = append(b, `\n`...) + continue + case '\r': + b = append(b, `\r`...) + continue + case '\t': + b = append(b, `\t`...) + continue + } + if r == 0x20 || r == 0x09 || r == 0x21 || (r >= 0x23 && r <= 0x5B) || (r >= 0x5D && r <= 0x7E) { + b = append(b, byte(r)) + } else if (r >= 0x80 && r <= 0xD7FF) || (r >= 0xE000 && r <= 0x10FFFF) { + l := utf8.RuneLen(r) + buf := make([]byte, l) + utf8.EncodeRune(buf, r) + b = append(b, buf...) + } else { + var h []byte + if r > 0xFFFF { + h = []byte(fmt.Sprintf("%08x", r)) + + } else { + h = []byte(fmt.Sprintf("%04x", r)) + } + b = append(b, `\u`...) + b = append(b, h...) + } + } + + b = append(b, stringQuote) + return b +} + +// called should have checked that the string is in A-Z / a-z / 0-9 / - / _ +func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte { + return append(b, v...) +} + +func (enc *Encoder) encodeTableHeader(b []byte, key []string) ([]byte, error) { + if len(key) == 0 { + return b, nil + } + + b = append(b, '[') + + var err error + b, err = enc.encodeKey(b, key[0]) + if err != nil { + return nil, err + } + + for _, k := range key[1:] { + b = append(b, '.') + b, err = enc.encodeKey(b, k) + if err != nil { + return nil, err + } + } + + b = append(b, "]\n"...) + + return b, nil +} + +func (enc *Encoder) encodeKey(b []byte, k string) ([]byte, error) { + needsQuotation := false + cannotUseLiteral := false + + for _, c := range k { + if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { + continue + } + if c == '\n' { + return nil, fmt.Errorf("TOML does not support multiline keys") + } + if c == literalQuote { + cannotUseLiteral = true + } + needsQuotation = true + } + + if cannotUseLiteral { + b = enc.encodeQuotedString(b, k) + } else if needsQuotation { + b = enc.encodeLiteralString(b, k) + } else { + b = enc.encodeUnquotedKey(b, k) + } + + return b, nil +} + +func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + if v.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("type '%s' not supported as map key", v.Type().Key().Kind()) + } + + t := table{} + + iter := v.MapRange() + for iter.Next() { + k := iter.Key().String() + v := iter.Value() + + if isNil(v) { + continue + } + + table, err := willConvertToTableOrArrayTable(v) + if err != nil { + return nil, err + } + + if table { + t.pushTable(k, v) + } else { + t.pushKV(k, v) + } + } + + sortEntriesByKey(t.kvs) + sortEntriesByKey(t.tables) + + return enc.encodeTable(b, ctx, t) +} + +func sortEntriesByKey(e []entry) { + sort.Slice(e, func(i, j int) bool { + return e[i].Key < e[j].Key + }) +} + +type entry struct { + Key string + Value reflect.Value +} + +type table struct { + kvs []entry + tables []entry +} + +func (t *table) pushKV(k string, v reflect.Value) { + t.kvs = append(t.kvs, entry{Key: k, Value: v}) +} + +func (t *table) pushTable(k string, v reflect.Value) { + t.tables = append(t.tables, entry{Key: k, Value: v}) +} + +func (t *table) hasKVs() bool { + return len(t.kvs) > 0 +} + +func (enc *Encoder) encodeStruct(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + t := table{} + + // TODO: cache this? + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + fieldType := typ.Field(i) + + // only consider exported fields + if fieldType.PkgPath != "" { + continue + } + + k, ok := fieldType.Tag.Lookup("toml") + if !ok { + k = fieldType.Name + } + + // special field name to skip field + if k == "-" { + continue + } + + f := v.Field(i) + + if isNil(f) { + continue + } + + willConvert, err := willConvertToTableOrArrayTable(f) + if err != nil { + return nil, err + } + + if willConvert { + t.pushTable(k, f) + } else { + t.pushKV(k, f) + } + } + + return enc.encodeTable(b, ctx, t) +} + +func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, error) { + var err error + + ctx.shiftKey() + + if ctx.insideKv { + b = append(b, '{') + + first := true + for _, kv := range t.kvs { + if first { + first = false + } else { + b = append(b, `, `...) + } + ctx.setKey(kv.Key) + b, err = enc.encodeKv(b, ctx, kv.Value) + if err != nil { + return nil, err + } + } + + for _, table := range t.tables { + if first { + first = false + } else { + b = append(b, `, `...) + } + ctx.setKey(table.Key) + b, err = enc.encode(b, ctx, table.Value) + if err != nil { + return nil, err + } + b = append(b, '\n') + } + + b = append(b, "}\n"...) + return b, nil + } + + if !ctx.skipTableHeader { + b, err = enc.encodeTableHeader(b, ctx.parentKey) + if err != nil { + return nil, err + } + } + ctx.skipTableHeader = false + + for _, kv := range t.kvs { + ctx.setKey(kv.Key) + b, err = enc.encodeKv(b, ctx, kv.Value) + if err != nil { + return nil, err + } + b = append(b, '\n') + } + + for _, table := range t.tables { + ctx.setKey(table.Key) + b, err = enc.encode(b, ctx, table.Value) + if err != nil { + return nil, err + } + b = append(b, '\n') + } + + return b, nil +} + +var errNilInterface = errors.New("nil interface not supported") +var errNilPointer = errors.New("nil pointer not supported") + +func willConvertToTable(v reflect.Value) (bool, error) { + switch v.Interface().(type) { + case time.Time: // TODO: add TextMarshaler + return false, nil + } + + t := v.Type() + switch t.Kind() { + case reflect.Map, reflect.Struct: + return true, nil + case reflect.Interface: + if v.IsNil() { + return false, errNilInterface + } + return willConvertToTable(v.Elem()) + case reflect.Ptr: + if v.IsNil() { + return false, nil + } + return willConvertToTable(v.Elem()) + default: + return false, nil + } +} + +func willConvertToTableOrArrayTable(v reflect.Value) (bool, error) { + t := v.Type() + + if t.Kind() == reflect.Interface { + if v.IsNil() { + return false, errNilInterface + } + return willConvertToTableOrArrayTable(v.Elem()) + } + + if t.Kind() == reflect.Slice { + if v.Len() == 0 { + // An empty slice should be a kv = []. + return false, nil + } + for i := 0; i < v.Len(); i++ { + t, err := willConvertToTable(v.Index(i)) + if err != nil { + return false, err + } + if !t { + return false, nil + } + } + return true, nil + } + + return willConvertToTable(v) +} + +func (enc *Encoder) encodeSlice(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + if v.Len() == 0 { + b = append(b, "[]"...) + return b, nil + } + + allTables, err := willConvertToTableOrArrayTable(v) + if err != nil { + return nil, err + } + + if allTables { + return enc.encodeSliceAsArrayTable(b, ctx, v) + } + + return enc.encodeSliceAsArray(b, ctx, v) +} + +// caller should have checked that v is a slice that only contains values that +// encode into tables. +func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + if v.Len() == 0 { + return b, nil + } + + ctx.shiftKey() + + var err error + scratch := make([]byte, 0, 64) + scratch = append(scratch, "[["...) + for i, k := range ctx.parentKey { + if i > 0 { + scratch = append(scratch, '.') + } + scratch, err = enc.encodeKey(scratch, k) + if err != nil { + return nil, err + } + } + scratch = append(scratch, "]]\n"...) + ctx.skipTableHeader = true + + for i := 0; i < v.Len(); i++ { + b = append(b, scratch...) + b, err = enc.encode(b, ctx, v.Index(i)) + if err != nil { + return nil, err + } + } + return b, nil +} + +func (enc *Encoder) encodeSliceAsArray(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { + b = append(b, '[') + + var err error + first := true + for i := 0; i < v.Len(); i++ { + if !first { + b = append(b, ", "...) + } + first = false + + b, err = enc.encode(b, ctx, v.Index(i)) + if err != nil { + return nil, err + } + } + + b = append(b, ']') + return b, nil +} diff --git a/marshaler_test.go b/marshaler_test.go new file mode 100644 index 0000000..0ec5066 --- /dev/null +++ b/marshaler_test.go @@ -0,0 +1,215 @@ +package toml_test + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMarshal(t *testing.T) { + examples := []struct { + desc string + v interface{} + expected string + err bool + }{ + { + desc: "simple map and string", + v: map[string]string{ + "hello": "world", + }, + expected: "hello = 'world'", + }, + { + desc: "map with new line in key", + v: map[string]string{ + "hel\nlo": "world", + }, + err: true, + }, + { + desc: `map with " in key`, + v: map[string]string{ + `hel"lo`: "world", + }, + expected: `'hel"lo' = 'world'`, + }, + { + desc: "map in map and string", + v: map[string]map[string]string{ + "table": { + "hello": "world", + }, + }, + expected: ` +[table] +hello = 'world'`, + }, + { + desc: "map in map in map and string", + v: map[string]map[string]map[string]string{ + "this": { + "is": { + "a": "test", + }, + }, + }, + expected: ` +[this] +[this.is] +a = 'test'`, + }, + { + // TODO: this test is flaky because output changes depending on + // the map iteration order. + desc: "map in map in map and string with values", + v: map[string]interface{}{ + "this": map[string]interface{}{ + "is": map[string]string{ + "a": "test", + }, + "also": "that", + }, + }, + expected: ` +[this] +also = 'that' +[this.is] +a = 'test'`, + }, + { + desc: "simple string array", + v: map[string][]string{ + "array": {"one", "two", "three"}, + }, + expected: `array = ['one', 'two', 'three']`, + }, + { + desc: "nested string arrays", + v: map[string][][]string{ + "array": {{"one", "two"}, {"three"}}, + }, + expected: `array = [['one', 'two'], ['three']]`, + }, + { + desc: "mixed strings and nested string arrays", + v: map[string][]interface{}{ + "array": {"a string", []string{"one", "two"}, "last"}, + }, + expected: `array = ['a string', ['one', 'two'], 'last']`, + }, + { + desc: "slice of maps", + v: map[string][]map[string]string{ + "top": { + {"map1.1": "v1.1"}, + {"map2.1": "v2.1"}, + }, + }, + expected: ` +[[top]] +'map1.1' = 'v1.1' +[[top]] +'map2.1' = 'v2.1' +`, + }, + { + desc: "map with two keys", + v: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + expected: ` +key1 = 'value1' +key2 = 'value2'`, + }, + { + desc: "simple struct", + v: struct { + A string + }{ + A: "foo", + }, + expected: `A = 'foo'`, + }, + { + desc: "one level of structs within structs", + v: struct { + A interface{} + }{ + A: struct { + K1 string + K2 string + }{ + K1: "v1", + K2: "v2", + }, + }, + expected: ` +[A] +K1 = 'v1' +K2 = 'v2' +`, + }, + { + desc: "structs in slice with interfaces", + v: map[string]interface{}{ + "root": map[string]interface{}{ + "nested": []interface{}{ + map[string]interface{}{"name": "Bob"}, + map[string]interface{}{"name": "Alice"}, + }, + }, + }, + expected: ` +[root] +[[root.nested]] +name = 'Bob' +[[root.nested]] +name = 'Alice' +`, + }, + } + + for _, e := range examples { + t.Run(e.desc, func(t *testing.T) { + b, err := toml.Marshal(e.v) + if e.err { + require.Error(t, err) + } else { + require.NoError(t, err) + equalStringsIgnoreNewlines(t, e.expected, string(b)) + } + }) + } +} + +func equalStringsIgnoreNewlines(t *testing.T, expected string, actual string) { + t.Helper() + cutset := "\n" + assert.Equal(t, strings.Trim(expected, cutset), strings.Trim(actual, cutset)) +} + +func TestIssue436(t *testing.T) { + data := []byte(`{"a": [ { "b": { "c": "d" } } ]}`) + + var v interface{} + err := json.Unmarshal(data, &v) + require.NoError(t, err) + + var buf bytes.Buffer + err = toml.NewEncoder(&buf).Encode(v) + require.NoError(t, err) + + expected := ` +[[a]] +[a.b] +c = 'd' +` + equalStringsIgnoreNewlines(t, expected, buf.String()) +} diff --git a/parser.go b/parser.go index bb9a37e..89fadf6 100644 --- a/parser.go +++ b/parser.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "strconv" - "time" "github.com/pelletier/go-toml/v2/internal/ast" ) @@ -793,266 +792,6 @@ func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { }), b[i:], nil } -func (p *parser) parseDateTime(b []byte) ([]byte, error) { - // we know the first 2 are digits. - if b[2] == ':' { - return p.parseTime(b) - } - // This state accepts an offset date-time, a local date-time, or a local date. - // - // 1979-05-27T07:32:00Z - // 1979-05-27T00:32:00-07:00 - // 1979-05-27T00:32:00.999999-07:00 - // 1979-05-27 07:32:00Z - // 1979-05-27 00:32:00-07:00 - // 1979-05-27 00:32:00.999999-07:00 - // 1979-05-27T07:32:00 - // 1979-05-27T00:32:00.999999 - // 1979-05-27 07:32:00 - // 1979-05-27 00:32:00.999999 - // 1979-05-27 - - // date - - idx := 4 - - localDate := LocalDate{ - Year: digitsToInt(b[:idx]), - } - - for i := 0; i < 2; i++ { - // month - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid month digit in date: %c", b[idx]) - } - localDate.Month *= 10 - localDate.Month += time.Month(b[idx] - '0') - } - - idx++ - if b[idx] != '-' { - return nil, fmt.Errorf("expected - to separate month of a date, not %c", b[idx]) - } - - for i := 0; i < 2; i++ { - // day - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid day digit in date: %c", b[idx]) - } - localDate.Day *= 10 - localDate.Day += int(b[idx] - '0') - } - - idx++ - - if idx >= len(b) { - //p.builder.LocalDateValue(localDate) - // TODO - return nil, nil - } else if b[idx] != ' ' && b[idx] != 'T' { - //p.builder.LocalDateValue(localDate) - // TODO - return b[idx:], nil - } - - // check if there is a chance there is anything useful after - if b[idx] == ' ' && (((idx + 2) >= len(b)) || !isDigit(b[idx+1]) || !isDigit(b[idx+2])) { - //p.builder.LocalDateValue(localDate) - // TODO - return b[idx:], nil - } - - //idx++ // skip the T or ' ' - - // time - localTime := LocalTime{} - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid hour digit in time: %c", b[idx]) - } - localTime.Hour *= 10 - localTime.Hour += int(b[idx] - '0') - } - - idx++ - if b[idx] != ':' { - return nil, fmt.Errorf("time hour/minute separator should be :, not %c", b[idx]) - } - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid minute digit in time: %c", b[idx]) - } - localTime.Minute *= 10 - localTime.Minute += int(b[idx] - '0') - } - - idx++ - if b[idx] != ':' { - return nil, fmt.Errorf("time minute/second separator should be :, not %c", b[idx]) - } - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid second digit in time: %c", b[idx]) - } - localTime.Second *= 10 - localTime.Second += int(b[idx] - '0') - } - - idx++ - if idx < len(b) && b[idx] == '.' { - idx++ - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("expected at least one digit in time's fraction, not %c", b[idx]) - } - - for { - localTime.Nanosecond *= 10 - localTime.Nanosecond += int(b[idx] - '0') - idx++ - - if idx < len(b) { - break - } - - if !isDigit(b[idx]) { - break - } - } - } - - if idx >= len(b) || (b[idx] != 'Z' && b[idx] != '+' && b[idx] != '-') { - dt := LocalDateTime{ - Date: localDate, - Time: localTime, - } - //p.builder.LocalDateTimeValue(dt) - // TODO - dt = dt - return b[idx:], nil - } - - loc := time.UTC - - if b[idx] == 'Z' { - idx++ - } else { - start := idx - sign := 1 - if b[idx] == '-' { - sign = -1 - } - - hours := 0 - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid hour digit in time offset: %c", b[idx]) - } - hours *= 10 - hours += int(b[idx] - '0') - } - offset := hours * 60 * 60 - - idx++ - if b[idx] != ':' { - return nil, fmt.Errorf("time offset hour/minute separator should be :, not %c", b[idx]) - } - - minutes := 0 - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid minute digit in time offset: %c", b[idx]) - } - minutes *= 10 - minutes += int(b[idx] - '0') - } - offset += minutes * 60 - offset *= sign - idx++ - loc = time.FixedZone(string(b[start:idx]), offset) - } - dt := time.Date(localDate.Year, localDate.Month, localDate.Day, localTime.Hour, localTime.Minute, localTime.Second, localTime.Nanosecond, loc) - //p.builder.DateTimeValue(dt) - // TODO - dt = dt - return b[idx:], nil -} - -func (p *parser) parseTime(b []byte) ([]byte, error) { - localTime := LocalTime{} - - idx := 0 - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid hour digit in time: %c", b[idx]) - } - localTime.Hour *= 10 - localTime.Hour += int(b[idx] - '0') - } - - idx++ - if b[idx] != ':' { - return nil, fmt.Errorf("time hour/minute separator should be :, not %c", b[idx]) - } - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid minute digit in time: %c", b[idx]) - } - localTime.Minute *= 10 - localTime.Minute += int(b[idx] - '0') - } - - idx++ - if b[idx] != ':' { - return nil, fmt.Errorf("time minute/second separator should be :, not %c", b[idx]) - } - - for i := 0; i < 2; i++ { - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("invalid second digit in time: %c", b[idx]) - } - localTime.Second *= 10 - localTime.Second += int(b[idx] - '0') - } - - idx++ - if idx < len(b) && b[idx] == '.' { - idx++ - idx++ - if !isDigit(b[idx]) { - return nil, fmt.Errorf("expected at least one digit in time's fraction, not %c", b[idx]) - } - - for { - localTime.Nanosecond *= 10 - localTime.Nanosecond += int(b[idx] - '0') - idx++ - if !isDigit(b[idx]) { - break - } - } - } - - //p.builder.LocalTimeValue(localTime) - // TODO - return b[idx:], nil -} - func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { i := 0 diff --git a/targets.go b/targets.go index d229020..a1e922f 100644 --- a/targets.go +++ b/targets.go @@ -356,6 +356,11 @@ func (d *decoder) scopeTableTarget(append bool, t target, name string) (target, case reflect.Struct: return scopeStruct(x, name) case reflect.Map: + if x.IsNil() { + t.set(reflect.MakeMap(x.Type())) + x = t.get() + } + return scopeMap(x, name) default: panic(fmt.Errorf("can't scope on a %s", x.Kind())) @@ -442,10 +447,6 @@ func (d *decoder) scopeArray(append bool, t target) (target, error) { } func scopeMap(v reflect.Value, name string) (target, bool, error) { - if v.IsNil() { - v.Set(reflect.MakeMap(v.Type())) - } - k := reflect.ValueOf(name) keyType := v.Type().Key() diff --git a/toml_testgen_support_test.go b/toml_testgen_support_test.go index 5d4b0a0..5edd25b 100644 --- a/toml_testgen_support_test.go +++ b/toml_testgen_support_test.go @@ -38,6 +38,15 @@ func testgenValid(t *testing.T, input string, jsonRef string) { refDoc := testgenBuildRefDoc(jsonRef) require.Equal(t, refDoc, doc) + + out, err := toml.Marshal(doc) + require.NoError(t, err) + + doc2 := map[string]interface{}{} + err = toml.Unmarshal(out, &doc2) + require.NoError(t, err) + + require.Equal(t, refDoc, doc2) } type testGenDescNode struct { @@ -121,13 +130,9 @@ func testGenTranslateDesc(input interface{}) interface{} { } } - var dest interface{} - if len(d) > 0 { - x := map[string]interface{}{} - for k, v := range d { - x[k] = testGenTranslateDesc(v) - } - dest = x + dest := map[string]interface{}{} + for k, v := range d { + dest[k] = testGenTranslateDesc(v) } return dest } diff --git a/unmarshaler.go b/unmarshaler.go index 1184e63..074d8ef 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -33,6 +33,9 @@ func NewDecoder(r io.Reader) *Decoder { // // When a TOML local date is decoded into a time.Time, its value is represented // in time.Local timezone. +// +// Empty tables decoded in an interface{} create an empty initialized +// map[string]interface{}. func (d *Decoder) Decode(v interface{}) error { b, err := ioutil.ReadAll(d.r) if err != nil { @@ -111,6 +114,18 @@ func (d *decoder) fromParser(p *parser, v interface{}) error { found = true case ast.Table: current, found, err = d.scopeWithKey(root, node.Key()) + if err == nil { + // In case this table points to an interface, + // make sure it at least holds something that + // looks like a table. Otherwise the information + // of a table is lost, and marshal cannot do the + // round trip. + v := current.get() + if v.Kind() == reflect.Interface && v.IsNil() { + newElement := reflect.MakeMap(mapStringInterfaceType) + current.set(newElement) + } + } case ast.ArrayTable: current, found, err = d.scopeWithArrayTable(root, node.Key()) default: