Reflect write to embeded structs

This commit is contained in:
Thomas Pelletier
2021-02-10 18:34:54 -05:00
parent e2a07a3b92
commit 6e79ce63c2
3 changed files with 106 additions and 15 deletions
+93 -2
View File
@@ -8,6 +8,13 @@ import (
"strings" "strings"
) )
// fieldGetters are functions that given a struct return a specific field
// (likely captured in their scope)
type fieldGetter func(s reflect.Value) reflect.Value
// collection of fieldGetters for a given struct type
type structFieldGetters map[string]fieldGetter
// Builder wraps a value and provides method to modify its structure. // Builder wraps a value and provides method to modify its structure.
// It is a stateful object that keeps a cursor of what part of the object is // It is a stateful object that keeps a cursor of what part of the object is
// being modified. // being modified.
@@ -17,11 +24,89 @@ type Builder struct {
// Root is always a pointer to a non-nil value. // Root is always a pointer to a non-nil value.
// Cursor is the top of the stack. // Cursor is the top of the stack.
stack []reflect.Value stack []reflect.Value
// Struct field tag to use to retrieve name.
nameTag string
// Cache of functions to access specific fields.
fieldGettersCache map[reflect.Type]structFieldGetters
}
func copyAndAppend(s []int, i int) []int {
ns := make([]int, len(s)+1)
copy(ns, s)
ns[len(ns)-1] = i
return ns
}
func (b *Builder) getOrGenerateFieldGettersRecursive(m structFieldGetters, idx []int, s reflect.Type) {
for i := 0; i < s.NumField(); i++ {
f := s.Field(i)
if f.PkgPath != "" {
// only consider exported fields
continue
}
// TODO: handle embedded structs
if f.Anonymous {
b.getOrGenerateFieldGettersRecursive(m, copyAndAppend(idx, i), f.Type)
} else {
fieldName, ok := f.Tag.Lookup(b.nameTag)
if !ok {
fieldName = f.Name
}
if len(idx) == 0 {
m[fieldName] = makeFieldGetterByIndex(i)
} else {
m[fieldName] = makeFieldGetterByIndexes(copyAndAppend(idx, i))
}
}
}
if b.fieldGettersCache == nil {
b.fieldGettersCache = make(map[reflect.Type]structFieldGetters, 1)
}
b.fieldGettersCache[s] = m
}
func (b *Builder) getOrGenerateFieldGetters(s reflect.Type) structFieldGetters {
if s.Kind() != reflect.Struct {
panic("generateFieldGetters can only be called on a struct")
}
m, ok := b.fieldGettersCache[s]
if ok {
return m
}
m = make(structFieldGetters, s.NumField())
b.getOrGenerateFieldGettersRecursive(m, nil, s)
b.fieldGettersCache[s] = m
return m
}
func makeFieldGetterByIndex(idx int) fieldGetter {
return func(s reflect.Value) reflect.Value {
return s.Field(idx)
}
}
func makeFieldGetterByIndexes(idx []int) fieldGetter {
return func(s reflect.Value) reflect.Value {
return s.FieldByIndex(idx)
}
}
func (b *Builder) fieldGetter(t reflect.Type, s string) (fieldGetter, error) {
m := b.getOrGenerateFieldGetters(t)
g, ok := m[s]
if !ok {
return nil, fmt.Errorf("field '%s' not accessible on '%s'", s, t)
}
return g, nil
} }
// NewBuilder creates a Builder to construct v. // NewBuilder creates a Builder to construct v.
// If v is nil or not a pointer, an error will be returned. // If v is nil or not a pointer, an error will be returned.
func NewBuilder(v interface{}) (Builder, error) { func NewBuilder(tag string, v interface{}) (Builder, error) {
if v == nil { if v == nil {
return Builder{}, fmt.Errorf("cannot build a nil value") return Builder{}, fmt.Errorf("cannot build a nil value")
} }
@@ -34,6 +119,7 @@ func NewBuilder(v interface{}) (Builder, error) {
return Builder{ return Builder{
root: rv.Elem(), root: rv.Elem(),
stack: []reflect.Value{rv.Elem()}, stack: []reflect.Value{rv.Elem()},
nameTag: tag,
}, nil }, nil
} }
@@ -90,7 +176,12 @@ func (b *Builder) DigField(s string) error {
return err return err
} }
f := t.FieldByName(s) g, err := b.fieldGetter(t.Type(), s)
if err != nil {
return err
}
f := g(t)
if !f.IsValid() { if !f.IsValid() {
return FieldNotFoundError{FieldName: s, Struct: t} return FieldNotFoundError{FieldName: s, Struct: t}
} }
+10 -10
View File
@@ -11,17 +11,17 @@ import (
func TestNewBuilderSuccess(t *testing.T) { func TestNewBuilderSuccess(t *testing.T) {
x := struct{}{} x := struct{}{}
_, err := reflectbuild.NewBuilder(&x) _, err := reflectbuild.NewBuilder("", &x)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestNewBuilderNil(t *testing.T) { func TestNewBuilderNil(t *testing.T) {
_, err := reflectbuild.NewBuilder(nil) _, err := reflectbuild.NewBuilder("", nil)
assert.Error(t, err) assert.Error(t, err)
} }
func TestNewBuilderNonPtr(t *testing.T) { func TestNewBuilderNonPtr(t *testing.T) {
_, err := reflectbuild.NewBuilder(struct{}{}) _, err := reflectbuild.NewBuilder("", struct{}{})
assert.Error(t, err) assert.Error(t, err)
} }
@@ -29,7 +29,7 @@ func TestDigField(t *testing.T) {
x := struct { x := struct {
Field string Field string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
assert.Error(t, b.DigField("oops")) assert.Error(t, b.DigField("oops"))
assert.NoError(t, b.DigField("Field")) assert.NoError(t, b.DigField("Field"))
@@ -41,7 +41,7 @@ func TestBack(t *testing.T) {
A string A string
B string B string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
b.Save() b.Save()
assert.NoError(t, b.DigField("A")) assert.NoError(t, b.DigField("A"))
@@ -63,7 +63,7 @@ func TestReset(t *testing.T) {
A []string A []string
B string B string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, b.DigField("A")) require.NoError(t, b.DigField("A"))
require.NoError(t, b.SliceNewElem()) require.NoError(t, b.SliceNewElem())
@@ -80,7 +80,7 @@ func TestSetString(t *testing.T) {
x := struct { x := struct {
Field string Field string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
assert.Error(t, b.SetString("oops")) assert.Error(t, b.SetString("oops"))
require.NoError(t, b.DigField("Field")) require.NoError(t, b.DigField("Field"))
@@ -92,7 +92,7 @@ func TestSliceNewElem(t *testing.T) {
x := struct { x := struct {
Field []string Field []string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, b.DigField("Field")) require.NoError(t, b.DigField("Field"))
b.Save() b.Save()
@@ -112,7 +112,7 @@ func TestSliceNewElemNested(t *testing.T) {
x := struct { x := struct {
Field [][]string Field [][]string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, b.DigField("Field")) require.NoError(t, b.DigField("Field"))
@@ -159,7 +159,7 @@ func TestCursor(t *testing.T) {
x := struct { x := struct {
Field string Field string
}{} }{}
b, err := reflectbuild.NewBuilder(&x) b, err := reflectbuild.NewBuilder("", &x)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, b.Cursor().Kind(), reflect.Struct) assert.Equal(t, b.Cursor().Kind(), reflect.Struct)
require.NoError(t, b.DigField("Field")) require.NoError(t, b.DigField("Field"))
+1 -1
View File
@@ -8,7 +8,7 @@ import (
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
u := &unmarshaler{} u := &unmarshaler{}
u.builder, u.err = reflectbuild.NewBuilder(v) u.builder, u.err = reflectbuild.NewBuilder("toml", v)
if u.err == nil { if u.err == nil {
parseErr := parser{builder: u}.parse(data) parseErr := parser{builder: u}.parse(data)
if parseErr != nil { if parseErr != nil {