Reflect write to embeded structs

This commit is contained in:
Thomas Pelletier
2021-02-10 18:34:54 -05:00
parent e2a07a3b92
commit 6e79ce63c2
3 changed files with 106 additions and 15 deletions
+95 -4
View File
@@ -8,6 +8,13 @@ import (
"strings"
)
// fieldGetters are functions that given a struct return a specific field
// (likely captured in their scope)
type fieldGetter func(s reflect.Value) reflect.Value
// collection of fieldGetters for a given struct type
type structFieldGetters map[string]fieldGetter
// Builder wraps a value and provides method to modify its structure.
// It is a stateful object that keeps a cursor of what part of the object is
// being modified.
@@ -17,11 +24,89 @@ type Builder struct {
// Root is always a pointer to a non-nil value.
// Cursor is the top of the stack.
stack []reflect.Value
// Struct field tag to use to retrieve name.
nameTag string
// Cache of functions to access specific fields.
fieldGettersCache map[reflect.Type]structFieldGetters
}
func copyAndAppend(s []int, i int) []int {
ns := make([]int, len(s)+1)
copy(ns, s)
ns[len(ns)-1] = i
return ns
}
func (b *Builder) getOrGenerateFieldGettersRecursive(m structFieldGetters, idx []int, s reflect.Type) {
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
if f.PkgPath != "" {
// only consider exported fields
continue
}
// TODO: handle embedded structs
if f.Anonymous {
b.getOrGenerateFieldGettersRecursive(m, copyAndAppend(idx, i), f.Type)
} else {
fieldName, ok := f.Tag.Lookup(b.nameTag)
if !ok {
fieldName = f.Name
}
if len(idx) == 0 {
m[fieldName] = makeFieldGetterByIndex(i)
} else {
m[fieldName] = makeFieldGetterByIndexes(copyAndAppend(idx, i))
}
}
}
if b.fieldGettersCache == nil {
b.fieldGettersCache = make(map[reflect.Type]structFieldGetters, 1)
}
b.fieldGettersCache[s] = m
}
func (b *Builder) getOrGenerateFieldGetters(s reflect.Type) structFieldGetters {
if s.Kind() != reflect.Struct {
panic("generateFieldGetters can only be called on a struct")
}
m, ok := b.fieldGettersCache[s]
if ok {
return m
}
m = make(structFieldGetters, s.NumField())
b.getOrGenerateFieldGettersRecursive(m, nil, s)
b.fieldGettersCache[s] = m
return m
}
func makeFieldGetterByIndex(idx int) fieldGetter {
return func(s reflect.Value) reflect.Value {
return s.Field(idx)
}
}
func makeFieldGetterByIndexes(idx []int) fieldGetter {
return func(s reflect.Value) reflect.Value {
return s.FieldByIndex(idx)
}
}
func (b *Builder) fieldGetter(t reflect.Type, s string) (fieldGetter, error) {
m := b.getOrGenerateFieldGetters(t)
g, ok := m[s]
if !ok {
return nil, fmt.Errorf("field '%s' not accessible on '%s'", s, t)
}
return g, nil
}
// NewBuilder creates a Builder to construct v.
// If v is nil or not a pointer, an error will be returned.
func NewBuilder(v interface{}) (Builder, error) {
func NewBuilder(tag string, v interface{}) (Builder, error) {
if v == nil {
return Builder{}, fmt.Errorf("cannot build a nil value")
}
@@ -32,8 +117,9 @@ func NewBuilder(v interface{}) (Builder, error) {
}
return Builder{
root: rv.Elem(),
stack: []reflect.Value{rv.Elem()},
root: rv.Elem(),
stack: []reflect.Value{rv.Elem()},
nameTag: tag,
}, nil
}
@@ -90,7 +176,12 @@ func (b *Builder) DigField(s string) error {
return err
}
f := t.FieldByName(s)
g, err := b.fieldGetter(t.Type(), s)
if err != nil {
return err
}
f := g(t)
if !f.IsValid() {
return FieldNotFoundError{FieldName: s, Struct: t}
}
+10 -10
View File
@@ -11,17 +11,17 @@ import (
func TestNewBuilderSuccess(t *testing.T) {
x := struct{}{}
_, err := reflectbuild.NewBuilder(&x)
_, err := reflectbuild.NewBuilder("", &x)
assert.NoError(t, err)
}
func TestNewBuilderNil(t *testing.T) {
_, err := reflectbuild.NewBuilder(nil)
_, err := reflectbuild.NewBuilder("", nil)
assert.Error(t, err)
}
func TestNewBuilderNonPtr(t *testing.T) {
_, err := reflectbuild.NewBuilder(struct{}{})
_, err := reflectbuild.NewBuilder("", struct{}{})
assert.Error(t, err)
}
@@ -29,7 +29,7 @@ func TestDigField(t *testing.T) {
x := struct {
Field string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
assert.Error(t, b.DigField("oops"))
assert.NoError(t, b.DigField("Field"))
@@ -41,7 +41,7 @@ func TestBack(t *testing.T) {
A string
B string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
b.Save()
assert.NoError(t, b.DigField("A"))
@@ -63,7 +63,7 @@ func TestReset(t *testing.T) {
A []string
B string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
require.NoError(t, b.DigField("A"))
require.NoError(t, b.SliceNewElem())
@@ -80,7 +80,7 @@ func TestSetString(t *testing.T) {
x := struct {
Field string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
assert.Error(t, b.SetString("oops"))
require.NoError(t, b.DigField("Field"))
@@ -92,7 +92,7 @@ func TestSliceNewElem(t *testing.T) {
x := struct {
Field []string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
require.NoError(t, b.DigField("Field"))
b.Save()
@@ -112,7 +112,7 @@ func TestSliceNewElemNested(t *testing.T) {
x := struct {
Field [][]string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
require.NoError(t, b.DigField("Field"))
@@ -159,7 +159,7 @@ func TestCursor(t *testing.T) {
x := struct {
Field string
}{}
b, err := reflectbuild.NewBuilder(&x)
b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err)
assert.Equal(t, b.Cursor().Kind(), reflect.Struct)
require.NoError(t, b.DigField("Field"))
+1 -1
View File
@@ -8,7 +8,7 @@ import (
func Unmarshal(data []byte, v interface{}) error {
u := &unmarshaler{}
u.builder, u.err = reflectbuild.NewBuilder(v)
u.builder, u.err = reflectbuild.NewBuilder("toml", v)
if u.err == nil {
parseErr := parser{builder: u}.parse(data)
if parseErr != nil {