marshal: support encoding.TextMarshaler (#374)

With this PR the encoder now supports encoding.TextMarshaler.
Additionally, a bug is fixed, where the encoder does not notice a pointer
field that implements the toml.Marshaler interface.

fixes #373
This commit is contained in:
Oncilla
2020-04-28 13:29:00 +02:00
committed by GitHub
parent d3c92c5999
commit 2b8e33f503
2 changed files with 152 additions and 12 deletions
+27 -1
View File
@@ -2,6 +2,7 @@ package toml
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -69,6 +70,7 @@ const (
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var localDateType = reflect.TypeOf(LocalDate{}) var localDateType = reflect.TypeOf(LocalDate{})
var localTimeType = reflect.TypeOf(LocalTime{}) var localTimeType = reflect.TypeOf(LocalTime{})
var localDateTimeType = reflect.TypeOf(LocalDateTime{}) var localDateTimeType = reflect.TypeOf(LocalDateTime{})
@@ -89,12 +91,16 @@ func isPrimitive(mtype reflect.Type) bool {
case reflect.String: case reflect.String:
return true return true
case reflect.Struct: case reflect.Struct:
return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType || isCustomMarshaler(mtype) return isTimeType(mtype) || isCustomMarshaler(mtype) || isTextMarshaler(mtype)
default: default:
return false return false
} }
} }
func isTimeType(mtype reflect.Type) bool {
return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType
}
// Check if the given marshal type maps to a Tree slice or array // Check if the given marshal type maps to a Tree slice or array
func isTreeSequence(mtype reflect.Type) bool { func isTreeSequence(mtype reflect.Type) bool {
switch mtype.Kind() { switch mtype.Kind() {
@@ -141,6 +147,14 @@ func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(Marshaler).MarshalTOML() return mval.Interface().(Marshaler).MarshalTOML()
} }
func isTextMarshaler(mtype reflect.Type) bool {
return mtype.Implements(textMarshalerType) && !isTimeType(mtype)
}
func callTextMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(encoding.TextMarshaler).MarshalText()
}
// Marshaler is the interface implemented by types that // Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML. // can marshal themselves into valid TOML.
type Marshaler interface { type Marshaler interface {
@@ -317,6 +331,9 @@ func (e *Encoder) marshal(v interface{}) ([]byte, error) {
if isCustomMarshaler(mtype) { if isCustomMarshaler(mtype) {
return callCustomMarshaler(sval) return callCustomMarshaler(sval)
} }
if isTextMarshaler(mtype) {
return callTextMarshaler(sval)
}
t, err := e.valueToTree(mtype, sval) t, err := e.valueToTree(mtype, sval)
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
@@ -441,14 +458,23 @@ func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (int
func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) { func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
e.line++ e.line++
if mtype.Kind() == reflect.Ptr { if mtype.Kind() == reflect.Ptr {
switch {
case isCustomMarshaler(mtype):
return callCustomMarshaler(mval)
case isTextMarshaler(mtype):
return callTextMarshaler(mval)
default:
return e.valueToToml(mtype.Elem(), mval.Elem()) return e.valueToToml(mtype.Elem(), mval.Elem())
} }
}
if mtype.Kind() == reflect.Interface { if mtype.Kind() == reflect.Interface {
return e.valueToToml(mval.Elem().Type(), mval.Elem()) return e.valueToToml(mval.Elem().Type(), mval.Elem())
} }
switch { switch {
case isCustomMarshaler(mtype): case isCustomMarshaler(mtype):
return callCustomMarshaler(mval) return callCustomMarshaler(mval)
case isTextMarshaler(mtype):
return callTextMarshaler(mval)
case isTree(mtype): case isTree(mtype):
return e.valueToTree(mtype, mval) return e.valueToTree(mtype, mval)
case isTreeSequence(mtype): case isTreeSequence(mtype):
+123 -9
View File
@@ -859,19 +859,19 @@ type customMarshalerParent struct {
} }
type customMarshaler struct { type customMarshaler struct {
FirsName string FirstName string
LastName string LastName string
} }
func (c customMarshaler) MarshalTOML() ([]byte, error) { func (c customMarshaler) MarshalTOML() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName) fullName := fmt.Sprintf("%s %s", c.FirstName, c.LastName)
return []byte(fullName), nil return []byte(fullName), nil
} }
var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"} var customMarshalerData = customMarshaler{FirstName: "Sally", LastName: "Fields"}
var customMarshalerToml = []byte(`Sally Fields`) var customMarshalerToml = []byte(`Sally Fields`)
var nestedCustomMarshalerData = customMarshalerParent{ var nestedCustomMarshalerData = customMarshalerParent{
Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"}, Self: customMarshaler{FirstName: "Maiku", LastName: "Suteda"},
Friends: []customMarshaler{customMarshalerData}, Friends: []customMarshaler{customMarshalerData},
} }
var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"] var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"]
@@ -889,14 +889,128 @@ func TestCustomMarshaler(t *testing.T) {
} }
} }
func TestNestedCustomMarshaler(t *testing.T) { type textMarshaler struct {
result, err := Marshal(nestedCustomMarshalerData) FirstName string
LastName string
}
func (m textMarshaler) MarshalText() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestTextMarshaler(t *testing.T) {
m := textMarshaler{FirstName: "Sally", LastName: "Fields"}
result, err := Marshal(m)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
expected := nestedCustomMarshalerToml expected := `Sally Fields`
if !bytes.Equal(result, expected) { if !bytes.Equal(result, []byte(expected)) {
t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result) t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
func TestNestedTextMarshaler(t *testing.T) {
var parent = struct {
Self textMarshaler `toml:"me"`
Friends []textMarshaler `toml:"friends"`
Stranger *textMarshaler `toml:"stranger"`
}{
Self: textMarshaler{FirstName: "Maiku", LastName: "Suteda"},
Friends: []textMarshaler{textMarshaler{FirstName: "Sally", LastName: "Fields"}},
Stranger: &textMarshaler{FirstName: "Earl", LastName: "Henson"},
}
result, err := Marshal(parent)
if err != nil {
t.Fatal(err)
}
expected := `friends = ["Sally Fields"]
me = "Maiku Suteda"
stranger = "Earl Henson"
`
if !bytes.Equal(result, []byte(expected)) {
t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
type precedentMarshaler struct {
FirstName string
LastName string
}
func (m precedentMarshaler) MarshalText() ([]byte, error) {
return []byte("shadowed"), nil
}
func (m precedentMarshaler) MarshalTOML() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestPrecedentMarshaler(t *testing.T) {
m := textMarshaler{FirstName: "Sally", LastName: "Fields"}
result, err := Marshal(m)
if err != nil {
t.Fatal(err)
}
expected := `Sally Fields`
if !bytes.Equal(result, []byte(expected)) {
t.Errorf("Bad text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
type customPointerMarshaler struct {
FirstName string
LastName string
}
func (m *customPointerMarshaler) MarshalText() ([]byte, error) {
return []byte("hidden"), nil
}
type textPointerMarshaler struct {
FirstName string
LastName string
}
func (m *textPointerMarshaler) MarshalText() ([]byte, error) {
return []byte("hidden"), nil
}
func TestPointerMarshaler(t *testing.T) {
var parent = struct {
Self customPointerMarshaler `toml:"me"`
Stranger *customPointerMarshaler `toml:"stranger"`
Friend textPointerMarshaler `toml:"friend"`
Fiend *textPointerMarshaler `toml:"fiend"`
}{
Self: customPointerMarshaler{FirstName: "Maiku", LastName: "Suteda"},
Stranger: &customPointerMarshaler{FirstName: "Earl", LastName: "Henson"},
Friend: textPointerMarshaler{FirstName: "Sally", LastName: "Fields"},
Fiend: &textPointerMarshaler{FirstName: "Casper", LastName: "Snider"},
}
result, err := Marshal(parent)
if err != nil {
t.Fatal(err)
}
expected := `fiend = "hidden"
stranger = "hidden"
[friend]
FirstName = "Sally"
LastName = "Fields"
[me]
FirstName = "Maiku"
LastName = "Suteda"
`
if !bytes.Equal(result, []byte(expected)) {
t.Errorf("Bad nested text marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
} }
} }