Skip to content

Commit be35d15

Browse files
committed
feat: implement NamedQueryContext also on *sqlx.Tx (#26)
1 parent febe355 commit be35d15

3 files changed

Lines changed: 46 additions & 27 deletions

File tree

sqlx.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ type Queryable interface {
266266
NamedExecContext(context.Context, string, any) (sql.Result, error)
267267
MustExec(string, ...any) sql.Result
268268
NamedQuery(string, any) (*Rows, error)
269+
NamedQueryContext(context.Context, string, any) (*Rows, error)
269270
}
270271

271272
var _ Queryable = (*DB)(nil)

sqlx_context.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,12 @@ func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...any) *
378378
return &Row{rows: rows, err: err, options: tx.options, Mapper: tx.Mapper}
379379
}
380380

381+
// NamedQueryContext using this Tx.
382+
// Any named placeholder parameters are replaced with fields from arg.
383+
func (tx *Tx) NamedQueryContext(ctx context.Context, query string, arg any) (*Rows, error) {
384+
return NamedQueryContext(ctx, tx, query, arg)
385+
}
386+
381387
// NamedExecContext using this Tx.
382388
// Any named placeholder parameters are replaced with fields from arg.
383389
func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg any) (sql.Result, error) {

sqlx_context_test.go

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ func TestNamedQueryContext(t *testing.T) {
487487
`,
488488
}
489489

490-
RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) {
490+
testWithQueryable := func(ctx context.Context, queryable Queryable, mapper **reflectx.Mapper, t *testing.T) {
491491
type Person struct {
492492
FirstName sql.NullString `db:"first_name"`
493493
LastName sql.NullString `db:"last_name"`
@@ -500,12 +500,12 @@ func TestNamedQueryContext(t *testing.T) {
500500
Email: sql.NullString{String: "ben@doe.com", Valid: true},
501501
}
502502

503-
_, err := db.NamedExecContext(ctx, `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, p)
503+
_, err := queryable.NamedExecContext(ctx, `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, p)
504504
require.NoError(t, err)
505505

506506
{
507507
p2 := &Person{}
508-
rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p)
508+
rows, err := queryable.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p)
509509
if err != nil {
510510
log.Fatal(err)
511511
}
@@ -527,7 +527,7 @@ func TestNamedQueryContext(t *testing.T) {
527527
// these are tests for #73; they verify that named queries work if you've
528528
// changed the db mapper. This code checks both NamedQuery "ad-hoc" style
529529
// queries and NamedStmt queries, which use different code paths internally.
530-
old := (*db).Mapper
530+
old := *mapper
531531

532532
type JSONPerson struct {
533533
FirstName sql.NullString `json:"FIRST"`
@@ -541,21 +541,21 @@ func TestNamedQueryContext(t *testing.T) {
541541
Email: sql.NullString{String: "ben@smith.com", Valid: true},
542542
}
543543

544-
db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper)
544+
*mapper = reflectx.NewMapperFunc("json", strings.ToUpper)
545545

546546
// prepare queries for case sensitivity to test our ToUpper function.
547547
// postgres and sqlite accept "", but mysql uses ``; since Go's multi-line
548548
// strings are `` we use "" by default and swap out for MySQL
549-
pdb := func(s string, db *DB) string {
550-
if db.DriverName() == "mysql" {
549+
pdb := func(s string, queryable Queryable) string {
550+
if queryable.DriverName() == "mysql" {
551551
return strings.Replace(s, `"`, "`", -1)
552552
}
553553
return s
554554
}
555555

556-
_, err = db.NamedExecContext(ctx, pdb(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`, db), jp)
556+
_, err = queryable.NamedExecContext(ctx, pdb(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`, queryable), jp)
557557
if err != nil {
558-
t.Fatal(err, db.DriverName())
558+
t.Fatal(err, queryable.DriverName())
559559
}
560560

561561
// Checks that a person pulled out of the db matches the one we put in
@@ -567,24 +567,24 @@ func TestNamedQueryContext(t *testing.T) {
567567
t.Error(err)
568568
}
569569
if jp.FirstName.String != "ben" {
570-
t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName())
570+
t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, queryable.DriverName())
571571
}
572572
if jp.LastName.String != "smith" {
573-
t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName())
573+
t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, queryable.DriverName())
574574
}
575575
if jp.Email.String != "ben@smith.com" {
576-
t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName())
576+
t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, queryable.DriverName())
577577
}
578578
}
579579
}
580580

581-
ns, err := db.PrepareNamed(pdb(`
581+
ns, err := queryable.PrepareNamed(pdb(`
582582
SELECT * FROM jsperson
583583
WHERE
584584
"FIRST"=:FIRST AND
585585
last_name=:last_name AND
586586
"EMAIL"=:EMAIL
587-
`, db))
587+
`, queryable))
588588
require.NoError(t, err)
589589

590590
rows, err := ns.QueryxContext(ctx, jp)
@@ -595,19 +595,19 @@ func TestNamedQueryContext(t *testing.T) {
595595

596596
// Check exactly the same thing, but with db.NamedQuery, which does not go
597597
// through the PrepareNamed/NamedStmt path.
598-
rows, err = db.NamedQueryContext(ctx, pdb(`
598+
rows, err = queryable.NamedQueryContext(ctx, pdb(`
599599
SELECT * FROM jsperson
600600
WHERE
601601
"FIRST"=:FIRST AND
602602
last_name=:last_name AND
603603
"EMAIL"=:EMAIL
604-
`, db), jp)
604+
`, queryable), jp)
605605
require.NoError(t, err)
606606

607607
check(t, rows)
608608
rows.Close()
609609

610-
db.Mapper = old
610+
*mapper = old
611611

612612
// Test nested structs
613613
type Place struct {
@@ -631,17 +631,17 @@ func TestNamedQueryContext(t *testing.T) {
631631
Email: sql.NullString{String: "ben@doe.com", Valid: true},
632632
}
633633

634-
_, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (1, :name)`, pl)
634+
_, err = queryable.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (1, :name)`, pl)
635635
require.NoError(t, err)
636636

637637
id := 1
638638
benDoe.Place.ID = id
639639

640-
_, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe)
640+
_, err = queryable.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe)
641641
require.NoError(t, err)
642642

643643
{
644-
rows, err = db.NamedQueryContext(ctx, `
644+
rows, err = queryable.NamedQueryContext(ctx, `
645645
SELECT
646646
first_name,
647647
last_name,
@@ -679,17 +679,17 @@ func TestNamedQueryContext(t *testing.T) {
679679
Name: sql.NullString{String: "the-house", Valid: true},
680680
}
681681

682-
_, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (2, :name)`, pl)
682+
_, err = queryable.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (2, :name)`, pl)
683683
require.NoError(t, err)
684684

685685
id = 2
686686
benDoe.Place.ID = id
687687

688-
_, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe)
688+
_, err = queryable.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe)
689689
require.NoError(t, err)
690690

691691
{
692-
rows, err = db.NamedQueryContext(ctx, `
692+
rows, err = queryable.NamedQueryContext(ctx, `
693693
SELECT
694694
place.id,
695695
place.name,
@@ -711,7 +711,7 @@ func TestNamedQueryContext(t *testing.T) {
711711
}
712712

713713
{
714-
rows, err = db.NamedQueryContext(ctx, `
714+
rows, err = queryable.NamedQueryContext(ctx, `
715715
SELECT
716716
place.id,
717717
place.name,
@@ -752,7 +752,7 @@ func TestNamedQueryContext(t *testing.T) {
752752
}
753753

754754
{
755-
rows, err = db.NamedQueryContext(ctx, `
755+
rows, err = queryable.NamedQueryContext(ctx, `
756756
SELECT
757757
place.id,
758758
place.name,
@@ -786,10 +786,10 @@ func TestNamedQueryContext(t *testing.T) {
786786
Notes: "this is a test person",
787787
}
788788

789-
_, err = db.NamedExecContext(ctx, `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)`, details)
789+
_, err = queryable.NamedExecContext(ctx, `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)`, details)
790790
require.NoError(t, err)
791791

792-
rows, err = db.NamedQueryContext(ctx, `
792+
rows, err = queryable.NamedQueryContext(ctx, `
793793
SELECT
794794
place.id,
795795
place.name,
@@ -817,6 +817,18 @@ func TestNamedQueryContext(t *testing.T) {
817817
assert.Equal(t, details.Notes, pp6.Owner.Details.Notes)
818818
}
819819
}
820+
}
821+
822+
RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) {
823+
testWithQueryable(ctx, db, &db.Mapper, t)
824+
})
825+
826+
RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) {
827+
tx := db.MustBegin()
828+
testWithQueryable(ctx, tx, &tx.Mapper, t)
829+
if err := tx.Rollback(); err != nil {
830+
t.Error(err)
831+
}
820832
})
821833
}
822834

0 commit comments

Comments
 (0)