From e29a498ed5b47ea382244b8f6c2caf5a14821481 Mon Sep 17 00:00:00 2001 From: Oncilla Date: Mon, 4 May 2020 18:49:37 +0200 Subject: [PATCH] unmarshal: support encoding.TextUnmarshaler (#375) * unmarshal: support encoding.TextUnmarshaler This PR adds support for decoding fields of primitive types that implement encoding.TextUnmarshaler by calling the custom method. Fields in anonymous structs are not supported at this point. Co-authored-by: Lorenz Bauer --- marshal.go | 23 ++++++++++++++++ marshal_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/marshal.go b/marshal.go index 9c5957b..6ab587e 100644 --- a/marshal.go +++ b/marshal.go @@ -71,6 +71,7 @@ const ( var timeType = reflect.TypeOf(time.Time{}) var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() +var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var localDateType = reflect.TypeOf(LocalDate{}) var localTimeType = reflect.TypeOf(LocalTime{}) var localDateTimeType = reflect.TypeOf(LocalDateTime{}) @@ -155,6 +156,14 @@ func callTextMarshaler(mval reflect.Value) ([]byte, error) { return mval.Interface().(encoding.TextMarshaler).MarshalText() } +func isTextUnmarshaler(mtype reflect.Type) bool { + return mtype.Implements(textUnmarshalerType) +} + +func callTextUnmarshaler(mval reflect.Value, text []byte) error { + return mval.Interface().(encoding.TextUnmarshaler).UnmarshalText(text) +} + // Marshaler is the interface implemented by types that // can marshal themselves into valid TOML. type Marshaler interface { @@ -866,6 +875,14 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) default: d.visitor.visit() + // Check if pointer to value implements the encoding.TextUnmarshaler. + if mvalPtr := reflect.New(mtype); isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) { + if err := d.unmarshalText(tval, mvalPtr); err != nil { + return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err) + } + return mvalPtr.Elem(), nil + } + switch mtype.Kind() { case reflect.Bool, reflect.Struct: val := reflect.ValueOf(tval) @@ -983,6 +1000,12 @@ func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *ref return mval, nil } +func (d *Decoder) unmarshalText(tval interface{}, mval reflect.Value) error { + var buf bytes.Buffer + fmt.Fprint(&buf, tval) + return callTextUnmarshaler(mval, buf.Bytes()) +} + func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { tag := vf.Tag.Get(an.tag) parse := strings.Split(tag, ",") diff --git a/marshal_test.go b/marshal_test.go index 59cafcb..4fa9600 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "reflect" + "strconv" "strings" "testing" "time" @@ -3226,3 +3227,74 @@ func TestDecoderStrictValid(t *testing.T) { t.Fatal("unexpected error:", err) } } + +type intWrapper struct { + Value int +} + +func (w *intWrapper) UnmarshalText(text []byte) error { + var err error + if w.Value, err = strconv.Atoi(string(text)); err == nil { + return nil + } + if b, err := strconv.ParseBool(string(text)); err == nil { + if b { + w.Value = 1 + } + return nil + } + if f, err := strconv.ParseFloat(string(text), 32); err == nil { + w.Value = int(f) + return nil + } + return fmt.Errorf("unsupported: %s", text) +} + +func TestTextUnmarshal(t *testing.T) { + var doc struct { + UnixTime intWrapper + Version *intWrapper + + Bool intWrapper + Int intWrapper + Float intWrapper + } + + input := ` +UnixTime = "12" +Version = "42" +Bool = true +Int = 21 +Float = 2.0 +` + + if err := Unmarshal([]byte(input), &doc); err != nil { + t.Fatalf("unexpected err: %s", err.Error()) + } + if doc.UnixTime.Value != 12 { + t.Fatalf("expected UnixTime: 12 got: %d", doc.UnixTime.Value) + } + if doc.Version.Value != 42 { + t.Fatalf("expected Version: 42 got: %d", doc.Version.Value) + } + if doc.Bool.Value != 1 { + t.Fatalf("expected Bool: 1 got: %d", doc.Bool.Value) + } + if doc.Int.Value != 21 { + t.Fatalf("expected Int: 21 got: %d", doc.Int.Value) + } + if doc.Float.Value != 2 { + t.Fatalf("expected Float: 2 got: %d", doc.Float.Value) + } +} + +func TestTextUnmarshalError(t *testing.T) { + var doc struct { + Failer intWrapper + } + + input := `Failer = "hello"` + if err := Unmarshal([]byte(input), &doc); err == nil { + t.Fatalf("expected err, got none") + } +}