|
7 | 7 | "crypto/rand" |
8 | 8 | "crypto/x509" |
9 | 9 | "errors" |
| 10 | + "fmt" |
10 | 11 | "io" |
11 | 12 | "math/big" |
12 | 13 | "testing" |
@@ -639,6 +640,197 @@ func TestWithTx(t *testing.T) { |
639 | 640 | ) |
640 | 641 | } |
641 | 642 |
|
| 643 | +func TestNoRollback(t *testing.T) { |
| 644 | + t.Run( |
| 645 | + "nil returns nil", |
| 646 | + func(t *testing.T) { |
| 647 | + assert.Nil(t, pg.NoRollback(nil)) |
| 648 | + }, |
| 649 | + ) |
| 650 | + |
| 651 | + t.Run( |
| 652 | + "wraps error in NoRollbackError", |
| 653 | + func(t *testing.T) { |
| 654 | + sentinel := errors.New("inner") |
| 655 | + err := pg.NoRollback(sentinel) |
| 656 | + |
| 657 | + var nrErr *pg.NoRollbackError |
| 658 | + require.ErrorAs(t, err, &nrErr) |
| 659 | + assert.Equal(t, sentinel, nrErr.Err) |
| 660 | + }, |
| 661 | + ) |
| 662 | + |
| 663 | + t.Run( |
| 664 | + "Error delegates to inner", |
| 665 | + func(t *testing.T) { |
| 666 | + err := pg.NoRollback(errors.New("boom")) |
| 667 | + assert.Equal(t, "boom", err.Error()) |
| 668 | + }, |
| 669 | + ) |
| 670 | + |
| 671 | + t.Run( |
| 672 | + "Unwrap returns inner", |
| 673 | + func(t *testing.T) { |
| 674 | + sentinel := errors.New("inner") |
| 675 | + err := pg.NoRollback(sentinel) |
| 676 | + assert.ErrorIs(t, err, sentinel) |
| 677 | + }, |
| 678 | + ) |
| 679 | + |
| 680 | + t.Run( |
| 681 | + "detectable through additional wrapping", |
| 682 | + func(t *testing.T) { |
| 683 | + sentinel := errors.New("root cause") |
| 684 | + wrapped := fmt.Errorf("context: %w", pg.NoRollback(sentinel)) |
| 685 | + |
| 686 | + var nrErr *pg.NoRollbackError |
| 687 | + require.ErrorAs(t, wrapped, &nrErr) |
| 688 | + assert.ErrorIs(t, wrapped, sentinel) |
| 689 | + }, |
| 690 | + ) |
| 691 | +} |
| 692 | + |
| 693 | +func TestWithTx_NoRollback(t *testing.T) { |
| 694 | + client := newTestClient(t) |
| 695 | + ctx := context.Background() |
| 696 | + |
| 697 | + setup := func(t *testing.T, table string) { |
| 698 | + t.Helper() |
| 699 | + err := client.WithConn( |
| 700 | + ctx, |
| 701 | + func(ctx context.Context, conn pg.Querier) error { |
| 702 | + _, err := conn.Exec(ctx, "DROP TABLE IF EXISTS "+table) |
| 703 | + if err != nil { |
| 704 | + return err |
| 705 | + } |
| 706 | + _, err = conn.Exec(ctx, |
| 707 | + "CREATE TABLE "+table+" (id serial PRIMARY KEY, name text NOT NULL)") |
| 708 | + return err |
| 709 | + }, |
| 710 | + ) |
| 711 | + require.NoError(t, err) |
| 712 | + |
| 713 | + t.Cleanup(func() { |
| 714 | + _ = client.WithConn( |
| 715 | + context.Background(), |
| 716 | + func(ctx context.Context, conn pg.Querier) error { |
| 717 | + _, err := conn.Exec(ctx, "DROP TABLE IF EXISTS "+table) |
| 718 | + return err |
| 719 | + }, |
| 720 | + ) |
| 721 | + }) |
| 722 | + } |
| 723 | + |
| 724 | + t.Run( |
| 725 | + "commits and returns inner error", |
| 726 | + func(t *testing.T) { |
| 727 | + setup(t, "test_tx_norollback") |
| 728 | + |
| 729 | + sentinel := errors.New("soft failure") |
| 730 | + err := client.WithTx( |
| 731 | + ctx, |
| 732 | + func(ctx context.Context, tx pg.Tx) error { |
| 733 | + _, err := tx.Exec(ctx, |
| 734 | + "INSERT INTO test_tx_norollback (name) VALUES ($1)", "committed") |
| 735 | + if err != nil { |
| 736 | + return err |
| 737 | + } |
| 738 | + return pg.NoRollback(sentinel) |
| 739 | + }, |
| 740 | + ) |
| 741 | + require.ErrorIs(t, err, sentinel) |
| 742 | + |
| 743 | + var nrErr *pg.NoRollbackError |
| 744 | + assert.False(t, errors.As(err, &nrErr), |
| 745 | + "returned error must not be wrapped in NoRollbackError") |
| 746 | + |
| 747 | + err = client.WithConn( |
| 748 | + ctx, |
| 749 | + func(ctx context.Context, conn pg.Querier) error { |
| 750 | + var count int |
| 751 | + err := conn.QueryRow(ctx, |
| 752 | + "SELECT count(*) FROM test_tx_norollback WHERE name = $1", |
| 753 | + "committed").Scan(&count) |
| 754 | + require.NoError(t, err) |
| 755 | + assert.Equal(t, 1, count) |
| 756 | + return nil |
| 757 | + }, |
| 758 | + ) |
| 759 | + require.NoError(t, err) |
| 760 | + }, |
| 761 | + ) |
| 762 | + |
| 763 | + t.Run( |
| 764 | + "works through additional wrapping", |
| 765 | + func(t *testing.T) { |
| 766 | + setup(t, "test_tx_norollback_wrap") |
| 767 | + |
| 768 | + sentinel := errors.New("root cause") |
| 769 | + err := client.WithTx( |
| 770 | + ctx, |
| 771 | + func(ctx context.Context, tx pg.Tx) error { |
| 772 | + _, err := tx.Exec(ctx, |
| 773 | + "INSERT INTO test_tx_norollback_wrap (name) VALUES ($1)", "kept") |
| 774 | + if err != nil { |
| 775 | + return err |
| 776 | + } |
| 777 | + return fmt.Errorf("extra context: %w", pg.NoRollback(sentinel)) |
| 778 | + }, |
| 779 | + ) |
| 780 | + require.Error(t, err) |
| 781 | + |
| 782 | + err = client.WithConn( |
| 783 | + ctx, |
| 784 | + func(ctx context.Context, conn pg.Querier) error { |
| 785 | + var count int |
| 786 | + err := conn.QueryRow(ctx, |
| 787 | + "SELECT count(*) FROM test_tx_norollback_wrap WHERE name = $1", |
| 788 | + "kept").Scan(&count) |
| 789 | + require.NoError(t, err) |
| 790 | + assert.Equal(t, 1, count) |
| 791 | + return nil |
| 792 | + }, |
| 793 | + ) |
| 794 | + require.NoError(t, err) |
| 795 | + }, |
| 796 | + ) |
| 797 | + |
| 798 | + t.Run( |
| 799 | + "normal error still rolls back", |
| 800 | + func(t *testing.T) { |
| 801 | + setup(t, "test_tx_norollback_control") |
| 802 | + |
| 803 | + sentinel := errors.New("hard failure") |
| 804 | + err := client.WithTx( |
| 805 | + ctx, |
| 806 | + func(ctx context.Context, tx pg.Tx) error { |
| 807 | + _, err := tx.Exec(ctx, |
| 808 | + "INSERT INTO test_tx_norollback_control (name) VALUES ($1)", "gone") |
| 809 | + if err != nil { |
| 810 | + return err |
| 811 | + } |
| 812 | + return sentinel |
| 813 | + }, |
| 814 | + ) |
| 815 | + require.ErrorIs(t, err, sentinel) |
| 816 | + |
| 817 | + err = client.WithConn( |
| 818 | + ctx, |
| 819 | + func(ctx context.Context, conn pg.Querier) error { |
| 820 | + var count int |
| 821 | + err := conn.QueryRow(ctx, |
| 822 | + "SELECT count(*) FROM test_tx_norollback_control WHERE name = $1", |
| 823 | + "gone").Scan(&count) |
| 824 | + require.NoError(t, err) |
| 825 | + assert.Equal(t, 0, count) |
| 826 | + return nil |
| 827 | + }, |
| 828 | + ) |
| 829 | + require.NoError(t, err) |
| 830 | + }, |
| 831 | + ) |
| 832 | +} |
| 833 | + |
642 | 834 | func TestWithTx_QuerierMethods(t *testing.T) { |
643 | 835 | client := newTestClient(t) |
644 | 836 | ctx := context.Background() |
|
0 commit comments