diff --git a/internal/imported_tests/unmarshal_imported_test.go b/internal/imported_tests/unmarshal_imported_test.go index 1f1d894..37dc444 100644 --- a/internal/imported_tests/unmarshal_imported_test.go +++ b/internal/imported_tests/unmarshal_imported_test.go @@ -927,6 +927,29 @@ func TestUnmarshalMapWithTypedKey(t *testing.T) { } } +func TestUnmarshalTypeTableHeader(t *testing.T) { + testToml := []byte(` + [test] + a = 1 + `) + + type header string + var result map[header]map[string]int + err := toml.Unmarshal(testToml, &result) + if err != nil { + t.Errorf("Received unexpected error: %s", err) + return + } + + expected := map[header]map[string]int{ + "test": map[string]int{"a": 1}, + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Bad unmarshal: expected %v, got %v", expected, result) + } +} + func TestUnmarshalNonPointer(t *testing.T) { a := 1 err := toml.Unmarshal([]byte{}, a) diff --git a/unmarshaler.go b/unmarshaler.go index ba997a7..a333419 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -412,9 +412,10 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle elem = v.Elem() return d.handleKeyPart(key, elem, nextFn, makeFn) case reflect.Map: + vt := v.Type() - // Create the key for the map element. For now assume it's a string. - mk := reflect.ValueOf(string(key.Node().Data)) + // Create the key for the map element. Convert to key type. + mk := reflect.ValueOf(string(key.Node().Data)).Convert(vt.Key()) // If the map does not exist, create it. if v.IsNil() { @@ -431,7 +432,6 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle // map[string]interface{} or a []interface{} depending on whether // this is the last part of the array table key. - vt := v.Type() t := vt.Elem() if t.Kind() == reflect.Interface { mv = makeFn()