From 9ccd9bbc7ab0c4282e338a9a71c85b4bd4cba2b1 Mon Sep 17 00:00:00 2001 From: x-hgg-x <39058530+x-hgg-x@users.noreply.github.com> Date: Mon, 4 May 2020 21:05:45 +0200 Subject: [PATCH] Fix unmarshaler error when a custom marshaler function is defined (#383) Fixes #382 --- marshal.go | 30 +++++++++++++++++++++++++++--- marshal_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/marshal.go b/marshal.go index a6e9733..278608a 100644 --- a/marshal.go +++ b/marshal.go @@ -93,7 +93,7 @@ func isPrimitive(mtype reflect.Type) bool { case reflect.String: return true case reflect.Struct: - return isTimeType(mtype) || isCustomMarshaler(mtype) || isTextMarshaler(mtype) + return isTimeType(mtype) default: return false } @@ -115,6 +115,30 @@ func isTreeSequence(mtype reflect.Type) bool { } } +// Check if the given marshal type maps to a slice or array of a custom marshaler type +func isCustomMarshalerSequence(mtype reflect.Type) bool { + switch mtype.Kind() { + case reflect.Ptr: + return isCustomMarshalerSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return isCustomMarshaler(mtype.Elem()) || isCustomMarshaler(reflect.New(mtype.Elem()).Type()) + default: + return false + } +} + +// Check if the given marshal type maps to a slice or array of a text marshaler type +func isTextMarshalerSequence(mtype reflect.Type) bool { + switch mtype.Kind() { + case reflect.Ptr: + return isTextMarshalerSequence(mtype.Elem()) + case reflect.Slice, reflect.Array: + return isTextMarshaler(mtype.Elem()) || isTextMarshaler(reflect.New(mtype.Elem()).Type()) + default: + return false + } +} + // Check if the given marshal type maps to a non-Tree slice or array func isOtherSequence(mtype reflect.Type) bool { switch mtype.Kind() { @@ -516,10 +540,10 @@ func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface return callTextMarshaler(mval) case isTree(mtype): return e.valueToTree(mtype, mval) + case isOtherSequence(mtype), isCustomMarshalerSequence(mtype), isTextMarshalerSequence(mtype): + return e.valueToOtherSlice(mtype, mval) case isTreeSequence(mtype): return e.valueToTreeSlice(mtype, mval) - case isOtherSequence(mtype): - return e.valueToOtherSlice(mtype, mval) default: switch mtype.Kind() { case reflect.Bool: diff --git a/marshal_test.go b/marshal_test.go index 107413e..64dfa76 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -911,6 +911,9 @@ var nestedCustomMarshalerData = customMarshalerParent{ var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] me = "Maiku Suteda" `) +var nestedCustomMarshalerTomlForUnmarshal = []byte(`[friends] +FirstName = "Sally" +LastName = "Fields"`) func TestCustomMarshaler(t *testing.T) { result, err := Marshal(customMarshalerData) @@ -946,6 +949,26 @@ func TestTextMarshaler(t *testing.T) { } } +func TestUnmarshalTextMarshaler(t *testing.T) { + var nested = struct { + Friends textMarshaler `toml:"friends"` + }{} + + var expected = struct { + Friends textMarshaler `toml:"friends"` + }{ + Friends: textMarshaler{FirstName: "Sally", LastName: "Fields"}, + } + + err := Unmarshal(nestedCustomMarshalerTomlForUnmarshal, &nested) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(nested, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, nested) + } +} + func TestNestedTextMarshaler(t *testing.T) { var parent = struct { Self textMarshaler `toml:"me"` @@ -1002,7 +1025,7 @@ type customPointerMarshaler struct { LastName string } -func (m *customPointerMarshaler) MarshalText() ([]byte, error) { +func (m *customPointerMarshaler) MarshalTOML() ([]byte, error) { return []byte("hidden"), nil } @@ -1048,6 +1071,30 @@ stranger = "hidden" } } +func TestPointerCustomMarshalerSequence(t *testing.T) { + var customPointerMarshalerSlice *[]*customPointerMarshaler + var customPointerMarshalerArray *[2]*customPointerMarshaler + + if !isCustomMarshalerSequence(reflect.TypeOf(customPointerMarshalerSlice)) { + t.Errorf("error: should be a sequence of custom marshaler interfaces") + } + if !isCustomMarshalerSequence(reflect.TypeOf(customPointerMarshalerArray)) { + t.Errorf("error: should be a sequence of custom marshaler interfaces") + } +} + +func TestPointerTextMarshalerSequence(t *testing.T) { + var textPointerMarshalerSlice *[]*textPointerMarshaler + var textPointerMarshalerArray *[2]*textPointerMarshaler + + if !isTextMarshalerSequence(reflect.TypeOf(textPointerMarshalerSlice)) { + t.Errorf("error: should be a sequence of text marshaler interfaces") + } + if !isTextMarshalerSequence(reflect.TypeOf(textPointerMarshalerArray)) { + t.Errorf("error: should be a sequence of text marshaler interfaces") + } +} + var commentTestToml = []byte(` # it's a comment on type [postgres]