Skip to content

Commit d755197

Browse files
authored
Merge pull request Vinovest#14 from Vinovest/mj/optional-nested-structs
Optional nested structs
2 parents 0a8d738 + 1f123ff commit d755197

6 files changed

Lines changed: 239 additions & 75 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ explains how to use `database/sql` along with sqlx.
2929

3030
## Changes compared to the original sqlx
3131

32+
* Better scanning in the case of outer joins. If a struct contains a nested
33+
struct pointer, it will no longer be a scan error.
34+
3235
* Made complex joins easier to scan by using the position of the field
3336
to help map duplicate column names into structs. See the [joins
3437
example](./examples/joins/main.go).

convert.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package sqlx
2+
3+
import (
4+
_ "unsafe"
5+
)
6+
7+
//go:linkname convertAssign database/sql.convertAssign
8+
func convertAssign(dest, src interface{}) error

examples/generics/main.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ import (
1515
// docker run --name sqlxpg -p 5444:5432 -e POSTGRES_PASSWORD=password -d docker.io/postgres:17.4
1616

1717
const schema = `
18-
CREATE TABLE IF NOT EXISTS person (
18+
DROP TABLE IF EXISTS person;
19+
CREATE TABLE person (
1920
id SERIAL PRIMARY KEY,
2021
first_name text,
2122
last_name text,
2223
email text
2324
);
2425
TRUNCATE TABLE person;
2526
26-
CREATE TABLE IF NOT EXISTS place (
27+
DROP TABLE IF EXISTS place;
28+
CREATE TABLE place (
2729
country text,
2830
city text NULL,
2931
telcode integer

reflectx/reflect.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
235235
v = reflect.Indirect(v).Field(i)
236236
// if this is a pointer and it's nil, allocate a new value and set it
237237
if v.Kind() == reflect.Ptr && v.IsNil() {
238-
alloc := reflect.New(Deref(v.Type()))
239-
v.Set(alloc)
238+
v.Set(reflect.New(v.Type().Elem()))
240239
}
241240
if v.Kind() == reflect.Map && v.IsNil() {
242241
v.Set(reflect.MakeMap(v.Type()))

sqlx.go

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ func (r *Rows) StructScan(dest interface{}) error {
773773
r.started = true
774774
}
775775

776-
err := fieldsByTraversal(v, r.fields, r.values, true)
776+
err := fieldsByTraversal(v, r.fields, r.values)
777777
if err != nil {
778778
return err
779779
}
@@ -990,7 +990,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
990990
}
991991
values := make([]interface{}, len(columns))
992992

993-
err = fieldsByTraversal(v, fields, values, true)
993+
err = fieldsByTraversal(v, fields, values)
994994
if err != nil {
995995
return err
996996
}
@@ -1165,7 +1165,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
11651165
vp = reflect.New(base)
11661166
v = reflect.Indirect(vp)
11671167

1168-
err = fieldsByTraversal(v, fields, values, true)
1168+
err = fieldsByTraversal(v, fields, values)
11691169
if err != nil {
11701170
return err
11711171
}
@@ -1231,7 +1231,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
12311231
// when iterating over many rows. Empty traversals will get an interface pointer.
12321232
// Because of the necessity of requesting ptrs or values, it's considered a bit too
12331233
// specialized for inclusion in reflectx itself.
1234-
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
1234+
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error {
12351235
v = reflect.Indirect(v)
12361236
if v.Kind() != reflect.Struct {
12371237
return errors.New("argument not a struct")
@@ -1240,23 +1240,37 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}
12401240
for i, traversal := range traversals {
12411241
if len(traversal) == 0 {
12421242
values[i] = new(interface{})
1243-
continue
1244-
}
1245-
f := reflectx.FieldByIndexes(v, traversal)
1246-
if ptrs {
1247-
values[i] = f.Addr().Interface()
1243+
} else if len(traversal) == 1 {
1244+
values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface()
12481245
} else {
1249-
values[i] = f.Interface()
1246+
// reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs.
1247+
// Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct.
1248+
// That way we can support LEFT JOINs with optional nested structs.
1249+
values[i] = optDest(func() interface{} {
1250+
return reflectx.FieldByIndexes(v, traversal).Addr().Interface()
1251+
})
12501252
}
12511253
}
12521254
return nil
12531255
}
12541256

1255-
func missingFields(transversals [][]int) (field int, err error) {
1256-
for i, t := range transversals {
1257+
func missingFields(traversals [][]int) (field int, err error) {
1258+
for i, t := range traversals {
12571259
if len(t) == 0 {
12581260
return i, errors.New("missing field")
12591261
}
12601262
}
12611263
return 0, nil
12621264
}
1265+
1266+
// optDest will only forward the Scan to the nested value if
1267+
// the database value is not nil.
1268+
type optDest func() interface{}
1269+
1270+
// Scan implements sql.Scanner.
1271+
func (dest optDest) Scan(src interface{}) error {
1272+
if src == nil {
1273+
return nil
1274+
}
1275+
return convertAssign(dest(), src)
1276+
}

0 commit comments

Comments
 (0)