Decode: convert table key to correct type (#741)

Fixes #740.
This commit is contained in:
Gregory Oschwald
2022-03-02 06:24:01 -08:00
committed by GitHub
parent 3f5d8a6b06
commit 3229a0abfb
2 changed files with 26 additions and 3 deletions
@@ -927,6 +927,29 @@ func TestUnmarshalMapWithTypedKey(t *testing.T) {
} }
} }
func TestUnmarshalTypeTableHeader(t *testing.T) {
testToml := []byte(`
[test]
a = 1
`)
type header string
var result map[header]map[string]int
err := toml.Unmarshal(testToml, &result)
if err != nil {
t.Errorf("Received unexpected error: %s", err)
return
}
expected := map[header]map[string]int{
"test": map[string]int{"a": 1},
}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Bad unmarshal: expected %v, got %v", expected, result)
}
}
func TestUnmarshalNonPointer(t *testing.T) { func TestUnmarshalNonPointer(t *testing.T) {
a := 1 a := 1
err := toml.Unmarshal([]byte{}, a) err := toml.Unmarshal([]byte{}, a)
+3 -3
View File
@@ -412,9 +412,10 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
elem = v.Elem() elem = v.Elem()
return d.handleKeyPart(key, elem, nextFn, makeFn) return d.handleKeyPart(key, elem, nextFn, makeFn)
case reflect.Map: case reflect.Map:
vt := v.Type()
// Create the key for the map element. For now assume it's a string. // Create the key for the map element. Convert to key type.
mk := reflect.ValueOf(string(key.Node().Data)) mk := reflect.ValueOf(string(key.Node().Data)).Convert(vt.Key())
// If the map does not exist, create it. // If the map does not exist, create it.
if v.IsNil() { if v.IsNil() {
@@ -431,7 +432,6 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// map[string]interface{} or a []interface{} depending on whether // map[string]interface{} or a []interface{} depending on whether
// this is the last part of the array table key. // this is the last part of the array table key.
vt := v.Type()
t := vt.Elem() t := vt.Elem()
if t.Kind() == reflect.Interface { if t.Kind() == reflect.Interface {
mv = makeFn() mv = makeFn()