diff --git a/internal/reflectbuild/reflectbuild.go b/internal/reflectbuild/reflectbuild.go index 18898b7..d79175a 100644 --- a/internal/reflectbuild/reflectbuild.go +++ b/internal/reflectbuild/reflectbuild.go @@ -8,6 +8,13 @@ import ( "strings" ) +// fieldGetters are functions that given a struct return a specific field +// (likely captured in their scope) +type fieldGetter func(s reflect.Value) reflect.Value + +// collection of fieldGetters for a given struct type +type structFieldGetters map[string]fieldGetter + // Builder wraps a value and provides method to modify its structure. // It is a stateful object that keeps a cursor of what part of the object is // being modified. @@ -17,11 +24,89 @@ type Builder struct { // Root is always a pointer to a non-nil value. // Cursor is the top of the stack. stack []reflect.Value + // Struct field tag to use to retrieve name. + nameTag string + // Cache of functions to access specific fields. + fieldGettersCache map[reflect.Type]structFieldGetters +} + +func copyAndAppend(s []int, i int) []int { + ns := make([]int, len(s)+1) + copy(ns, s) + ns[len(ns)-1] = i + return ns +} + +func (b *Builder) getOrGenerateFieldGettersRecursive(m structFieldGetters, idx []int, s reflect.Type) { + for i := 0; i < s.NumField(); i++ { + f := s.Field(i) + if f.PkgPath != "" { + // only consider exported fields + continue + } + // TODO: handle embedded structs + if f.Anonymous { + b.getOrGenerateFieldGettersRecursive(m, copyAndAppend(idx, i), f.Type) + } else { + fieldName, ok := f.Tag.Lookup(b.nameTag) + if !ok { + fieldName = f.Name + } + + if len(idx) == 0 { + m[fieldName] = makeFieldGetterByIndex(i) + } else { + m[fieldName] = makeFieldGetterByIndexes(copyAndAppend(idx, i)) + } + } + } + + if b.fieldGettersCache == nil { + b.fieldGettersCache = make(map[reflect.Type]structFieldGetters, 1) + } + + b.fieldGettersCache[s] = m +} + +func (b *Builder) getOrGenerateFieldGetters(s reflect.Type) structFieldGetters { + if s.Kind() != reflect.Struct { + panic("generateFieldGetters can only be called on a struct") + } + m, ok := b.fieldGettersCache[s] + if ok { + return m + } + + m = make(structFieldGetters, s.NumField()) + b.getOrGenerateFieldGettersRecursive(m, nil, s) + b.fieldGettersCache[s] = m + return m +} + +func makeFieldGetterByIndex(idx int) fieldGetter { + return func(s reflect.Value) reflect.Value { + return s.Field(idx) + } +} + +func makeFieldGetterByIndexes(idx []int) fieldGetter { + return func(s reflect.Value) reflect.Value { + return s.FieldByIndex(idx) + } +} + +func (b *Builder) fieldGetter(t reflect.Type, s string) (fieldGetter, error) { + m := b.getOrGenerateFieldGetters(t) + g, ok := m[s] + if !ok { + return nil, fmt.Errorf("field '%s' not accessible on '%s'", s, t) + } + return g, nil } // NewBuilder creates a Builder to construct v. // If v is nil or not a pointer, an error will be returned. -func NewBuilder(v interface{}) (Builder, error) { +func NewBuilder(tag string, v interface{}) (Builder, error) { if v == nil { return Builder{}, fmt.Errorf("cannot build a nil value") } @@ -32,8 +117,9 @@ func NewBuilder(v interface{}) (Builder, error) { } return Builder{ - root: rv.Elem(), - stack: []reflect.Value{rv.Elem()}, + root: rv.Elem(), + stack: []reflect.Value{rv.Elem()}, + nameTag: tag, }, nil } @@ -90,7 +176,12 @@ func (b *Builder) DigField(s string) error { return err } - f := t.FieldByName(s) + g, err := b.fieldGetter(t.Type(), s) + if err != nil { + return err + } + + f := g(t) if !f.IsValid() { return FieldNotFoundError{FieldName: s, Struct: t} } diff --git a/internal/reflectbuild/reflectbuild_test.go b/internal/reflectbuild/reflectbuild_test.go index 433a341..9c80dfe 100644 --- a/internal/reflectbuild/reflectbuild_test.go +++ b/internal/reflectbuild/reflectbuild_test.go @@ -11,17 +11,17 @@ import ( func TestNewBuilderSuccess(t *testing.T) { x := struct{}{} - _, err := reflectbuild.NewBuilder(&x) + _, err := reflectbuild.NewBuilder("", &x) assert.NoError(t, err) } func TestNewBuilderNil(t *testing.T) { - _, err := reflectbuild.NewBuilder(nil) + _, err := reflectbuild.NewBuilder("", nil) assert.Error(t, err) } func TestNewBuilderNonPtr(t *testing.T) { - _, err := reflectbuild.NewBuilder(struct{}{}) + _, err := reflectbuild.NewBuilder("", struct{}{}) assert.Error(t, err) } @@ -29,7 +29,7 @@ func TestDigField(t *testing.T) { x := struct { Field string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) assert.Error(t, b.DigField("oops")) assert.NoError(t, b.DigField("Field")) @@ -41,7 +41,7 @@ func TestBack(t *testing.T) { A string B string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) b.Save() assert.NoError(t, b.DigField("A")) @@ -63,7 +63,7 @@ func TestReset(t *testing.T) { A []string B string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) require.NoError(t, b.DigField("A")) require.NoError(t, b.SliceNewElem()) @@ -80,7 +80,7 @@ func TestSetString(t *testing.T) { x := struct { Field string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) assert.Error(t, b.SetString("oops")) require.NoError(t, b.DigField("Field")) @@ -92,7 +92,7 @@ func TestSliceNewElem(t *testing.T) { x := struct { Field []string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) require.NoError(t, b.DigField("Field")) b.Save() @@ -112,7 +112,7 @@ func TestSliceNewElemNested(t *testing.T) { x := struct { Field [][]string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) require.NoError(t, b.DigField("Field")) @@ -159,7 +159,7 @@ func TestCursor(t *testing.T) { x := struct { Field string }{} - b, err := reflectbuild.NewBuilder(&x) + b, err := reflectbuild.NewBuilder("", &x) require.NoError(t, err) assert.Equal(t, b.Cursor().Kind(), reflect.Struct) require.NoError(t, b.DigField("Field")) diff --git a/unmarshal.go b/unmarshal.go index 6462999..da2cf63 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -8,7 +8,7 @@ import ( func Unmarshal(data []byte, v interface{}) error { u := &unmarshaler{} - u.builder, u.err = reflectbuild.NewBuilder(v) + u.builder, u.err = reflectbuild.NewBuilder("toml", v) if u.err == nil { parseErr := parser{builder: u}.parse(data) if parseErr != nil {