Skip to content

Commit 943df88

Browse files
committed
Use Tx for savepoint
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent ca42457 commit 943df88

4 files changed

Lines changed: 89 additions & 79 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Breaking Changes
11+
12+
- **pg**: `Tx.Savepoint` callback now receives `pg.Tx` instead of `pg.Querier`, enabling nested savepoints.
13+
1014
### Added
1115

1216
- **worker**: Generic task polling worker with concurrent processing, observability, and graceful shutdown.

pg/client.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,9 @@ func (c *Client) WithConn(
328328
// }
329329
//
330330
// // Savepoint failure does not roll back the DELETE above.
331-
// if err := tx.Savepoint(ctx, func(ctx context.Context, q pg.Querier) error {
332-
// _, err := q.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
331+
// // The callback receives a Tx, so savepoints can be nested.
332+
// if err := tx.Savepoint(ctx, func(ctx context.Context, inner pg.Tx) error {
333+
// _, err := inner.Exec(ctx, "INSERT INTO audit_log (...) VALUES (...)")
333334
// return err
334335
// }); err != nil {
335336
// log.Warn("audit failed, continuing", "err", err)

pg/client_test.go

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,14 @@ func TestWithTx(t *testing.T) {
411411
return err
412412
}
413413

414-
return tx.Savepoint(
415-
ctx,
416-
func(ctx context.Context, conn pg.Querier) error {
417-
_, err := conn.Exec(ctx,
418-
"INSERT INTO test_tx_nested (name) VALUES ($1)", "inner")
419-
return err
420-
},
421-
)
414+
return tx.Savepoint(
415+
ctx,
416+
func(ctx context.Context, inner pg.Tx) error {
417+
_, err := inner.Exec(ctx,
418+
"INSERT INTO test_tx_nested (name) VALUES ($1)", "inner")
419+
return err
420+
},
421+
)
422422
},
423423
)
424424
require.NoError(t, err)
@@ -451,16 +451,16 @@ func TestWithTx(t *testing.T) {
451451
return err
452452
}
453453

454-
_ = tx.Savepoint(
455-
ctx,
456-
func(ctx context.Context, conn pg.Querier) error {
457-
if _, err := conn.Exec(ctx,
458-
"INSERT INTO test_tx_savepoint (name) VALUES ($1)", "inner_fail"); err != nil {
459-
return err
460-
}
461-
return errors.New("inner error")
462-
},
463-
)
454+
_ = tx.Savepoint(
455+
ctx,
456+
func(ctx context.Context, inner pg.Tx) error {
457+
if _, err := inner.Exec(ctx,
458+
"INSERT INTO test_tx_savepoint (name) VALUES ($1)", "inner_fail"); err != nil {
459+
return err
460+
}
461+
return errors.New("inner error")
462+
},
463+
)
464464

465465
return nil
466466
},
@@ -501,25 +501,25 @@ func TestWithTx(t *testing.T) {
501501
err := client.WithTx(
502502
ctx,
503503
func(ctx context.Context, tx pg.Tx) error {
504-
if err := tx.Savepoint(
505-
ctx,
506-
func(ctx context.Context, q pg.Querier) error {
507-
_, err := q.Exec(ctx,
508-
"INSERT INTO test_tx_multi_sp (name) VALUES ($1)", "sp1")
509-
return err
510-
},
511-
); err != nil {
504+
if err := tx.Savepoint(
505+
ctx,
506+
func(ctx context.Context, inner pg.Tx) error {
507+
_, err := inner.Exec(ctx,
508+
"INSERT INTO test_tx_multi_sp (name) VALUES ($1)", "sp1")
512509
return err
513-
}
510+
},
511+
); err != nil {
512+
return err
513+
}
514514

515-
return tx.Savepoint(
516-
ctx,
517-
func(ctx context.Context, q pg.Querier) error {
518-
_, err := q.Exec(ctx,
519-
"INSERT INTO test_tx_multi_sp (name) VALUES ($1)", "sp2")
520-
return err
521-
},
522-
)
515+
return tx.Savepoint(
516+
ctx,
517+
func(ctx context.Context, inner pg.Tx) error {
518+
_, err := inner.Exec(ctx,
519+
"INSERT INTO test_tx_multi_sp (name) VALUES ($1)", "sp2")
520+
return err
521+
},
522+
)
523523
},
524524
)
525525
require.NoError(t, err)
@@ -547,25 +547,25 @@ func TestWithTx(t *testing.T) {
547547
err := client.WithTx(
548548
ctx,
549549
func(ctx context.Context, tx pg.Tx) error {
550-
if err := tx.Savepoint(ctx, func(ctx context.Context, q pg.Querier) error {
551-
_, err := q.Exec(ctx,
552-
"INSERT INTO test_tx_sp_mixed (name) VALUES ($1)", "kept")
553-
return err
554-
}); err != nil {
555-
return err
556-
}
550+
if err := tx.Savepoint(ctx, func(ctx context.Context, inner pg.Tx) error {
551+
_, err := inner.Exec(ctx,
552+
"INSERT INTO test_tx_sp_mixed (name) VALUES ($1)", "kept")
553+
return err
554+
}); err != nil {
555+
return err
556+
}
557557

558-
_ = tx.Savepoint(
559-
ctx,
560-
func(ctx context.Context, q pg.Querier) error {
561-
_, err := q.Exec(ctx,
562-
"INSERT INTO test_tx_sp_mixed (name) VALUES ($1)", "discarded")
563-
if err != nil {
564-
return err
565-
}
566-
return errors.New("second savepoint fails")
567-
},
568-
)
558+
_ = tx.Savepoint(
559+
ctx,
560+
func(ctx context.Context, inner pg.Tx) error {
561+
_, err := inner.Exec(ctx,
562+
"INSERT INTO test_tx_sp_mixed (name) VALUES ($1)", "discarded")
563+
if err != nil {
564+
return err
565+
}
566+
return errors.New("second savepoint fails")
567+
},
568+
)
569569

570570
return nil
571571
},
@@ -613,12 +613,12 @@ func TestWithTx(t *testing.T) {
613613
return err
614614
}
615615

616-
return tx.Savepoint(
617-
ctx,
618-
func(ctx context.Context, q pg.Querier) error {
619-
return errors.New("savepoint failed")
620-
},
621-
)
616+
return tx.Savepoint(
617+
ctx,
618+
func(ctx context.Context, inner pg.Tx) error {
619+
return errors.New("savepoint failed")
620+
},
621+
)
622622
},
623623
)
624624
require.Error(t, err)
@@ -883,14 +883,14 @@ func TestWithTx_Tracing(t *testing.T) {
883883
if err != nil {
884884
return err
885885
}
886-
return tx.Savepoint(
887-
ctx,
888-
func(ctx context.Context, conn pg.Querier) error {
889-
_, err := conn.Exec(ctx,
890-
"INSERT INTO test_tx_trace_sp (name) VALUES ($1)", "inner")
891-
return err
892-
},
893-
)
886+
return tx.Savepoint(
887+
ctx,
888+
func(ctx context.Context, inner pg.Tx) error {
889+
_, err := inner.Exec(ctx,
890+
"INSERT INTO test_tx_trace_sp (name) VALUES ($1)", "inner")
891+
return err
892+
},
893+
)
894894
},
895895
)
896896
require.NoError(t, err)
@@ -917,12 +917,12 @@ func TestWithTx_Tracing(t *testing.T) {
917917
if err != nil {
918918
return err
919919
}
920-
_ = tx.Savepoint(
921-
ctx,
922-
func(ctx context.Context, conn pg.Querier) error {
923-
return errors.New("inner savepoint error")
924-
},
925-
)
920+
_ = tx.Savepoint(
921+
ctx,
922+
func(ctx context.Context, inner pg.Tx) error {
923+
return errors.New("inner savepoint error")
924+
},
925+
)
926926
return nil
927927
},
928928
)

pg/conn.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ type (
3939
}
4040

4141
// Tx represents an active database transaction. It extends
42-
// Querier with the ability to create savepoints.
42+
// Querier with the ability to create savepoints. The callback
43+
// receives a Tx so savepoints can be nested arbitrarily.
4344
Tx interface {
4445
Querier
4546

46-
Savepoint(context.Context, ExecFunc[Querier]) error
47+
Savepoint(context.Context, ExecFunc[Tx]) error
4748
}
4849

4950
pgxTx struct {
@@ -75,7 +76,9 @@ func (t *pgxTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
7576
// Savepoint executes fn within a savepoint. If fn returns an error,
7677
// the savepoint is rolled back; otherwise, it is released. The outer
7778
// transaction remains active regardless of the savepoint outcome.
78-
func (t *pgxTx) Savepoint(ctx context.Context, fn ExecFunc[Querier]) error {
79+
//
80+
// The callback receives a Tx, allowing nested savepoints.
81+
func (t *pgxTx) Savepoint(ctx context.Context, fn ExecFunc[Tx]) error {
7982
var (
8083
rootSpan = trace.SpanFromContext(ctx)
8184
span trace.Span
@@ -100,7 +103,9 @@ func (t *pgxTx) Savepoint(ctx context.Context, fn ExecFunc[Querier]) error {
100103
return err
101104
}
102105

103-
if err := fn(ctx, sp); err != nil {
106+
spTx := &pgxTx{inner: sp, tracer: t.tracer}
107+
108+
if err := fn(ctx, spTx); err != nil {
104109
if err2 := sp.Rollback(ctx); err2 != nil {
105110
err = errors.Join(
106111
err,

0 commit comments

Comments
 (0)