diff --git a/marshal.go b/marshal.go index 6ab587e..a9ee1c1 100644 --- a/marshal.go +++ b/marshal.go @@ -777,7 +777,11 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V // Convert toml value to marshal struct/map slice, using marshal type func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) { - mval := reflect.MakeSlice(mtype, len(tval), len(tval)) + mval, err := makeSliceOrArray(mtype, len(tval)) + if err != nil { + return mval, err + } + for i := 0; i < len(tval); i++ { d.visitor.push(strconv.Itoa(i)) val, err := d.valueFromTree(mtype.Elem(), tval[i], nil) @@ -792,7 +796,11 @@ func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect. // Convert toml value to marshal primitive slice, using marshal type func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) { - mval := reflect.MakeSlice(mtype, len(tval), len(tval)) + mval, err := makeSliceOrArray(mtype, len(tval)) + if err != nil { + return mval, err + } + for i := 0; i < len(tval); i++ { val, err := d.valueFromToml(mtype.Elem(), tval[i], nil) if err != nil { @@ -806,10 +814,14 @@ func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (r // Convert toml value to marshal primitive slice, using marshal type func (d *Decoder) valueFromOtherSliceI(mtype reflect.Type, tval interface{}) (reflect.Value, error) { val := reflect.ValueOf(tval) + length := val.Len() - lenght := val.Len() - mval := reflect.MakeSlice(mtype, lenght, lenght) - for i := 0; i < lenght; i++ { + mval, err := makeSliceOrArray(mtype, length) + if err != nil { + return mval, err + } + + for i := 0; i < length; i++ { val, err := d.valueFromToml(mtype.Elem(), val.Index(i).Interface(), nil) if err != nil { return mval, err @@ -819,6 +831,21 @@ func (d *Decoder) valueFromOtherSliceI(mtype reflect.Type, tval interface{}) (re return mval, nil } +// Create a new slice or a new array with specified length +func makeSliceOrArray(mtype reflect.Type, tLength int) (reflect.Value, error) { + var mval reflect.Value + switch mtype.Kind() { + case reflect.Slice: + mval = reflect.MakeSlice(mtype, tLength, tLength) + case reflect.Array: + mval = reflect.New(reflect.ArrayOf(mtype.Len(), mtype.Elem())).Elem() + if tLength > mtype.Len() { + return mval, fmt.Errorf("unmarshal: TOML array length (%v) exceeds destination array length (%v)", tLength, mtype.Len()) + } + } + return mval, nil +} + // Convert toml value to marshal value, using marshal type. When mval1 is non-nil // and the given type is a struct value, merge fields into it. func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) { @@ -972,7 +999,7 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref ival := mval1.Elem() return d.valueFromToml(mval1.Elem().Type(), t, &ival) } - case reflect.Slice: + case reflect.Slice, reflect.Array: if isOtherSequence(mtype) && isOtherSequence(reflect.TypeOf(t)) { return d.valueFromOtherSliceI(mtype, t) } @@ -1048,11 +1075,7 @@ func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { func isZero(val reflect.Value) bool { switch val.Type().Kind() { - case reflect.Map: - fallthrough - case reflect.Array: - fallthrough - case reflect.Slice: + case reflect.Slice, reflect.Array, reflect.Map: return val.Len() == 0 default: return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface()) diff --git a/marshal_test.go b/marshal_test.go index 4fa9600..7d6c025 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -2259,7 +2259,7 @@ func TestUnmarshalPreservesUnexportedFields(t *testing.T) { [[slice1]] exported1 = "visible3" - + [[slice1]] exported1 = "visible4" @@ -3122,6 +3122,20 @@ type sliceStruct struct { StructSlicePtr *[]basicMarshalTestSubStruct ` toml:"struct_slice_ptr" ` } +type arrayStruct struct { + Slice [4]string ` toml:"str_slice" ` + SlicePtr *[4]string ` toml:"str_slice_ptr" ` + IntSlice [4]int ` toml:"int_slice" ` + IntSlicePtr *[4]int ` toml:"int_slice_ptr" ` + StructSlice [4]basicMarshalTestSubStruct ` toml:"struct_slice" ` + StructSlicePtr *[4]basicMarshalTestSubStruct ` toml:"struct_slice_ptr" ` +} + +type arrayTooSmallStruct struct { + Slice [1]string ` toml:"str_slice" ` + StructSlice [1]basicMarshalTestSubStruct ` toml:"struct_slice" ` +} + func TestUnmarshalSlice(t *testing.T) { tree, _ := LoadBytes(sliceTomlDemo) tree, _ = TreeFromMap(tree.ToMap()) @@ -3168,6 +3182,75 @@ func TestUnmarshalSliceFail2(t *testing.T) { } +func TestUnmarshalArray(t *testing.T) { + var tree *Tree + var err error + + tree, _ = LoadBytes(sliceTomlDemo) + var actual1 arrayStruct + err = tree.Unmarshal(&actual1) + if err != nil { + t.Error("shound not err", err) + } + + tree, _ = TreeFromMap(tree.ToMap()) + var actual2 arrayStruct + err = tree.Unmarshal(&actual2) + if err != nil { + t.Error("shound not err", err) + } + + expected := arrayStruct{ + Slice: [4]string{"Howdy", "Hey There"}, + SlicePtr: &[4]string{"Howdy", "Hey There"}, + IntSlice: [4]int{1, 2}, + IntSlicePtr: &[4]int{1, 2}, + StructSlice: [4]basicMarshalTestSubStruct{{"1"}, {"2"}}, + StructSlicePtr: &[4]basicMarshalTestSubStruct{{"1"}, {"2"}}, + } + if !reflect.DeepEqual(actual1, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual1) + } + if !reflect.DeepEqual(actual2, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, actual2) + } +} + +func TestUnmarshalArrayFail(t *testing.T) { + tree, _ := TreeFromMap(map[string]interface{}{ + "str_slice": []string{"Howdy", "Hey There"}, + }) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(0, 0): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(0, 0): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + +func TestUnmarshalArrayFail2(t *testing.T) { + tree, _ := Load(`str_slice=["Howdy","Hey There"]`) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(1, 1): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(1, 1): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + +func TestUnmarshalArrayFail3(t *testing.T) { + tree, _ := Load(`[[struct_slice]] +String2="1" +[[struct_slice]] +String2="2"`) + + var actual arrayTooSmallStruct + err := tree.Unmarshal(&actual) + if err.Error() != "(3, 1): unmarshal: TOML array length (2) exceeds destination array length (1)" { + t.Error("expect err:(3, 1): unmarshal: TOML array length (2) exceeds destination array length (1) but got ", err) + } +} + func TestDecoderStrict(t *testing.T) { input := ` [decoded]