From e7d1a179aeaaa957dbb931060557b5beb4297920 Mon Sep 17 00:00:00 2001 From: x-hgg-x <39058530+x-hgg-x@users.noreply.github.com> Date: Mon, 4 May 2020 19:33:55 +0200 Subject: [PATCH] Support custom unmarshaler (#394) Co-authored-by: Thomas Pelletier --- marshal.go | 36 ++++++++++++++++++++++++ marshal_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/marshal.go b/marshal.go index e8b5db5..a6e9733 100644 --- a/marshal.go +++ b/marshal.go @@ -70,6 +70,7 @@ const ( var timeType = reflect.TypeOf(time.Time{}) var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() +var unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem() var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var localDateType = reflect.TypeOf(LocalDate{}) @@ -156,6 +157,14 @@ func callTextMarshaler(mval reflect.Value) ([]byte, error) { return mval.Interface().(encoding.TextMarshaler).MarshalText() } +func isCustomUnmarshaler(mtype reflect.Type) bool { + return mtype.Implements(unmarshalerType) +} + +func callCustomUnmarshaler(mval reflect.Value, tval interface{}) error { + return mval.Interface().(Unmarshaler).UnmarshalTOML(tval) +} + func isTextUnmarshaler(mtype reflect.Type) bool { return mtype.Implements(textUnmarshalerType) } @@ -170,6 +179,12 @@ type Marshaler interface { MarshalTOML() ([]byte, error) } +// Unmarshaler is the interface implemented by types that +// can unmarshal a TOML description of themselves. +type Unmarshaler interface { + UnmarshalTOML(interface{}) error +} + /* Marshal returns the TOML encoding of v. Behavior is similar to the Go json encoder, except that there is no concept of a Marshaler interface or MarshalTOML @@ -676,6 +691,17 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.V if mtype.Kind() == reflect.Ptr { return d.unwrapPointer(mtype, tval, mval1) } + + // Check if pointer to value implements the Unmarshaler interface. + if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) { + d.visitor.visitAll() + + if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err) + } + return mvalPtr.Elem(), nil + } + var mval reflect.Value switch mtype.Kind() { case reflect.Struct: @@ -1151,6 +1177,16 @@ func (s *visitorState) visit() { } } +func (s *visitorState) visitAll() { + if s.active { + for k := range s.keys { + if strings.HasPrefix(k, strings.Join(s.path, ".")) { + delete(s.keys, k) + } + } + } +} + func (s *visitorState) validate() error { if !s.active { return nil diff --git a/marshal_test.go b/marshal_test.go index 1a51755..107413e 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3449,6 +3449,81 @@ func TestDecoderStrictValid(t *testing.T) { } } +type docUnmarshalTOML struct { + Decoded struct { + Key string + } +} + +func (d *docUnmarshalTOML) UnmarshalTOML(i interface{}) error { + if iMap, ok := i.(map[string]interface{}); !ok { + return fmt.Errorf("type assertion error: wants %T, have %T", map[string]interface{}{}, i) + } else if key, ok := iMap["key"]; !ok { + return fmt.Errorf("key '%s' not in map", "key") + } else if keyString, ok := key.(string); !ok { + return fmt.Errorf("type assertion error: wants %T, have %T", "", key) + } else { + d.Decoded.Key = keyString + } + return nil +} + +func TestDecoderStrictCustomUnmarshal(t *testing.T) { + input := `key = "ok"` + var doc docUnmarshalTOML + err := NewDecoder(bytes.NewReader([]byte(input))).Strict(true).Decode(&doc) + if err != nil { + t.Fatal("unexpected error:", err) + } + if doc.Decoded.Key != "ok" { + t.Errorf("Bad unmarshal: expected ok, got %v", doc.Decoded.Key) + } +} + +type parent struct { + Doc docUnmarshalTOML + DocPointer *docUnmarshalTOML +} + +func TestCustomUnmarshal(t *testing.T) { + input := ` +[Doc] + key = "ok1" +[DocPointer] + key = "ok2" +` + + var d parent + if err := Unmarshal([]byte(input), &d); err != nil { + t.Fatalf("unexpected err: %s", err.Error()) + } + if d.Doc.Decoded.Key != "ok1" { + t.Errorf("Bad unmarshal: expected ok, got %v", d.Doc.Decoded.Key) + } + if d.DocPointer.Decoded.Key != "ok2" { + t.Errorf("Bad unmarshal: expected ok, got %v", d.DocPointer.Decoded.Key) + } +} + +func TestCustomUnmarshalError(t *testing.T) { + input := ` +[Doc] + key = 1 +[DocPointer] + key = "ok2" +` + + expected := "(2, 1): unmarshal toml: type assertion error: wants string, have int64" + + var d parent + err := Unmarshal([]byte(input), &d) + if err == nil { + t.Error("expected error, got none") + } else if err.Error() != expected { + t.Errorf("expect err: %s, got: %s", expected, err.Error()) + } +} + type intWrapper struct { Value int }