Support custom unmarshaler (#394)

Co-authored-by: Thomas Pelletier <pelletier.thomas@gmail.com>
This commit is contained in:
x-hgg-x
2020-05-04 19:33:55 +02:00
committed by GitHub
parent 71a8bd4c61
commit e7d1a179ae
2 changed files with 111 additions and 0 deletions
+36
View File
@@ -70,6 +70,7 @@ const (
var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
var unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
var localDateType = reflect.TypeOf(LocalDate{})
@@ -156,6 +157,14 @@ func callTextMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(encoding.TextMarshaler).MarshalText()
}
func isCustomUnmarshaler(mtype reflect.Type) bool {
return mtype.Implements(unmarshalerType)
}
func callCustomUnmarshaler(mval reflect.Value, tval interface{}) error {
return mval.Interface().(Unmarshaler).UnmarshalTOML(tval)
}
func isTextUnmarshaler(mtype reflect.Type) bool {
return mtype.Implements(textUnmarshalerType)
}
@@ -170,6 +179,12 @@ type Marshaler interface {
MarshalTOML() ([]byte, error)
}
// Unmarshaler is the interface implemented by types that
// can unmarshal a TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
/*
Marshal returns the TOML encoding of v. Behavior is similar to the Go json
encoder, except that there is no concept of a Marshaler interface or MarshalTOML
@@ -676,6 +691,17 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V
if mtype.Kind() == reflect.Ptr {
return d.unwrapPointer(mtype, tval, mval1)
}
// Check if pointer to value implements the Unmarshaler interface.
if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) {
d.visitor.visitAll()
if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil {
return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err)
}
return mvalPtr.Elem(), nil
}
var mval reflect.Value
switch mtype.Kind() {
case reflect.Struct:
@@ -1151,6 +1177,16 @@ func (s *visitorState) visit() {
}
}
func (s *visitorState) visitAll() {
if s.active {
for k := range s.keys {
if strings.HasPrefix(k, strings.Join(s.path, ".")) {
delete(s.keys, k)
}
}
}
}
func (s *visitorState) validate() error {
if !s.active {
return nil