diff --git a/marshal.go b/marshal.go index 27867fb..4155c40 100644 --- a/marshal.go +++ b/marshal.go @@ -91,23 +91,25 @@ func isPrimitive(mtype reflect.Type) bool { } } -// Check if the given marshal type maps to a Tree slice -func isTreeSlice(mtype reflect.Type) bool { +// Check if the given marshal type maps to a Tree slice or array +func isTreeSequence(mtype reflect.Type) bool { switch mtype.Kind() { - case reflect.Slice: - return !isOtherSlice(mtype) + case reflect.Ptr: + return isTreeSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return isTree(mtype.Elem()) default: return false } } -// Check if the given marshal type maps to a non-Tree slice -func isOtherSlice(mtype reflect.Type) bool { +// Check if the given marshal type maps to a non-Tree slice or array +func isOtherSequence(mtype reflect.Type) bool { switch mtype.Kind() { case reflect.Ptr: - return isOtherSlice(mtype.Elem()) - case reflect.Slice: - return isPrimitive(mtype.Elem()) || isOtherSlice(mtype.Elem()) + return isOtherSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return !isTreeSequence(mtype) default: return false } @@ -116,6 +118,8 @@ func isOtherSlice(mtype reflect.Type) bool { // Check if the given marshal type maps to a Tree func isTree(mtype reflect.Type) bool { switch mtype.Kind() { + case reflect.Ptr: + return isTree(mtype.Elem()) case reflect.Map: return true case reflect.Struct: @@ -406,9 +410,9 @@ func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface return callCustomMarshaler(mval) case isTree(mtype): return e.valueToTree(mtype, mval) - case isTreeSlice(mtype): + case isTreeSequence(mtype): return e.valueToTreeSlice(mtype, mval) - case isOtherSlice(mtype): + case isOtherSequence(mtype): return e.valueToOtherSlice(mtype, mval) default: switch mtype.Kind() { @@ -687,12 +691,12 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a tree", tval, tval) case []*Tree: - if isTreeSlice(mtype) { + if isTreeSequence(mtype) { return d.valueFromTreeSlice(mtype, t) } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval) case []interface{}: - if isOtherSlice(mtype) { + if isOtherSequence(mtype) { return d.valueFromOtherSlice(mtype, t) } return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) diff --git a/marshal_test.go b/marshal_test.go index 30f5bc8..3887ab9 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -433,12 +433,36 @@ type tomlTypeCheckTest struct { func TestTypeChecks(t *testing.T) { tests := []tomlTypeCheckTest{ - {"integer", 2, 0}, + {"bool", true, 0}, + {"bool", false, 0}, + {"int", int(2), 0}, + {"int8", int8(2), 0}, + {"int16", int16(2), 0}, + {"int32", int32(2), 0}, + {"int64", int64(2), 0}, + {"uint", uint(2), 0}, + {"uint8", uint8(2), 0}, + {"uint16", uint16(2), 0}, + {"uint32", uint32(2), 0}, + {"uint64", uint64(2), 0}, + {"float32", float32(3.14), 0}, + {"float64", float64(3.14), 0}, + {"string", "lorem ipsum", 0}, {"time", time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), 0}, {"stringlist", []string{"hello", "hi"}, 1}, + {"stringlistptr", &[]string{"hello", "hi"}, 1}, + {"stringarray", [2]string{"hello", "hi"}, 1}, + {"stringarrayptr", &[2]string{"hello", "hi"}, 1}, {"timelist", []time.Time{time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, 1}, + {"timelistptr", &[]time.Time{time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, 1}, + {"timearray", [1]time.Time{time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, 1}, + {"timearrayptr", &[1]time.Time{time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, 1}, {"objectlist", []tomlTypeCheckTest{}, 2}, + {"objectlistptr", &[]tomlTypeCheckTest{}, 2}, + {"objectarray", [2]tomlTypeCheckTest{{}, {}}, 2}, + {"objectlistptr", &[2]tomlTypeCheckTest{{}, {}}, 2}, {"object", tomlTypeCheckTest{}, 3}, + {"objectptr", &tomlTypeCheckTest{}, 3}, } for _, test := range tests { @@ -446,8 +470,8 @@ func TestTypeChecks(t *testing.T) { expected[test.typ] = true result := []bool{ isPrimitive(reflect.TypeOf(test.item)), - isOtherSlice(reflect.TypeOf(test.item)), - isTreeSlice(reflect.TypeOf(test.item)), + isOtherSequence(reflect.TypeOf(test.item)), + isTreeSequence(reflect.TypeOf(test.item)), isTree(reflect.TypeOf(test.item)), } if !reflect.DeepEqual(expected, result) { @@ -1699,3 +1723,58 @@ func TestTreeMarshal(t *testing.T) { }) } } + +func TestMarshalArrays(t *testing.T) { + cases := []struct { + Data interface{} + Expected string + }{ + { + Data: struct { + XY [2]int + }{ + XY: [2]int{1, 2}, + }, + Expected: `XY = [1,2] +`, + }, + { + Data: struct { + XY [1][2]int + }{ + XY: [1][2]int{{1, 2}}, + }, + Expected: `XY = [[1,2]] +`, + }, + { + Data: struct { + XY [1][]int + }{ + XY: [1][]int{{1, 2}}, + }, + Expected: `XY = [[1,2]] +`, + }, + { + Data: struct { + XY [][2]int + }{ + XY: [][2]int{{1, 2}}, + }, + Expected: `XY = [[1,2]] +`, + }, + } + for _, tc := range cases { + t.Run("", func(t *testing.T) { + result, err := Marshal(tc.Data) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, []byte(tc.Expected)) { + t.Errorf("Bad marshal: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", []byte(tc.Expected), result) + } + }) + } +}