Support custom unmarshaler (#394)
Co-authored-by: Thomas Pelletier <pelletier.thomas@gmail.com>
This commit is contained in:
+36
@@ -70,6 +70,7 @@ const (
|
||||
|
||||
var timeType = reflect.TypeOf(time.Time{})
|
||||
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
|
||||
var unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
|
||||
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
|
||||
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
|
||||
var localDateType = reflect.TypeOf(LocalDate{})
|
||||
@@ -156,6 +157,14 @@ func callTextMarshaler(mval reflect.Value) ([]byte, error) {
|
||||
return mval.Interface().(encoding.TextMarshaler).MarshalText()
|
||||
}
|
||||
|
||||
func isCustomUnmarshaler(mtype reflect.Type) bool {
|
||||
return mtype.Implements(unmarshalerType)
|
||||
}
|
||||
|
||||
func callCustomUnmarshaler(mval reflect.Value, tval interface{}) error {
|
||||
return mval.Interface().(Unmarshaler).UnmarshalTOML(tval)
|
||||
}
|
||||
|
||||
func isTextUnmarshaler(mtype reflect.Type) bool {
|
||||
return mtype.Implements(textUnmarshalerType)
|
||||
}
|
||||
@@ -170,6 +179,12 @@ type Marshaler interface {
|
||||
MarshalTOML() ([]byte, error)
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface implemented by types that
|
||||
// can unmarshal a TOML description of themselves.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalTOML(interface{}) error
|
||||
}
|
||||
|
||||
/*
|
||||
Marshal returns the TOML encoding of v. Behavior is similar to the Go json
|
||||
encoder, except that there is no concept of a Marshaler interface or MarshalTOML
|
||||
@@ -676,6 +691,17 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
|
||||
if mtype.Kind() == reflect.Ptr {
|
||||
return d.unwrapPointer(mtype, tval, mval1)
|
||||
}
|
||||
|
||||
// Check if pointer to value implements the Unmarshaler interface.
|
||||
if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) {
|
||||
d.visitor.visitAll()
|
||||
|
||||
if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil {
|
||||
return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err)
|
||||
}
|
||||
return mvalPtr.Elem(), nil
|
||||
}
|
||||
|
||||
var mval reflect.Value
|
||||
switch mtype.Kind() {
|
||||
case reflect.Struct:
|
||||
@@ -1151,6 +1177,16 @@ func (s *visitorState) visit() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *visitorState) visitAll() {
|
||||
if s.active {
|
||||
for k := range s.keys {
|
||||
if strings.HasPrefix(k, strings.Join(s.path, ".")) {
|
||||
delete(s.keys, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *visitorState) validate() error {
|
||||
if !s.active {
|
||||
return nil
|
||||
|
||||
@@ -3449,6 +3449,81 @@ func TestDecoderStrictValid(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type docUnmarshalTOML struct {
|
||||
Decoded struct {
|
||||
Key string
|
||||
}
|
||||
}
|
||||
|
||||
func (d *docUnmarshalTOML) UnmarshalTOML(i interface{}) error {
|
||||
if iMap, ok := i.(map[string]interface{}); !ok {
|
||||
return fmt.Errorf("type assertion error: wants %T, have %T", map[string]interface{}{}, i)
|
||||
} else if key, ok := iMap["key"]; !ok {
|
||||
return fmt.Errorf("key '%s' not in map", "key")
|
||||
} else if keyString, ok := key.(string); !ok {
|
||||
return fmt.Errorf("type assertion error: wants %T, have %T", "", key)
|
||||
} else {
|
||||
d.Decoded.Key = keyString
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDecoderStrictCustomUnmarshal(t *testing.T) {
|
||||
input := `key = "ok"`
|
||||
var doc docUnmarshalTOML
|
||||
err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error:", err)
|
||||
}
|
||||
if doc.Decoded.Key != "ok" {
|
||||
t.Errorf("Bad unmarshal: expected ok, got %v", doc.Decoded.Key)
|
||||
}
|
||||
}
|
||||
|
||||
type parent struct {
|
||||
Doc docUnmarshalTOML
|
||||
DocPointer *docUnmarshalTOML
|
||||
}
|
||||
|
||||
func TestCustomUnmarshal(t *testing.T) {
|
||||
input := `
|
||||
[Doc]
|
||||
key = "ok1"
|
||||
[DocPointer]
|
||||
key = "ok2"
|
||||
`
|
||||
|
||||
var d parent
|
||||
if err := Unmarshal([]byte(input), &d); err != nil {
|
||||
t.Fatalf("unexpected err: %s", err.Error())
|
||||
}
|
||||
if d.Doc.Decoded.Key != "ok1" {
|
||||
t.Errorf("Bad unmarshal: expected ok, got %v", d.Doc.Decoded.Key)
|
||||
}
|
||||
if d.DocPointer.Decoded.Key != "ok2" {
|
||||
t.Errorf("Bad unmarshal: expected ok, got %v", d.DocPointer.Decoded.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomUnmarshalError(t *testing.T) {
|
||||
input := `
|
||||
[Doc]
|
||||
key = 1
|
||||
[DocPointer]
|
||||
key = "ok2"
|
||||
`
|
||||
|
||||
expected := "(2, 1): unmarshal toml: type assertion error: wants string, have int64"
|
||||
|
||||
var d parent
|
||||
err := Unmarshal([]byte(input), &d)
|
||||
if err == nil {
|
||||
t.Error("expected error, got none")
|
||||
} else if err.Error() != expected {
|
||||
t.Errorf("expect err: %s, got: %s", expected, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type intWrapper struct {
|
||||
Value int
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user