diff --git a/internal/imported_tests/marshal_imported_test.go b/internal/imported_tests/marshal_imported_test.go index 98ad71f..578cf57 100644 --- a/internal/imported_tests/marshal_imported_test.go +++ b/internal/imported_tests/marshal_imported_test.go @@ -4,6 +4,7 @@ package imported_tests // defaults of v2. import ( + "fmt" "testing" "time" @@ -164,3 +165,34 @@ stringlist = [] require.Equal(t, string(expected), string(result)) } + +type textMarshaler struct { + FirstName string + LastName string +} + +func (m textMarshaler) MarshalText() ([]byte, error) { + fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) + return []byte(fullName), nil +} + +func TestTextMarshaler(t *testing.T) { + type wrap struct { + TM textMarshaler + } + + m := textMarshaler{FirstName: "Sally", LastName: "Fields"} + + t.Run("at root", func(t *testing.T) { + _, err := toml.Marshal(m) + // in v2 we do not allow TextMarshaler at root + require.Error(t, err) + }) + + t.Run("leaf", func(t *testing.T) { + res, err := toml.Marshal(wrap{m}) + require.NoError(t, err) + + require.Equal(t, "TM = 'Sally Fields'\n", string(res)) + }) +} diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 8dbfec4..d3f54d8 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -612,16 +612,6 @@ func (x *IntOrString) MarshalTOML() ([]byte, error) { return []byte(s), nil } -type textMarshaler struct { - FirstName string - LastName string -} - -func (m textMarshaler) MarshalText() ([]byte, error) { - fullName := fmt.Sprintf("%s %s", m.FirstName, m.LastName) - return []byte(fullName), nil -} - func TestUnmarshalTextMarshaler(t *testing.T) { var nested = struct { Friends textMarshaler `toml:"friends"` diff --git a/marshaler.go b/marshaler.go index 2491015..06f11a1 100644 --- a/marshaler.go +++ b/marshaler.go @@ -2,6 +2,7 @@ package toml import ( "bytes" + "encoding" "errors" "fmt" "io" @@ -165,14 +166,27 @@ func (ctx *encoderCtx) isRoot() bool { } var errUnsupportedValue = errors.New("unsupported encode value kind") +var errTextMarshalerCannotBeAtRoot = errors.New("type implementing TextMarshaler cannot be at root") //nolint:cyclop func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { //nolint:gocritic,godox - switch i := v.Interface().(type) { - case time.Time: // TODO: add TextMarshaler - b = i.AppendFormat(b, time.RFC3339) + if v.Type() == timeType { + i := v.Interface().(time.Time) + b = i.AppendFormat(b, time.RFC3339) + return b, nil + } + if v.Type().Implements(textMarshalerType) { + if ctx.isRoot() { + return nil, errTextMarshalerCannotBeAtRoot + } + + text, err := v.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return nil, err + } + b = enc.encodeString(b, string(text), ctx.options) return b, nil } @@ -620,10 +634,10 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte var errNilInterface = errors.New("nil interface not supported") +var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() + func willConvertToTable(ctx encoderCtx, v reflect.Value) (bool, error) { - //nolint:gocritic,godox - switch v.Interface().(type) { - case time.Time: // TODO: add TextMarshaler + if v.Type() == timeType || v.Type().Implements(textMarshalerType) { return false, nil }