@@ -621,8 +621,7 @@ func (r *Rows) StructScan(dest interface{}) error {
621621 r .started = true
622622 }
623623
624- octx := newObjectContext ()
625- err := fieldsByTraversal (octx , v , r .fields , r .values )
624+ err := fieldsByTraversal (v , r .fields , r .values )
626625 if err != nil {
627626 return err
628627 }
@@ -782,9 +781,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
782781 }
783782 values := make ([]interface {}, len (columns ))
784783
785- octx := newObjectContext ()
786-
787- err = fieldsByTraversal (octx , v , fields , values )
784+ err = fieldsByTraversal (v , fields , values )
788785 if err != nil {
789786 return err
790787 }
@@ -951,14 +948,13 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
951948 return fmt .Errorf ("missing destination name %s in %T" , columns [f ], dest )
952949 }
953950 values = make ([]interface {}, len (columns ))
954- octx := newObjectContext ()
955951
956952 for rows .Next () {
957953 // create a new struct type (which returns PtrTo) and indirect it
958954 vp = reflect .New (base )
959955 v = reflect .Indirect (vp )
960956
961- err = fieldsByTraversal (octx , v , fields , values )
957+ err = fieldsByTraversal (v , fields , values )
962958 if err != nil {
963959 return err
964960 }
@@ -1024,21 +1020,23 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
10241020// when iterating over many rows. Empty traversals will get an interface pointer.
10251021// Because of the necessity of requesting ptrs or values, it's considered a bit too
10261022// specialized for inclusion in reflectx itself.
1027- func fieldsByTraversal (octx * objectContext , v reflect.Value , traversals [][]int , values []interface {}) error {
1023+ func fieldsByTraversal (v reflect.Value , traversals [][]int , values []interface {}) error {
10281024 v = reflect .Indirect (v )
10291025 if v .Kind () != reflect .Struct {
10301026 return errors .New ("argument not a struct" )
10311027 }
10321028
1033- octx .NewRow (v )
1034-
10351029 for i , traversal := range traversals {
10361030 if len (traversal ) == 0 {
10371031 values [i ] = new (interface {})
1038- continue
1032+ } else if len (traversal ) == 1 {
1033+ values [i ] = reflectx .FieldByIndexes (v , traversal ).Addr ().Interface ()
1034+ } else {
1035+ traversal := traversal
1036+ values [i ] = optDest (func () interface {} {
1037+ return reflectx .FieldByIndexes (v , traversal ).Addr ().Interface ()
1038+ })
10391039 }
1040- f := octx .FieldForIndexes (traversal )
1041- values [i ] = f .Addr ().Interface ()
10421040 }
10431041 return nil
10441042}
@@ -1052,49 +1050,14 @@ func missingFields(traversals [][]int) (field int, err error) {
10521050 return 0 , nil
10531051}
10541052
1055- // objectContext provides a single layer to abstract away
1056- // nested struct scanning functionality
1057- type objectContext struct {
1058- value reflect.Value
1059- }
1060-
1061- func newObjectContext () * objectContext {
1062- return & objectContext {}
1063- }
1064-
1065- // NewRow updates the object reference.
1066- // This ensures all columns point to the same object
1067- func (o * objectContext ) NewRow (value reflect.Value ) {
1068- o .value = value
1069- }
1070-
1071- // FieldForIndexes returns the value for address. If the address is a nested struct,
1072- // a nestedFieldScanner is returned instead of the standard value reference
1073- func (o * objectContext ) FieldForIndexes (indexes []int ) reflect.Value {
1074- if len (indexes ) == 1 {
1075- return reflectx .FieldByIndexes (o .value , indexes )
1076- }
1077-
1078- obj := & nestedFieldScanner {
1079- parent : o ,
1080- indexes : indexes ,
1081- }
1082-
1083- return reflect .ValueOf (obj ).Elem ()
1084- }
1085-
1086- // nestedFieldScanner will only forward the Scan to the nested value if
1053+ // optDest will only forward the Scan to the nested value if
10871054// the database value is not nil.
1088- type nestedFieldScanner struct {
1089- parent * objectContext
1090- indexes []int
1091- }
1055+ type optDest func () interface {}
10921056
10931057// Scan implements sql.Scanner.
1094- func (o * nestedFieldScanner ) Scan (src interface {}) error {
1058+ func (dest optDest ) Scan (src interface {}) error {
10951059 if src == nil {
10961060 return nil
10971061 }
1098- dest := reflectx .FieldByIndexes (o .parent .value , o .indexes )
1099- return convertAssign (dest .Addr ().Interface (), src )
1062+ return convertAssign (dest (), src )
11001063}
0 commit comments