encoder: support TextMarshaler (#522)

Fixes #521
This commit is contained in:
Thomas Pelletier
2021-04-22 10:13:41 -04:00
committed by GitHub
parent 2b1c52dddd
commit e443b4fdb8
3 changed files with 52 additions and 16 deletions
@@ -4,6 +4,7 @@ package imported_tests
// defaults of v2. // defaults of v2.
import ( import (
"fmt"
"testing" "testing"
"time" "time"
@@ -164,3 +165,34 @@ stringlist = []
require.Equal(t, string(expected), string(result)) require.Equal(t, string(expected), string(result))
} }
type textMarshaler struct {
FirstName string
LastName string
}
func (m textMarshaler) MarshalText() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestTextMarshaler(t *testing.T) {
type wrap struct {
TM textMarshaler
}
m := textMarshaler{FirstName: "Sally", LastName: "Fields"}
t.Run("at root", func(t *testing.T) {
_, err := toml.Marshal(m)
// in v2 we do not allow TextMarshaler at root
require.Error(t, err)
})
t.Run("leaf", func(t *testing.T) {
res, err := toml.Marshal(wrap{m})
require.NoError(t, err)
require.Equal(t, "TM = 'Sally Fields'\n", string(res))
})
}
@@ -612,16 +612,6 @@ func (x *IntOrString) MarshalTOML() ([]byte, error) {
return []byte(s), nil return []byte(s), nil
} }
type textMarshaler struct {
FirstName string
LastName string
}
func (m textMarshaler) MarshalText() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName)
return []byte(fullName), nil
}
func TestUnmarshalTextMarshaler(t *testing.T) { func TestUnmarshalTextMarshaler(t *testing.T) {
var nested = struct { var nested = struct {
Friends textMarshaler `toml:"friends"` Friends textMarshaler `toml:"friends"`
+20 -6
View File
@@ -2,6 +2,7 @@ package toml
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -165,14 +166,27 @@ func (ctx *encoderCtx) isRoot() bool {
} }
var errUnsupportedValue = errors.New("unsupported encode value kind") var errUnsupportedValue = errors.New("unsupported encode value kind")
var errTextMarshalerCannotBeAtRoot = errors.New("type implementing TextMarshaler cannot be at root")
//nolint:cyclop //nolint:cyclop
func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
//nolint:gocritic,godox //nolint:gocritic,godox
switch i := v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler
b = i.AppendFormat(b, time.RFC3339)
if v.Type() == timeType {
i := v.Interface().(time.Time)
b = i.AppendFormat(b, time.RFC3339)
return b, nil
}
if v.Type().Implements(textMarshalerType) {
if ctx.isRoot() {
return nil, errTextMarshalerCannotBeAtRoot
}
text, err := v.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return nil, err
}
b = enc.encodeString(b, string(text), ctx.options)
return b, nil return b, nil
} }
@@ -620,10 +634,10 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
var errNilInterface = errors.New("nil interface not supported") var errNilInterface = errors.New("nil interface not supported")
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) {
//nolint:gocritic,godox if v.Type() == timeType || v.Type().Implements(textMarshalerType) {
switch v.Interface().(type) {
case time.Time: // TODO: add TextMarshaler
return false, nil return false, nil
} }