From 2b8e33f5039358ee5baf28ccd55ba059f4989d7b Mon Sep 17 00:00:00 2001 From: Oncilla Date: Tue, 28 Apr 2020 13:29:00 +0200 Subject: [PATCH] marshal: support encoding.TextMarshaler (#374) With this PR the encoder now supports encoding.TextMarshaler. Additionally, a bug is fixed, where the encoder does not notice a pointer field that implements the toml.Marshaler interface. fixes #373 --- marshal.go | 30 ++++++++++- marshal_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 152 insertions(+), 12 deletions(-) diff --git a/marshal.go b/marshal.go index 1045d3b..9c5957b 100644 --- a/marshal.go +++ b/marshal.go @@ -2,6 +2,7 @@ package toml import ( "bytes" + "encoding" "errors" "fmt" "io" @@ -69,6 +70,7 @@ const ( var timeType = reflect.TypeOf(time.Time{}) var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() +var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() var localDateType = reflect.TypeOf(LocalDate{}) var localTimeType = reflect.TypeOf(LocalTime{}) var localDateTimeType = reflect.TypeOf(LocalDateTime{}) @@ -89,12 +91,16 @@ func isPrimitive(mtype reflect.Type) bool { case reflect.String: return true case reflect.Struct: - return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType || isCustomMarshaler(mtype) + return isTimeType(mtype) || isCustomMarshaler(mtype) || isTextMarshaler(mtype) default: return false } } +func isTimeType(mtype reflect.Type) bool { + return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType +} + // Check if the given marshal type maps to a Tree slice or array func isTreeSequence(mtype reflect.Type) bool { switch mtype.Kind() { @@ -141,6 +147,14 @@ func callCustomMarshaler(mval reflect.Value) ([]byte, error) { return mval.Interface().(Marshaler).MarshalTOML() } +func isTextMarshaler(mtype reflect.Type) bool { + return mtype.Implements(textMarshalerType) && !isTimeType(mtype) +} + +func callTextMarshaler(mval reflect.Value) ([]byte, error) { + return mval.Interface().(encoding.TextMarshaler).MarshalText() +} + // Marshaler is the interface implemented by types that // can marshal themselves into valid TOML. type Marshaler interface { @@ -317,6 +331,9 @@ func (e *Encoder) marshal(v interface{}) ([]byte, error) { if isCustomMarshaler(mtype) { return callCustomMarshaler(sval) } + if isTextMarshaler(mtype) { + return callTextMarshaler(sval) + } t, err := e.valueToTree(mtype, sval) if err != nil { return []byte{}, err @@ -441,7 +458,14 @@ func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (int func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) { e.line++ if mtype.Kind() == reflect.Ptr { - return e.valueToToml(mtype.Elem(), mval.Elem()) + switch { + case isCustomMarshaler(mtype): + return callCustomMarshaler(mval) + case isTextMarshaler(mtype): + return callTextMarshaler(mval) + default: + return e.valueToToml(mtype.Elem(), mval.Elem()) + } } if mtype.Kind() == reflect.Interface { return e.valueToToml(mval.Elem().Type(), mval.Elem()) @@ -449,6 +473,8 @@ func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface switch { case isCustomMarshaler(mtype): return callCustomMarshaler(mval) + case isTextMarshaler(mtype): + return callTextMarshaler(mval) case isTree(mtype): return e.valueToTree(mtype, mval) case isTreeSequence(mtype): diff --git a/marshal_test.go b/marshal_test.go index a8a5695..59cafcb 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -859,19 +859,19 @@ type customMarshalerParent struct { } type customMarshaler struct { - FirsName string - LastName string + FirstName string + LastName string } func (c customMarshaler) MarshalTOML() ([]byte, error) { - fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName) + fullName := fmt.Sprintf("%s %s", c.FirstName, c.LastName) return []byte(fullName), nil } -var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} +var customMarshalerData = customMarshaler{FirstName: "Sally", LastName: "Fields"} var customMarshalerToml = []byte(`Sally Fields`) var nestedCustomMarshalerData = customMarshalerParent{ - Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"}, + Self: customMarshaler{FirstName: "Maiku", LastName: "Suteda"}, Friends: []customMarshaler{customMarshalerData}, } var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] @@ -889,14 +889,128 @@ func TestCustomMarshaler(t *testing.T) { } } -func TestNestedCustomMarshaler(t *testing.T) { - result, err := Marshal(nestedCustomMarshalerData) +type textMarshaler struct { + FirstName string + LastName string +} + +func (m textMarshaler) MarshalText() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) + return []byte(fullName), nil +} + +func TestTextMarshaler(t *testing.T) { + m := textMarshaler{FirstName: "Sally", LastName: "Fields"} + + result, err := Marshal(m) if err != nil { t.Fatal(err) } - expected := nestedCustomMarshalerToml - if !bytes.Equal(result, expected) { - t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + expected := `Sally Fields` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +func TestNestedTextMarshaler(t *testing.T) { + var parent = struct { + Self textMarshaler `toml:"me"` + Friends []textMarshaler `toml:"friends"` + Stranger *textMarshaler `toml:"stranger"` + }{ + Self: textMarshaler{FirstName: "Maiku", LastName: "Suteda"}, + Friends: []textMarshaler{textMarshaler{FirstName: "Sally", LastName: "Fields"}}, + Stranger: &textMarshaler{FirstName: "Earl", LastName: "Henson"}, + } + + result, err := Marshal(parent) + if err != nil { + t.Fatal(err) + } + expected := `friends = ["Sally Fields"] +me = "Maiku Suteda" +stranger = "Earl Henson" +` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +type precedentMarshaler struct { + FirstName string + LastName string +} + +func (m precedentMarshaler) MarshalText() ([]byte, error) { + return []byte("shadowed"), nil +} + +func (m precedentMarshaler) MarshalTOML() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) + return []byte(fullName), nil +} + +func TestPrecedentMarshaler(t *testing.T) { + m := textMarshaler{FirstName: "Sally", LastName: "Fields"} + + result, err := Marshal(m) + if err != nil { + t.Fatal(err) + } + expected := `Sally Fields` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) + } +} + +type customPointerMarshaler struct { + FirstName string + LastName string +} + +func (m *customPointerMarshaler) MarshalText() ([]byte, error) { + return []byte("hidden"), nil +} + +type textPointerMarshaler struct { + FirstName string + LastName string +} + +func (m *textPointerMarshaler) MarshalText() ([]byte, error) { + return []byte("hidden"), nil +} + +func TestPointerMarshaler(t *testing.T) { + var parent = struct { + Self customPointerMarshaler `toml:"me"` + Stranger *customPointerMarshaler `toml:"stranger"` + Friend textPointerMarshaler `toml:"friend"` + Fiend *textPointerMarshaler `toml:"fiend"` + }{ + Self: customPointerMarshaler{FirstName: "Maiku", LastName: "Suteda"}, + Stranger: &customPointerMarshaler{FirstName: "Earl", LastName: "Henson"}, + Friend: textPointerMarshaler{FirstName: "Sally", LastName: "Fields"}, + Fiend: &textPointerMarshaler{FirstName: "Casper", LastName: "Snider"}, + } + + result, err := Marshal(parent) + if err != nil { + t.Fatal(err) + } + expected := `fiend = "hidden" +stranger = "hidden" + +[friend] + FirstName = "Sally" + LastName = "Fields" + +[me] + FirstName = "Maiku" + LastName = "Suteda" +` + if !bytes.Equal(result, []byte(expected)) { + t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) } }