unmarshal: support encoding.TextUnmarshaler (#375)

* unmarshal: support encoding.TextUnmarshaler

This PR adds support for decoding fields of primitive types that implement
encoding.TextUnmarshaler by calling the custom method.

Fields in anonymous structs are not supported at this point.

Co-authored-by: Lorenz Bauer <lmb@users.noreply.github.com>
This commit is contained in:
Oncilla
2020-05-04 18:49:37 +02:00
committed by GitHub
parent 2b8e33f503
commit e29a498ed5
2 changed files with 95 additions and 0 deletions
+23
View File
@@ -71,6 +71,7 @@ const (
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem() var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
var localDateType = reflect.TypeOf(LocalDate{}) var localDateType = reflect.TypeOf(LocalDate{})
var localTimeType = reflect.TypeOf(LocalTime{}) var localTimeType = reflect.TypeOf(LocalTime{})
var localDateTimeType = reflect.TypeOf(LocalDateTime{}) var localDateTimeType = reflect.TypeOf(LocalDateTime{})
@@ -155,6 +156,14 @@ func callTextMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(encoding.TextMarshaler).MarshalText() return mval.Interface().(encoding.TextMarshaler).MarshalText()
} }
func isTextUnmarshaler(mtype reflect.Type) bool {
return mtype.Implements(textUnmarshalerType)
}
func callTextUnmarshaler(mval reflect.Value, text []byte) error {
return mval.Interface().(encoding.TextUnmarshaler).UnmarshalText(text)
}
// Marshaler is the interface implemented by types that // Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML. // can marshal themselves into valid TOML.
type Marshaler interface { type Marshaler interface {
@@ -866,6 +875,14 @@ func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *ref
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval) return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval)
default: default:
d.visitor.visit() d.visitor.visit()
// Check if pointer to value implements the encoding.TextUnmarshaler.
if mvalPtr := reflect.New(mtype); isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) {
if err := d.unmarshalText(tval, mvalPtr); err != nil {
return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err)
}
return mvalPtr.Elem(), nil
}
switch mtype.Kind() { switch mtype.Kind() {
case reflect.Bool, reflect.Struct: case reflect.Bool, reflect.Struct:
val := reflect.ValueOf(tval) val := reflect.ValueOf(tval)
@@ -983,6 +1000,12 @@ func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *ref
return mval, nil return mval, nil
} }
func (d *Decoder) unmarshalText(tval interface{}, mval reflect.Value) error {
var buf bytes.Buffer
fmt.Fprint(&buf, tval)
return callTextUnmarshaler(mval, buf.Bytes())
}
func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { func tomlOptions(vf reflect.StructField, an annotation) tomlOpts {
tag := vf.Tag.Get(an.tag) tag := vf.Tag.Get(an.tag)
parse := strings.Split(tag, ",") parse := strings.Split(tag, ",")
+72
View File
@@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"reflect" "reflect"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -3226,3 +3227,74 @@ func TestDecoderStrictValid(t *testing.T) {
t.Fatal("unexpected error:", err) t.Fatal("unexpected error:", err)
} }
} }
type intWrapper struct {
Value int
}
func (w *intWrapper) UnmarshalText(text []byte) error {
var err error
if w.Value, err = strconv.Atoi(string(text)); err == nil {
return nil
}
if b, err := strconv.ParseBool(string(text)); err == nil {
if b {
w.Value = 1
}
return nil
}
if f, err := strconv.ParseFloat(string(text), 32); err == nil {
w.Value = int(f)
return nil
}
return fmt.Errorf("unsupported: %s", text)
}
func TestTextUnmarshal(t *testing.T) {
var doc struct {
UnixTime intWrapper
Version *intWrapper
Bool intWrapper
Int intWrapper
Float intWrapper
}
input := `
UnixTime = "12"
Version = "42"
Bool = true
Int = 21
Float = 2.0
`
if err := Unmarshal([]byte(input), &doc); err != nil {
t.Fatalf("unexpected err: %s", err.Error())
}
if doc.UnixTime.Value != 12 {
t.Fatalf("expected UnixTime: 12 got: %d", doc.UnixTime.Value)
}
if doc.Version.Value != 42 {
t.Fatalf("expected Version: 42 got: %d", doc.Version.Value)
}
if doc.Bool.Value != 1 {
t.Fatalf("expected Bool: 1 got: %d", doc.Bool.Value)
}
if doc.Int.Value != 21 {
t.Fatalf("expected Int: 21 got: %d", doc.Int.Value)
}
if doc.Float.Value != 2 {
t.Fatalf("expected Float: 2 got: %d", doc.Float.Value)
}
}
func TestTextUnmarshalError(t *testing.T) {
var doc struct {
Failer intWrapper
}
input := `Failer = "hello"`
if err := Unmarshal([]byte(input), &doc); err == nil {
t.Fatalf("expected err, got none")
}
}