Provide Marshaler interface (#151)

The toml.Marhshaler interface allows marshalling custom objects implementing
the interface. Design based off json.Marshaler.
This commit is contained in:
Carolyn Van Slyck
2017-04-04 20:41:05 -05:00
committed by Thomas Pelletier
parent e32a2e0474
commit fe206efb84
3 changed files with 72 additions and 1 deletions
+21 -1
View File
@@ -33,6 +33,7 @@ type tomlOpts struct {
} }
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
// Check if the given marshall type maps to a TomlTree primitive // Check if the given marshall type maps to a TomlTree primitive
func isPrimitive(mtype reflect.Type) bool { func isPrimitive(mtype reflect.Type) bool {
@@ -50,7 +51,7 @@ func isPrimitive(mtype reflect.Type) bool {
case reflect.String: case reflect.String:
return true return true
case reflect.Struct: case reflect.Struct:
return mtype == timeType return mtype == timeType || isCustomMarshaler(mtype)
default: default:
return false return false
} }
@@ -90,6 +91,20 @@ func isTree(mtype reflect.Type) bool {
} }
} }
func isCustomMarshaler(mtype reflect.Type) bool {
return mtype.Implements(marshalerType)
}
func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(Marshaler).MarshalTOML()
}
// Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML.
type Marshaler interface {
MarshalTOML() ([]byte, error)
}
/* /*
Marshal returns the TOML encoding of v. Behavior is similar to the Go json Marshal returns the TOML encoding of v. Behavior is similar to the Go json
encoder, except that there is no concept of a Marshaler interface or MarshalTOML encoder, except that there is no concept of a Marshaler interface or MarshalTOML
@@ -106,6 +121,9 @@ func Marshal(v interface{}) ([]byte, error) {
return []byte{}, errors.New("Only a struct can be marshaled to TOML") return []byte{}, errors.New("Only a struct can be marshaled to TOML")
} }
sval := reflect.ValueOf(v) sval := reflect.ValueOf(v)
if isCustomMarshaler(mtype) {
return callCustomMarshaler(sval)
}
t, err := valueToTree(mtype, sval) t, err := valueToTree(mtype, sval)
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
@@ -178,6 +196,8 @@ func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
return valueToToml(mtype.Elem(), mval.Elem()) return valueToToml(mtype.Elem(), mval.Elem())
} }
switch { switch {
case isCustomMarshaler(mtype):
return callCustomMarshaler(mval)
case isTree(mtype): case isTree(mtype):
return valueToTree(mtype, mval) return valueToTree(mtype, mval)
case isTreeSlice(mtype): case isTreeSlice(mtype):
+48
View File
@@ -3,6 +3,7 @@ package toml
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"reflect" "reflect"
"testing" "testing"
@@ -533,3 +534,50 @@ func TestNestedUnmarshal(t *testing.T) {
t.Errorf("Bad nested unmarshal: expected %v, got %v", expected, result) t.Errorf("Bad nested unmarshal: expected %v, got %v", expected, result)
} }
} }
type customMarshalerParent struct {
Self customMarshaler `toml:"me"`
Friends []customMarshaler `toml:"friends"`
}
type customMarshaler struct {
FirsName string
LastName string
}
func (c customMarshaler) MarshalTOML() ([]byte, error) {
fullName := fmt.Sprintf("%s %s", c.FirsName, c.LastName)
return []byte(fullName), nil
}
var customMarshalerData = customMarshaler{FirsName: "Sally", LastName: "Fields"}
var customMarshalerToml = []byte(`Sally Fields`)
var nestedCustomMarshalerData = customMarshalerParent{
Self: customMarshaler{FirsName: "Maiku", LastName: "Suteda"},
Friends: []customMarshaler{customMarshalerData},
}
var nestedCustomMarshalerToml = []byte(`friends = ["Sally Fields"]
me = "Maiku Suteda"
`)
func TestCustomMarshaler(t *testing.T) {
result, err := Marshal(customMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := customMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
func TestNestedCustomMarshaler(t *testing.T) {
result, err := Marshal(nestedCustomMarshalerData)
if err != nil {
t.Fatal(err)
}
expected := nestedCustomMarshalerToml
if !bytes.Equal(result, expected) {
t.Errorf("Bad nested custom marshaler: expected\n-----\n%s\n-----\ngot\n-----\n%s\n-----\n", expected, result)
}
}
+3
View File
@@ -52,6 +52,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
return strconv.FormatFloat(value, 'f', -1, 32), nil return strconv.FormatFloat(value, 'f', -1, 32), nil
case string: case string:
return "\"" + encodeTomlString(value) + "\"", nil return "\"" + encodeTomlString(value) + "\"", nil
case []byte:
b, _ := v.([]byte)
return tomlValueStringRepresentation(string(b))
case bool: case bool:
if value { if value {
return "true", nil return "true", nil