Skip to content

Commit 9d6eb4a

Browse files
committed
Add NoRollback error type
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent 8927359 commit 9d6eb4a

3 files changed

Lines changed: 235 additions & 0 deletions

File tree

pg/client.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,21 @@ func (c *Client) WithTx(
383383
tx := &pgxTx{inner: innerTx, tracer: c.tracer}
384384

385385
if err := exec(ctx, tx); err != nil {
386+
if skipErr, ok := errors.AsType[*NoRollbackError](err); ok {
387+
if err2 := innerTx.Commit(ctx); err2 != nil {
388+
err = errors.Join(
389+
err,
390+
fmt.Errorf("cannot commit transaction: %w", err2),
391+
)
392+
}
393+
394+
if span != nil {
395+
recordError(span, err)
396+
}
397+
398+
return skipErr.Err
399+
}
400+
386401
if err2 := innerTx.Rollback(ctx); err2 != nil {
387402
err = errors.Join(
388403
err,

pg/client_test.go

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/rand"
88
"crypto/x509"
99
"errors"
10+
"fmt"
1011
"io"
1112
"math/big"
1213
"testing"
@@ -639,6 +640,197 @@ func TestWithTx(t *testing.T) {
639640
)
640641
}
641642

643+
func TestNoRollback(t *testing.T) {
644+
t.Run(
645+
"nil returns nil",
646+
func(t *testing.T) {
647+
assert.Nil(t, pg.NoRollback(nil))
648+
},
649+
)
650+
651+
t.Run(
652+
"wraps error in NoRollbackError",
653+
func(t *testing.T) {
654+
sentinel := errors.New("inner")
655+
err := pg.NoRollback(sentinel)
656+
657+
var nrErr *pg.NoRollbackError
658+
require.ErrorAs(t, err, &nrErr)
659+
assert.Equal(t, sentinel, nrErr.Err)
660+
},
661+
)
662+
663+
t.Run(
664+
"Error delegates to inner",
665+
func(t *testing.T) {
666+
err := pg.NoRollback(errors.New("boom"))
667+
assert.Equal(t, "boom", err.Error())
668+
},
669+
)
670+
671+
t.Run(
672+
"Unwrap returns inner",
673+
func(t *testing.T) {
674+
sentinel := errors.New("inner")
675+
err := pg.NoRollback(sentinel)
676+
assert.ErrorIs(t, err, sentinel)
677+
},
678+
)
679+
680+
t.Run(
681+
"detectable through additional wrapping",
682+
func(t *testing.T) {
683+
sentinel := errors.New("root cause")
684+
wrapped := fmt.Errorf("context: %w", pg.NoRollback(sentinel))
685+
686+
var nrErr *pg.NoRollbackError
687+
require.ErrorAs(t, wrapped, &nrErr)
688+
assert.ErrorIs(t, wrapped, sentinel)
689+
},
690+
)
691+
}
692+
693+
func TestWithTx_NoRollback(t *testing.T) {
694+
client := newTestClient(t)
695+
ctx := context.Background()
696+
697+
setup := func(t *testing.T, table string) {
698+
t.Helper()
699+
err := client.WithConn(
700+
ctx,
701+
func(ctx context.Context, conn pg.Querier) error {
702+
_, err := conn.Exec(ctx, "DROP TABLE IF EXISTS "+table)
703+
if err != nil {
704+
return err
705+
}
706+
_, err = conn.Exec(ctx,
707+
"CREATE TABLE "+table+" (id serial PRIMARY KEY, name text NOT NULL)")
708+
return err
709+
},
710+
)
711+
require.NoError(t, err)
712+
713+
t.Cleanup(func() {
714+
_ = client.WithConn(
715+
context.Background(),
716+
func(ctx context.Context, conn pg.Querier) error {
717+
_, err := conn.Exec(ctx, "DROP TABLE IF EXISTS "+table)
718+
return err
719+
},
720+
)
721+
})
722+
}
723+
724+
t.Run(
725+
"commits and returns inner error",
726+
func(t *testing.T) {
727+
setup(t, "test_tx_norollback")
728+
729+
sentinel := errors.New("soft failure")
730+
err := client.WithTx(
731+
ctx,
732+
func(ctx context.Context, tx pg.Tx) error {
733+
_, err := tx.Exec(ctx,
734+
"INSERT INTO test_tx_norollback (name) VALUES ($1)", "committed")
735+
if err != nil {
736+
return err
737+
}
738+
return pg.NoRollback(sentinel)
739+
},
740+
)
741+
require.ErrorIs(t, err, sentinel)
742+
743+
var nrErr *pg.NoRollbackError
744+
assert.False(t, errors.As(err, &nrErr),
745+
"returned error must not be wrapped in NoRollbackError")
746+
747+
err = client.WithConn(
748+
ctx,
749+
func(ctx context.Context, conn pg.Querier) error {
750+
var count int
751+
err := conn.QueryRow(ctx,
752+
"SELECT count(*) FROM test_tx_norollback WHERE name = $1",
753+
"committed").Scan(&count)
754+
require.NoError(t, err)
755+
assert.Equal(t, 1, count)
756+
return nil
757+
},
758+
)
759+
require.NoError(t, err)
760+
},
761+
)
762+
763+
t.Run(
764+
"works through additional wrapping",
765+
func(t *testing.T) {
766+
setup(t, "test_tx_norollback_wrap")
767+
768+
sentinel := errors.New("root cause")
769+
err := client.WithTx(
770+
ctx,
771+
func(ctx context.Context, tx pg.Tx) error {
772+
_, err := tx.Exec(ctx,
773+
"INSERT INTO test_tx_norollback_wrap (name) VALUES ($1)", "kept")
774+
if err != nil {
775+
return err
776+
}
777+
return fmt.Errorf("extra context: %w", pg.NoRollback(sentinel))
778+
},
779+
)
780+
require.Error(t, err)
781+
782+
err = client.WithConn(
783+
ctx,
784+
func(ctx context.Context, conn pg.Querier) error {
785+
var count int
786+
err := conn.QueryRow(ctx,
787+
"SELECT count(*) FROM test_tx_norollback_wrap WHERE name = $1",
788+
"kept").Scan(&count)
789+
require.NoError(t, err)
790+
assert.Equal(t, 1, count)
791+
return nil
792+
},
793+
)
794+
require.NoError(t, err)
795+
},
796+
)
797+
798+
t.Run(
799+
"normal error still rolls back",
800+
func(t *testing.T) {
801+
setup(t, "test_tx_norollback_control")
802+
803+
sentinel := errors.New("hard failure")
804+
err := client.WithTx(
805+
ctx,
806+
func(ctx context.Context, tx pg.Tx) error {
807+
_, err := tx.Exec(ctx,
808+
"INSERT INTO test_tx_norollback_control (name) VALUES ($1)", "gone")
809+
if err != nil {
810+
return err
811+
}
812+
return sentinel
813+
},
814+
)
815+
require.ErrorIs(t, err, sentinel)
816+
817+
err = client.WithConn(
818+
ctx,
819+
func(ctx context.Context, conn pg.Querier) error {
820+
var count int
821+
err := conn.QueryRow(ctx,
822+
"SELECT count(*) FROM test_tx_norollback_control WHERE name = $1",
823+
"gone").Scan(&count)
824+
require.NoError(t, err)
825+
assert.Equal(t, 0, count)
826+
return nil
827+
},
828+
)
829+
require.NoError(t, err)
830+
},
831+
)
832+
}
833+
642834
func TestWithTx_QuerierMethods(t *testing.T) {
643835
client := newTestClient(t)
644836
ctx := context.Background()

pg/conn.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,34 @@ import (
2626
"go.opentelemetry.io/otel/trace"
2727
)
2828

29+
// NoRollbackError wraps an error to signal that WithTx should
30+
// commit the transaction instead of rolling back, while still
31+
// propagating the inner error to the caller.
32+
type NoRollbackError struct {
33+
Err error
34+
}
35+
36+
func (e *NoRollbackError) Error() string {
37+
return e.Err.Error()
38+
}
39+
40+
func (e *NoRollbackError) Unwrap() error {
41+
return e.Err
42+
}
43+
44+
// NoRollback wraps err so that WithTx commits the transaction
45+
// instead of rolling back. The inner error is still returned to the
46+
// caller after the commit.
47+
//
48+
// A nil err is returned as-is (no wrapping).
49+
func NoRollback(err error) error {
50+
if err == nil {
51+
return nil
52+
}
53+
54+
return &NoRollbackError{Err: err}
55+
}
56+
2957
type (
3058
// Querier represents something you can run SQL queries against.
3159
Querier interface {

0 commit comments

Comments
 (0)