diff --git a/marshaler.go b/marshaler.go index b3b177e..6ab1d82 100644 --- a/marshaler.go +++ b/marshaler.go @@ -577,11 +577,23 @@ func (enc *Encoder) encodeKey(b []byte, k string) []byte { } } -func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { - if v.Type().Key().Kind() != reflect.String { - return nil, fmt.Errorf("toml: type %s is not supported as a map key", v.Type().Key().Kind()) - } +func (enc *Encoder) keyToString(k reflect.Value) (string, error) { + keyType := k.Type() + switch { + case keyType.Kind() == reflect.String: + return k.String(), nil + case keyType.Implements(textMarshalerType): + keyB, err := k.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err) + } + return string(keyB), nil + } + return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind()) +} + +func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { var ( t table emptyValueOptions valueOptions @@ -589,13 +601,17 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte iter := v.MapRange() for iter.Next() { - k := iter.Key().String() v := iter.Value() if isNil(v) { continue } + k, err := enc.keyToString(iter.Key()) + if err != nil { + return nil, err + } + if willConvertToTableOrArrayTable(ctx, v) { t.pushTable(k, v, emptyValueOptions) } else { diff --git a/marshaler_test.go b/marshaler_test.go index 0a9d744..d9e44f0 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -15,6 +15,21 @@ import ( "github.com/stretchr/testify/require" ) +type marshalTextKey struct { + A string + B string +} + +func (k marshalTextKey) MarshalText() ([]byte, error) { + return []byte(k.A + "-" + k.B), nil +} + +type marshalBadTextKey struct{} + +func (k marshalBadTextKey) MarshalText() ([]byte, error) { + return nil, fmt.Errorf("error") +} + func TestMarshal(t *testing.T) { someInt := 42 @@ -97,6 +112,53 @@ also = 'that' a = 'test' `, }, + { + desc: `map with text key`, + v: map[marshalTextKey]string{ + {A: "a", B: "1"}: "value 1", + {A: "a", B: "2"}: "value 2", + {A: "b", B: "1"}: "value 3", + }, + expected: `a-1 = 'value 1' +a-2 = 'value 2' +b-1 = 'value 3' +`, + }, + { + desc: `table with text key`, + v: map[marshalTextKey]map[string]string{ + {A: "a", B: "1"}: {"value": "foo"}, + }, + expected: `[a-1] +value = 'foo' +`, + }, + { + desc: `map with ptr text key`, + v: map[*marshalTextKey]string{ + {A: "a", B: "1"}: "value 1", + {A: "a", B: "2"}: "value 2", + {A: "b", B: "1"}: "value 3", + }, + expected: `a-1 = 'value 1' +a-2 = 'value 2' +b-1 = 'value 3' +`, + }, + { + desc: `map with bad text key`, + v: map[marshalBadTextKey]string{ + {}: "value 1", + }, + err: true, + }, + { + desc: `map with bad ptr text key`, + v: map[*marshalBadTextKey]string{ + {}: "value 1", + }, + err: true, + }, { desc: "simple string array", v: map[string][]string{ @@ -487,9 +549,14 @@ foo = 42 }, { desc: "invalid map key", - v: map[int]interface{}{}, + v: map[int]interface{}{1: "a"}, err: true, }, + { + desc: "invalid map key but empty", + v: map[int]interface{}{}, + expected: "", + }, { desc: "unhandled type", v: struct { diff --git a/unmarshaler.go b/unmarshaler.go index 5c19845..bab1121 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -417,7 +417,10 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h vt := v.Type() // Create the key for the map element. Convert to key type. - mk := reflect.ValueOf(string(key.Node().Data)).Convert(vt.Key()) + mk, err := d.keyFromData(vt.Key(), key.Node().Data) + if err != nil { + return reflect.Value{}, err + } // If the map does not exist, create it. if v.IsNil() { @@ -1009,6 +1012,31 @@ func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Nod return reflect.Value{}, d.handleValue(value, v) } +func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, error) { + switch { + case stringType.AssignableTo(keyType): + return reflect.ValueOf(string(data)), nil + + case stringType.ConvertibleTo(keyType): + return reflect.ValueOf(string(data)).Convert(keyType), nil + + case keyType.Implements(textUnmarshalerType): + mk := reflect.New(keyType.Elem()) + if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err) + } + return mk, nil + + case reflect.PointerTo(keyType).Implements(textUnmarshalerType): + mk := reflect.New(keyType) + if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err) + } + return mk.Elem(), nil + } + return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType) +} + func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) { // contains the replacement for v var rv reflect.Value @@ -1019,16 +1047,9 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node case reflect.Map: vt := v.Type() - mk := reflect.ValueOf(string(key.Node().Data)) - mkt := stringType - - keyType := vt.Key() - if !mkt.AssignableTo(keyType) { - if !mkt.ConvertibleTo(keyType) { - return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", mkt, keyType) - } - - mk = mk.Convert(keyType) + mk, err := d.keyFromData(vt.Key(), key.Node().Data) + if err != nil { + return reflect.Value{}, err } // If the map does not exist, create it. diff --git a/unmarshaler_test.go b/unmarshaler_test.go index 3c34425..1cc17d0 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -16,6 +16,27 @@ import ( "github.com/stretchr/testify/require" ) +type unmarshalTextKey struct { + A string + B string +} + +func (k *unmarshalTextKey) UnmarshalText(text []byte) error { + parts := strings.Split(string(text), "-") + if len(parts) != 2 { + return fmt.Errorf("invalid text key: %s", text) + } + k.A = parts[0] + k.B = parts[1] + return nil +} + +type unmarshalBadTextKey struct{} + +func (k *unmarshalBadTextKey) UnmarshalText(text []byte) error { + return fmt.Errorf("error") +} + func ExampleDecoder_DisallowUnknownFields() { type S struct { Key1 string @@ -315,6 +336,7 @@ func TestUnmarshal(t *testing.T) { target interface{} expected interface{} err bool + assert func(t *testing.T, test test) } examples := []struct { skip bool @@ -350,6 +372,96 @@ func TestUnmarshal(t *testing.T) { } }, }, + { + desc: "kv text key", + input: `a-1 = "foo"`, + gen: func() test { + type doc = map[unmarshalTextKey]string + + return test{ + target: &doc{}, + expected: &doc{{A: "a", B: "1"}: "foo"}, + } + }, + }, + { + desc: "table text key", + input: `["a-1"] +foo = "bar"`, + gen: func() test { + type doc = map[unmarshalTextKey]map[string]string + + return test{ + target: &doc{}, + expected: &doc{{A: "a", B: "1"}: map[string]string{"foo": "bar"}}, + } + }, + }, + { + desc: "kv ptr text key", + input: `a-1 = "foo"`, + gen: func() test { + type doc = map[*unmarshalTextKey]string + + return test{ + target: &doc{}, + expected: &doc{{A: "a", B: "1"}: "foo"}, + assert: func(t *testing.T, test test) { + // Despite the documentation: + // Pointer variable equality is determined based on the equality of the + // referenced values (as opposed to the memory addresses). + // assert.Equal does not work properly with maps with pointer keys + // https://github.com/stretchr/testify/issues/1143 + expected := make(map[unmarshalTextKey]string) + for k, v := range *(test.expected.(*doc)) { + expected[*k] = v + } + got := make(map[unmarshalTextKey]string) + for k, v := range *(test.target.(*doc)) { + got[*k] = v + } + assert.Equal(t, expected, got) + }, + } + }, + }, + { + desc: "kv bad text key", + input: `a-1 = "foo"`, + gen: func() test { + type doc = map[unmarshalBadTextKey]string + + return test{ + target: &doc{}, + err: true, + } + }, + }, + { + desc: "kv bad ptr text key", + input: `a-1 = "foo"`, + gen: func() test { + type doc = map[*unmarshalBadTextKey]string + + return test{ + target: &doc{}, + err: true, + } + }, + }, + { + desc: "table bad text key", + input: `["a-1"] +foo = "bar"`, + gen: func() test { + type doc = map[unmarshalBadTextKey]map[string]string + + return test{ + target: &doc{}, + err: true, + } + }, + }, { desc: "time.time with negative zone", input: `a = 1979-05-27T00:32:00-07:00 `, // space intentional @@ -1521,6 +1633,16 @@ B = "data"`, } }, }, + { + desc: "empty map into map with invalid key type", + input: ``, + gen: func() test { + return test{ + target: &map[int]string{}, + expected: &map[int]string{}, + } + }, + }, { desc: "into map with convertible key type", input: `A = "hello"`, @@ -1777,7 +1899,11 @@ B = "data"`, require.Error(t, err) } else { require.NoError(t, err) - assert.Equal(t, test.expected, test.target) + if test.assert != nil { + test.assert(t, test) + } else { + assert.Equal(t, test.expected, test.target) + } } }) }