From 574c4d316051a275a15ead957f330b8eacb4c8b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Tue, 7 Apr 2026 13:03:05 +0200 Subject: [PATCH 1/2] feat: replace LockForUpdate with ClaimForUpdate, add Upsert ClaimForUpdate/ClaimOneForUpdate replace LockForUpdate/LockForUpdates with a cleaner API: per-record mutateFn that returns error (enabling rollback), and returns claimed records after commit for processing outside the transaction. Upsert adds INSERT ... ON CONFLICT support with DO NOTHING or DO UPDATE SET, column name validation, and batch chunking. --- table.go | 167 +++++++++++++++++++++++++++----------- tests/table_test.go | 188 +++++++++++++++++++++++++++++++++++++++---- tests/tables_test.go | 15 ++-- tests/worker_test.go | 11 +-- 4 files changed, 306 insertions(+), 75 deletions(-) diff --git a/table.go b/table.go index 3e96fd5..8a37add 100644 --- a/table.go +++ b/table.go @@ -6,6 +6,7 @@ import ( "fmt" "iter" "slices" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -139,6 +140,80 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { return nil } +// Upsert inserts records with ON CONFLICT handling. conflictColumns specify the unique +// constraint columns. If updateColumns is nil or empty, uses DO NOTHING. Otherwise, +// uses DO UPDATE SET to update the specified columns on conflict. +// +// Does not use RETURNING — with DO NOTHING the conflicting row may not be returned. +func (t *Table[T, P, I]) Upsert(ctx context.Context, conflictColumns []string, updateColumns []string, records ...P) error { + if len(records) == 0 { + return nil + } + if len(conflictColumns) == 0 { + return fmt.Errorf("upsert: conflictColumns must not be empty") + } + + now := time.Now().UTC() + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + if row, ok := any(r).(HasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + if row, ok := any(r).(HasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + } + + // Validate column names against the first record's mapped columns. + cols, _, err := Map(records[0]) + if err != nil { + return fmt.Errorf("upsert: map record: %w", err) + } + colSet := make(map[string]struct{}, len(cols)) + for _, c := range cols { + colSet[c] = struct{}{} + } + for _, c := range conflictColumns { + if _, ok := colSet[c]; !ok { + return fmt.Errorf("upsert: invalid conflict column %q", c) + } + } + for _, c := range updateColumns { + if _, ok := colSet[c]; !ok { + return fmt.Errorf("upsert: invalid update column %q", c) + } + } + + // Build ON CONFLICT suffix. + var suffix string + if len(updateColumns) == 0 { + suffix = fmt.Sprintf("ON CONFLICT (%s) DO NOTHING", strings.Join(conflictColumns, ", ")) + } else { + sets := make([]string, len(updateColumns)) + for i, c := range updateColumns { + sets[i] = fmt.Sprintf("%s = EXCLUDED.%s", c, c) + } + suffix = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", strings.Join(conflictColumns, ", "), strings.Join(sets, ", ")) + } + + for start := 0; start < len(records); start += chunkSize { + end := min(start+chunkSize, len(records)) + chunk := records[start:end] + + q := t.SQL.InsertRecords(chunk).Into(t.Name).SuffixExpr(sq.Expr(suffix)) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("upsert records: %w", err) + } + } + + return nil +} + // Update updates one or more records by their ID. Sets UpdatedAt timestamp if available. // Returns (true, nil) if at least one row was updated, (false, nil) if no rows matched. func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) (bool, error) { @@ -536,61 +611,55 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { } } -// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// within a database transaction for safe concurrent processing. The record is processed exactly -// once across multiple workers. The record is automatically updated after updateFn() completes. +// ClaimForUpdate locks matching rows with FOR UPDATE SKIP LOCKED, calls mutateFn +// on each record, saves all records (committing the mutation), and returns the +// mutated records for processing outside the transaction. // -// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status -// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() -// to update status to "completed" or "failed". -// -// Returns ErrNoRows if no matching records are available for locking. -func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error { - var noRows bool +// If mutateFn returns an error, the transaction is rolled back and no records are returned. +func (t *Table[T, P, I]) ClaimForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) { + var claimed []P - err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) { - if len(records) > 0 { - updateFn(records[0]) - } else { - noRows = true + claimWithTx := func(pgTx pgx.Tx) error { + records, err := t.claimForUpdateWithTx(ctx, pgTx, where, orderBy, limit, mutateFn) + if err != nil { + return err } - }) - if err != nil { - return err //nolint:wrapcheck - } - - if noRows { - return ErrNoRows + claimed = records + return nil } - return nil -} - -// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// within a database transaction for safe concurrent processing. Each record is processed exactly -// once across multiple workers. Records are automatically updated after updateFn() completes. -// -// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status -// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() -// to update status to "completed" or "failed". -func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { // Reuse existing transaction if available. if t.DB.Query.Tx != nil { - if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { - return fmt.Errorf("lock for update (with tx): %w", err) + if err := claimWithTx(t.DB.Query.Tx); err != nil { + return nil, fmt.Errorf("claim for update (with tx): %w", err) } - return nil + return claimed, nil } - return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { - if err := t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn); err != nil { - return fmt.Errorf("lock for update (new tx): %w", err) - } - return nil + err := pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { + return claimWithTx(pgTx) }) + if err != nil { + return nil, fmt.Errorf("claim for update (new tx): %w", err) + } + + return claimed, nil } -func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { +// ClaimOneForUpdate is ClaimForUpdate with limit=1. Returns the single record +// or ErrNoRows if nothing matched. +func (t *Table[T, P, I]) ClaimOneForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, mutateFn func(record P) error) (P, error) { + records, err := t.ClaimForUpdate(ctx, where, orderBy, 1, mutateFn) + if err != nil { + return nil, err //nolint:wrapcheck + } + if len(records) == 0 { + return nil, ErrNoRows + } + return records[0], nil +} + +func (t *Table[T, P, I]) claimForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -607,24 +676,28 @@ func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, var records []P if err := txQuery.GetAll(ctx, q, &records); err != nil { - return fmt.Errorf("select for update skip locked: %w", err) + return nil, fmt.Errorf("select for update skip locked: %w", err) } - updateFn(records) + for _, record := range records { + if err := mutateFn(record); err != nil { + return nil, fmt.Errorf("mutate record: %w", err) + } + } now := time.Now().UTC() for _, record := range records { if err := record.Validate(); err != nil { - return fmt.Errorf("validate record after update: %w", err) + return nil, fmt.Errorf("validate record after update: %w", err) } if row, ok := any(record).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) if _, err := txQuery.Exec(ctx, q); err != nil { - return fmt.Errorf("update record: %w", err) + return nil, fmt.Errorf("update record: %w", err) } } - return nil + return records, nil } diff --git a/tests/table_test.go b/tests/table_test.go index b3356a6..1cfac8e 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -1,6 +1,7 @@ package pgkit_test import ( + "context" "fmt" "slices" "sync" @@ -647,16 +648,16 @@ func TestIter(t *testing.T) { require.Equal(t, total, count, "Iter should yield all rows") } -func TestLockForUpdates(t *testing.T) { +func TestClaimForUpdate(t *testing.T) { truncateAllTables(t) ctx := t.Context() db := initDB(DB) worker := &Worker{DB: db} - t.Run("TestLockForUpdates", func(t *testing.T) { + t.Run("concurrent dequeue", func(t *testing.T) { // Create account. - account := &Account{Name: "LockForUpdates Account"} + account := &Account{Name: "ClaimForUpdate Account"} err := db.Accounts.Save(ctx, account) require.NoError(t, err, "Create account failed") @@ -665,7 +666,7 @@ func TestLockForUpdates(t *testing.T) { err = db.Articles.Save(ctx, article) require.NoError(t, err, "Create article failed") - // Create 1000 reviews. + // Create 100 reviews. reviews := make([]*Review, 100) for i := range 100 { reviews[i] = &Review{ @@ -679,7 +680,7 @@ func TestLockForUpdates(t *testing.T) { require.NoError(t, err, "create review") var mu sync.Mutex - var allIDs []uint64 + var allReviews []*Review var wg sync.WaitGroup for range 10 { @@ -687,28 +688,31 @@ func TestLockForUpdates(t *testing.T) { go func() { defer wg.Done() - reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) + claimed, err := db.Reviews.DequeueForProcessing(ctx, 10) assert.NoError(t, err, "dequeue reviews") - var localIDs []uint64 - for _, review := range reviews { - localIDs = append(localIDs, review.ID) + // Verify returned records have mutated status. + for _, review := range claimed { + assert.Equal(t, ReviewStatusProcessing, review.Status, "returned record should be mutated") worker.wg.Add(1) go worker.ProcessReview(ctx, review) } mu.Lock() - allIDs = append(allIDs, localIDs...) + allReviews = append(allReviews, claimed...) mu.Unlock() }() } wg.Wait() // Ensure that all reviews were picked up for processing exactly once. - uniqueIDs := slices.Clone(allIDs) - slices.Sort(uniqueIDs) - uniqueIDs = slices.Compact(uniqueIDs) - require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") + allIDs := make([]uint64, len(allReviews)) + for i, r := range allReviews { + allIDs[i] = r.ID + } + slices.Sort(allIDs) + allIDs = slices.Compact(allIDs) + require.Equal(t, 100, len(allIDs), "number of unique reviews picked up for processing should be 100") // Wait for all reviews to be processed asynchronously. worker.Wait() @@ -718,4 +722,160 @@ func TestLockForUpdates(t *testing.T) { require.NoError(t, err, "count reviews") require.Zero(t, count, "there should be no reviews stuck in 'processing' status") }) + + t.Run("mutateFn error rolls back", func(t *testing.T) { + truncateAllTables(t) + + account := &Account{Name: "Rollback Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + article := &Article{AccountID: account.ID, Author: "Author"} + err = db.Articles.Save(ctx, article) + require.NoError(t, err) + + for i := range 5 { + err := db.Reviews.Save(ctx, &Review{ + Comment: fmt.Sprintf("Rollback comment %d", i), + AccountID: account.ID, + ArticleID: article.ID, + Status: ReviewStatusPending, + }) + require.NoError(t, err) + } + + callCount := 0 + records, err := db.Reviews.ClaimForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, 5, func(review *Review) error { + callCount++ + if callCount == 3 { + return fmt.Errorf("deliberate error on record 3") + } + review.Status = ReviewStatusProcessing + return nil + }) + require.Error(t, err, "should return mutateFn error") + require.Nil(t, records, "no records on error") + + // All reviews should still be pending (transaction rolled back). + count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusPending}) + require.NoError(t, err) + require.Equal(t, uint64(5), count, "all reviews should still be pending after rollback") + }) + + t.Run("empty result returns nil slice", func(t *testing.T) { + truncateAllTables(t) + + records, err := db.Reviews.ClaimForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, 10, func(review *Review) error { + return nil + }) + require.NoError(t, err) + require.Empty(t, records) + }) + + t.Run("ClaimOneForUpdate no match returns ErrNoRows", func(t *testing.T) { + truncateAllTables(t) + + record, err := db.Reviews.ClaimOneForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, func(review *Review) error { + return nil + }) + require.ErrorIs(t, err, pgkit.ErrNoRows) + require.Nil(t, record) + }) +} + +func TestUpsert(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + // Add a unique index on accounts.name for upsert testing. + _, err := DB.Conn.Exec(ctx, "CREATE UNIQUE INDEX IF NOT EXISTS accounts_name_unique ON accounts (name)") + require.NoError(t, err) + t.Cleanup(func() { + DB.Conn.Exec(context.Background(), "DROP INDEX IF EXISTS accounts_name_unique") //nolint:errcheck + }) + + t.Run("insert new record", func(t *testing.T) { + truncateAllTables(t) + + err := db.Accounts.Upsert(ctx, []string{"name"}, nil, &Account{Name: "Upsert New"}) + require.NoError(t, err) + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "Upsert New"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count) + }) + + t.Run("DO NOTHING on conflict", func(t *testing.T) { + truncateAllTables(t) + + original := &Account{Name: "DoNothing"} + err := db.Accounts.Insert(ctx, original) + require.NoError(t, err) + + // Upsert with DO NOTHING — original should be preserved. + err = db.Accounts.Upsert(ctx, []string{"name"}, nil, &Account{Name: "DoNothing", Disabled: true}) + require.NoError(t, err) + + got, err := db.Accounts.GetByID(ctx, original.ID) + require.NoError(t, err) + require.False(t, got.Disabled, "original data should be preserved with DO NOTHING") + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "DoNothing"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should not create duplicate") + }) + + t.Run("DO UPDATE on conflict", func(t *testing.T) { + truncateAllTables(t) + + original := &Account{Name: "DoUpdate"} + err := db.Accounts.Insert(ctx, original) + require.NoError(t, err) + require.False(t, original.Disabled) + + // Upsert with DO UPDATE SET disabled — should update disabled column. + err = db.Accounts.Upsert(ctx, []string{"name"}, []string{"disabled"}, &Account{Name: "DoUpdate", Disabled: true}) + require.NoError(t, err) + + got, err := db.Accounts.GetByID(ctx, original.ID) + require.NoError(t, err) + require.True(t, got.Disabled, "disabled should be updated on conflict") + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "DoUpdate"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should not create duplicate") + }) + + t.Run("concurrent upserts no duplicates", func(t *testing.T) { + truncateAllTables(t) + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func() { + defer wg.Done() + err := db.Accounts.Upsert(ctx, []string{"name"}, []string{"disabled"}, &Account{Name: "ConcurrentUpsert", Disabled: i%2 == 0}) + assert.NoError(t, err) + }() + } + wg.Wait() + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "ConcurrentUpsert"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should have exactly one record after concurrent upserts") + }) + + t.Run("invalid conflict column", func(t *testing.T) { + err := db.Accounts.Upsert(ctx, []string{"nonexistent"}, nil, &Account{Name: "Invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid conflict column") + }) + + t.Run("invalid update column", func(t *testing.T) { + err := db.Accounts.Upsert(ctx, []string{"name"}, []string{"nonexistent"}, &Account{Name: "Invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid update column") + }) } diff --git a/tests/tables_test.go b/tests/tables_test.go index d734da0..a590303 100644 --- a/tests/tables_test.go +++ b/tests/tables_test.go @@ -22,7 +22,6 @@ type reviewsTable struct { } func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { - var dequeued []*Review where := sq.Eq{ "status": ReviewStatusPending, "deleted_at": nil, @@ -31,16 +30,14 @@ func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ( "created_at ASC", } - err := t.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { - now := time.Now().UTC() - for _, review := range reviews { - review.Status = ReviewStatusProcessing - review.ProcessedAt = &now - } - dequeued = reviews + now := time.Now().UTC() + dequeued, err := t.ClaimForUpdate(ctx, where, orderBy, limit, func(review *Review) error { + review.Status = ReviewStatusProcessing + review.ProcessedAt = &now + return nil }) if err != nil { - return nil, fmt.Errorf("lock for updates: %w", err) + return nil, fmt.Errorf("claim for update: %w", err) } return dequeued, nil diff --git a/tests/worker_test.go b/tests/worker_test.go index 6bc6416..893bddb 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -27,17 +27,18 @@ func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) defer func() { // Always update review status to "approved", "rejected" or "failed". noCtx := context.Background() - err = w.DB.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { + _, claimErr := w.DB.Reviews.ClaimOneForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) error { now := time.Now().UTC() update.ProcessedAt = &now if err != nil { update.Status = ReviewStatusFailed - return + } else { + update.Status = review.Status } - update.Status = review.Status + return nil }) - if err != nil { - log.Printf("failed to save review: %v", err) + if claimErr != nil { + log.Printf("failed to save review: %v", claimErr) } }() From d8cfea48af9d2031f602d7c09ecc255c4542d439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Tue, 7 Apr 2026 13:16:57 +0200 Subject: [PATCH 2/2] fix: address review findings on ClaimForUpdate and Upsert MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ClaimForUpdate: fix docstring to accurately describe the existing-tx path — records are returned within the tx scope, caller must commit. - Upsert: auto-include updated_at in DO UPDATE SET when the record implements HasSetUpdatedAt, consistent with Save/Update behavior. - Add test assertion for updated_at bump on conflict. --- table.go | 22 +++++++++++++++++++--- tests/table_test.go | 1 + 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/table.go b/table.go index 8a37add..d91df5f 100644 --- a/table.go +++ b/table.go @@ -189,6 +189,15 @@ func (t *Table[T, P, I]) Upsert(ctx context.Context, conflictColumns []string, u } } + // Auto-include updated_at in DO UPDATE when the record tracks update time. + if len(updateColumns) > 0 { + if _, ok := any(records[0]).(HasSetUpdatedAt); ok { + if _, exists := colSet["updated_at"]; exists && !slices.Contains(updateColumns, "updated_at") { + updateColumns = append(updateColumns, "updated_at") + } + } + } + // Build ON CONFLICT suffix. var suffix string if len(updateColumns) == 0 { @@ -612,10 +621,17 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { } // ClaimForUpdate locks matching rows with FOR UPDATE SKIP LOCKED, calls mutateFn -// on each record, saves all records (committing the mutation), and returns the -// mutated records for processing outside the transaction. +// on each record, and saves all records within the transaction. +// +// When no existing transaction is present, ClaimForUpdate creates and commits its own +// transaction — records are returned after commit, safe for processing outside the tx. +// +// When called on a table bound to an existing transaction (via WithTx), the caller +// controls commit/rollback. Records are returned after the mutations are persisted +// within the tx but before commit — the caller must commit the tx to finalize. // -// If mutateFn returns an error, the transaction is rolled back and no records are returned. +// If mutateFn returns an error, the transaction is rolled back (or left for the caller +// to roll back in the WithTx case) and no records are returned. func (t *Table[T, P, I]) ClaimForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) { var claimed []P diff --git a/tests/table_test.go b/tests/table_test.go index 1cfac8e..63d8a2d 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -842,6 +842,7 @@ func TestUpsert(t *testing.T) { got, err := db.Accounts.GetByID(ctx, original.ID) require.NoError(t, err) require.True(t, got.Disabled, "disabled should be updated on conflict") + require.True(t, got.UpdatedAt.After(original.UpdatedAt), "updated_at should be bumped on conflict even though caller only listed 'disabled'") count, err := db.Accounts.Count(ctx, sq.Eq{"name": "DoUpdate"}) require.NoError(t, err)