Support text Un/Marshaller for map keys (#863)
This commit is contained in:
+21
-5
@@ -577,11 +577,23 @@ func (enc *Encoder) encodeKey(b []byte, k string) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
|
||||
if v.Type().Key().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind())
|
||||
}
|
||||
func (enc *Encoder) keyToString(k reflect.Value) (string, error) {
|
||||
keyType := k.Type()
|
||||
switch {
|
||||
case keyType.Kind() == reflect.String:
|
||||
return k.String(), nil
|
||||
|
||||
case keyType.Implements(textMarshalerType):
|
||||
keyB, err := k.Interface().(encoding.TextMarshaler).MarshalText()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err)
|
||||
}
|
||||
return string(keyB), nil
|
||||
}
|
||||
return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind())
|
||||
}
|
||||
|
||||
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
|
||||
var (
|
||||
t table
|
||||
emptyValueOptions valueOptions
|
||||
@@ -589,13 +601,17 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte
|
||||
|
||||
iter := v.MapRange()
|
||||
for iter.Next() {
|
||||
k := iter.Key().String()
|
||||
v := iter.Value()
|
||||
|
||||
if isNil(v) {
|
||||
continue
|
||||
}
|
||||
|
||||
k, err := enc.keyToString(iter.Key())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if willConvertToTableOrArrayTable(ctx, v) {
|
||||
t.pushTable(k, v, emptyValueOptions)
|
||||
} else {
|
||||
|
||||
+68
-1
@@ -15,6 +15,21 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type marshalTextKey struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
|
||||
func (k marshalTextKey) MarshalText() ([]byte, error) {
|
||||
return []byte(k.A + "-" + k.B), nil
|
||||
}
|
||||
|
||||
type marshalBadTextKey struct{}
|
||||
|
||||
func (k marshalBadTextKey) MarshalText() ([]byte, error) {
|
||||
return nil, fmt.Errorf("error")
|
||||
}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
someInt := 42
|
||||
|
||||
@@ -97,6 +112,53 @@ also = 'that'
|
||||
a = 'test'
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: `map with text key`,
|
||||
v: map[marshalTextKey]string{
|
||||
{A: "a", B: "1"}: "value 1",
|
||||
{A: "a", B: "2"}: "value 2",
|
||||
{A: "b", B: "1"}: "value 3",
|
||||
},
|
||||
expected: `a-1 = 'value 1'
|
||||
a-2 = 'value 2'
|
||||
b-1 = 'value 3'
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: `table with text key`,
|
||||
v: map[marshalTextKey]map[string]string{
|
||||
{A: "a", B: "1"}: {"value": "foo"},
|
||||
},
|
||||
expected: `[a-1]
|
||||
value = 'foo'
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: `map with ptr text key`,
|
||||
v: map[*marshalTextKey]string{
|
||||
{A: "a", B: "1"}: "value 1",
|
||||
{A: "a", B: "2"}: "value 2",
|
||||
{A: "b", B: "1"}: "value 3",
|
||||
},
|
||||
expected: `a-1 = 'value 1'
|
||||
a-2 = 'value 2'
|
||||
b-1 = 'value 3'
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: `map with bad text key`,
|
||||
v: map[marshalBadTextKey]string{
|
||||
{}: "value 1",
|
||||
},
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
desc: `map with bad ptr text key`,
|
||||
v: map[*marshalBadTextKey]string{
|
||||
{}: "value 1",
|
||||
},
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
desc: "simple string array",
|
||||
v: map[string][]string{
|
||||
@@ -487,9 +549,14 @@ foo = 42
|
||||
},
|
||||
{
|
||||
desc: "invalid map key",
|
||||
v: map[int]interface{}{},
|
||||
v: map[int]interface{}{1: "a"},
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
desc: "invalid map key but empty",
|
||||
v: map[int]interface{}{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
desc: "unhandled type",
|
||||
v: struct {
|
||||
|
||||
+32
-11
@@ -417,7 +417,10 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h
|
||||
vt := v.Type()
|
||||
|
||||
// Create the key for the map element. Convert to key type.
|
||||
mk := reflect.ValueOf(string(key.Node().Data)).Convert(vt.Key())
|
||||
mk, err := d.keyFromData(vt.Key(), key.Node().Data)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// If the map does not exist, create it.
|
||||
if v.IsNil() {
|
||||
@@ -1009,6 +1012,31 @@ func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Nod
|
||||
return reflect.Value{}, d.handleValue(value, v)
|
||||
}
|
||||
|
||||
func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, error) {
|
||||
switch {
|
||||
case stringType.AssignableTo(keyType):
|
||||
return reflect.ValueOf(string(data)), nil
|
||||
|
||||
case stringType.ConvertibleTo(keyType):
|
||||
return reflect.ValueOf(string(data)).Convert(keyType), nil
|
||||
|
||||
case keyType.Implements(textUnmarshalerType):
|
||||
mk := reflect.New(keyType.Elem())
|
||||
if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
|
||||
}
|
||||
return mk, nil
|
||||
|
||||
case reflect.PointerTo(keyType).Implements(textUnmarshalerType):
|
||||
mk := reflect.New(keyType)
|
||||
if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
|
||||
return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
|
||||
}
|
||||
return mk.Elem(), nil
|
||||
}
|
||||
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType)
|
||||
}
|
||||
|
||||
func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
|
||||
// contains the replacement for v
|
||||
var rv reflect.Value
|
||||
@@ -1019,16 +1047,9 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
|
||||
case reflect.Map:
|
||||
vt := v.Type()
|
||||
|
||||
mk := reflect.ValueOf(string(key.Node().Data))
|
||||
mkt := stringType
|
||||
|
||||
keyType := vt.Key()
|
||||
if !mkt.AssignableTo(keyType) {
|
||||
if !mkt.ConvertibleTo(keyType) {
|
||||
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mkt, keyType)
|
||||
}
|
||||
|
||||
mk = mk.Convert(keyType)
|
||||
mk, err := d.keyFromData(vt.Key(), key.Node().Data)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// If the map does not exist, create it.
|
||||
|
||||
+127
-1
@@ -16,6 +16,27 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type unmarshalTextKey struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
|
||||
func (k *unmarshalTextKey) UnmarshalText(text []byte) error {
|
||||
parts := strings.Split(string(text), "-")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid text key: %s", text)
|
||||
}
|
||||
k.A = parts[0]
|
||||
k.B = parts[1]
|
||||
return nil
|
||||
}
|
||||
|
||||
type unmarshalBadTextKey struct{}
|
||||
|
||||
func (k *unmarshalBadTextKey) UnmarshalText(text []byte) error {
|
||||
return fmt.Errorf("error")
|
||||
}
|
||||
|
||||
func ExampleDecoder_DisallowUnknownFields() {
|
||||
type S struct {
|
||||
Key1 string
|
||||
@@ -315,6 +336,7 @@ func TestUnmarshal(t *testing.T) {
|
||||
target interface{}
|
||||
expected interface{}
|
||||
err bool
|
||||
assert func(t *testing.T, test test)
|
||||
}
|
||||
examples := []struct {
|
||||
skip bool
|
||||
@@ -350,6 +372,96 @@ func TestUnmarshal(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "kv text key",
|
||||
input: `a-1 = "foo"`,
|
||||
gen: func() test {
|
||||
type doc = map[unmarshalTextKey]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
expected: &doc{{A: "a", B: "1"}: "foo"},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "table text key",
|
||||
input: `["a-1"]
|
||||
foo = "bar"`,
|
||||
gen: func() test {
|
||||
type doc = map[unmarshalTextKey]map[string]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
expected: &doc{{A: "a", B: "1"}: map[string]string{"foo": "bar"}},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "kv ptr text key",
|
||||
input: `a-1 = "foo"`,
|
||||
gen: func() test {
|
||||
type doc = map[*unmarshalTextKey]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
expected: &doc{{A: "a", B: "1"}: "foo"},
|
||||
assert: func(t *testing.T, test test) {
|
||||
// Despite the documentation:
|
||||
// Pointer variable equality is determined based on the equality of the
|
||||
// referenced values (as opposed to the memory addresses).
|
||||
// assert.Equal does not work properly with maps with pointer keys
|
||||
// https://github.com/stretchr/testify/issues/1143
|
||||
expected := make(map[unmarshalTextKey]string)
|
||||
for k, v := range *(test.expected.(*doc)) {
|
||||
expected[*k] = v
|
||||
}
|
||||
got := make(map[unmarshalTextKey]string)
|
||||
for k, v := range *(test.target.(*doc)) {
|
||||
got[*k] = v
|
||||
}
|
||||
assert.Equal(t, expected, got)
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "kv bad text key",
|
||||
input: `a-1 = "foo"`,
|
||||
gen: func() test {
|
||||
type doc = map[unmarshalBadTextKey]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
err: true,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "kv bad ptr text key",
|
||||
input: `a-1 = "foo"`,
|
||||
gen: func() test {
|
||||
type doc = map[*unmarshalBadTextKey]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
err: true,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "table bad text key",
|
||||
input: `["a-1"]
|
||||
foo = "bar"`,
|
||||
gen: func() test {
|
||||
type doc = map[unmarshalBadTextKey]map[string]string
|
||||
|
||||
return test{
|
||||
target: &doc{},
|
||||
err: true,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "time.time with negative zone",
|
||||
input: `a = 1979-05-27T00:32:00-07:00 `, // space intentional
|
||||
@@ -1521,6 +1633,16 @@ B = "data"`,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty map into map with invalid key type",
|
||||
input: ``,
|
||||
gen: func() test {
|
||||
return test{
|
||||
target: &map[int]string{},
|
||||
expected: &map[int]string{},
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "into map with convertible key type",
|
||||
input: `A = "hello"`,
|
||||
@@ -1777,7 +1899,11 @@ B = "data"`,
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, test.target)
|
||||
if test.assert != nil {
|
||||
test.assert(t, test)
|
||||
} else {
|
||||
assert.Equal(t, test.expected, test.target)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user