diff --git a/table.go b/table.go index 3e96fd5..d91df5f 100644 --- a/table.go +++ b/table.go @@ -6,6 +6,7 @@ import ( "fmt" "iter" "slices" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -139,6 +140,89 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { return nil } +// Upsert inserts records with ON CONFLICT handling. conflictColumns specify the unique +// constraint columns. If updateColumns is nil or empty, uses DO NOTHING. Otherwise, +// uses DO UPDATE SET to update the specified columns on conflict. +// +// Does not use RETURNING — with DO NOTHING the conflicting row may not be returned. +func (t *Table[T, P, I]) Upsert(ctx context.Context, conflictColumns []string, updateColumns []string, records ...P) error { + if len(records) == 0 { + return nil + } + if len(conflictColumns) == 0 { + return fmt.Errorf("upsert: conflictColumns must not be empty") + } + + now := time.Now().UTC() + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + if row, ok := any(r).(HasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + if row, ok := any(r).(HasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + } + + // Validate column names against the first record's mapped columns. + cols, _, err := Map(records[0]) + if err != nil { + return fmt.Errorf("upsert: map record: %w", err) + } + colSet := make(map[string]struct{}, len(cols)) + for _, c := range cols { + colSet[c] = struct{}{} + } + for _, c := range conflictColumns { + if _, ok := colSet[c]; !ok { + return fmt.Errorf("upsert: invalid conflict column %q", c) + } + } + for _, c := range updateColumns { + if _, ok := colSet[c]; !ok { + return fmt.Errorf("upsert: invalid update column %q", c) + } + } + + // Auto-include updated_at in DO UPDATE when the record tracks update time. + if len(updateColumns) > 0 { + if _, ok := any(records[0]).(HasSetUpdatedAt); ok { + if _, exists := colSet["updated_at"]; exists && !slices.Contains(updateColumns, "updated_at") { + updateColumns = append(updateColumns, "updated_at") + } + } + } + + // Build ON CONFLICT suffix. + var suffix string + if len(updateColumns) == 0 { + suffix = fmt.Sprintf("ON CONFLICT (%s) DO NOTHING", strings.Join(conflictColumns, ", ")) + } else { + sets := make([]string, len(updateColumns)) + for i, c := range updateColumns { + sets[i] = fmt.Sprintf("%s = EXCLUDED.%s", c, c) + } + suffix = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", strings.Join(conflictColumns, ", "), strings.Join(sets, ", ")) + } + + for start := 0; start < len(records); start += chunkSize { + end := min(start+chunkSize, len(records)) + chunk := records[start:end] + + q := t.SQL.InsertRecords(chunk).Into(t.Name).SuffixExpr(sq.Expr(suffix)) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("upsert records: %w", err) + } + } + + return nil +} + // Update updates one or more records by their ID. Sets UpdatedAt timestamp if available. // Returns (true, nil) if at least one row was updated, (false, nil) if no rows matched. func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) (bool, error) { @@ -536,61 +620,62 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { } } -// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// within a database transaction for safe concurrent processing. The record is processed exactly -// once across multiple workers. The record is automatically updated after updateFn() completes. +// ClaimForUpdate locks matching rows with FOR UPDATE SKIP LOCKED, calls mutateFn +// on each record, and saves all records within the transaction. +// +// When no existing transaction is present, ClaimForUpdate creates and commits its own +// transaction — records are returned after commit, safe for processing outside the tx. // -// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status -// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() -// to update status to "completed" or "failed". +// When called on a table bound to an existing transaction (via WithTx), the caller +// controls commit/rollback. Records are returned after the mutations are persisted +// within the tx but before commit — the caller must commit the tx to finalize. // -// Returns ErrNoRows if no matching records are available for locking. -func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error { - var noRows bool +// If mutateFn returns an error, the transaction is rolled back (or left for the caller +// to roll back in the WithTx case) and no records are returned. +func (t *Table[T, P, I]) ClaimForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) { + var claimed []P - err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) { - if len(records) > 0 { - updateFn(records[0]) - } else { - noRows = true + claimWithTx := func(pgTx pgx.Tx) error { + records, err := t.claimForUpdateWithTx(ctx, pgTx, where, orderBy, limit, mutateFn) + if err != nil { + return err } - }) - if err != nil { - return err //nolint:wrapcheck - } - - if noRows { - return ErrNoRows + claimed = records + return nil } - return nil -} - -// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// within a database transaction for safe concurrent processing. Each record is processed exactly -// once across multiple workers. Records are automatically updated after updateFn() completes. -// -// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status -// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() -// to update status to "completed" or "failed". -func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { // Reuse existing transaction if available. if t.DB.Query.Tx != nil { - if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { - return fmt.Errorf("lock for update (with tx): %w", err) + if err := claimWithTx(t.DB.Query.Tx); err != nil { + return nil, fmt.Errorf("claim for update (with tx): %w", err) } - return nil + return claimed, nil } - return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { - if err := t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn); err != nil { - return fmt.Errorf("lock for update (new tx): %w", err) - } - return nil + err := pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { + return claimWithTx(pgTx) }) + if err != nil { + return nil, fmt.Errorf("claim for update (new tx): %w", err) + } + + return claimed, nil +} + +// ClaimOneForUpdate is ClaimForUpdate with limit=1. Returns the single record +// or ErrNoRows if nothing matched. +func (t *Table[T, P, I]) ClaimOneForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, mutateFn func(record P) error) (P, error) { + records, err := t.ClaimForUpdate(ctx, where, orderBy, 1, mutateFn) + if err != nil { + return nil, err //nolint:wrapcheck + } + if len(records) == 0 { + return nil, ErrNoRows + } + return records[0], nil } -func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { +func (t *Table[T, P, I]) claimForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -607,24 +692,28 @@ func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, var records []P if err := txQuery.GetAll(ctx, q, &records); err != nil { - return fmt.Errorf("select for update skip locked: %w", err) + return nil, fmt.Errorf("select for update skip locked: %w", err) } - updateFn(records) + for _, record := range records { + if err := mutateFn(record); err != nil { + return nil, fmt.Errorf("mutate record: %w", err) + } + } now := time.Now().UTC() for _, record := range records { if err := record.Validate(); err != nil { - return fmt.Errorf("validate record after update: %w", err) + return nil, fmt.Errorf("validate record after update: %w", err) } if row, ok := any(record).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) if _, err := txQuery.Exec(ctx, q); err != nil { - return fmt.Errorf("update record: %w", err) + return nil, fmt.Errorf("update record: %w", err) } } - return nil + return records, nil } diff --git a/tests/table_test.go b/tests/table_test.go index b3356a6..63d8a2d 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -1,6 +1,7 @@ package pgkit_test import ( + "context" "fmt" "slices" "sync" @@ -647,16 +648,16 @@ func TestIter(t *testing.T) { require.Equal(t, total, count, "Iter should yield all rows") } -func TestLockForUpdates(t *testing.T) { +func TestClaimForUpdate(t *testing.T) { truncateAllTables(t) ctx := t.Context() db := initDB(DB) worker := &Worker{DB: db} - t.Run("TestLockForUpdates", func(t *testing.T) { + t.Run("concurrent dequeue", func(t *testing.T) { // Create account. - account := &Account{Name: "LockForUpdates Account"} + account := &Account{Name: "ClaimForUpdate Account"} err := db.Accounts.Save(ctx, account) require.NoError(t, err, "Create account failed") @@ -665,7 +666,7 @@ func TestLockForUpdates(t *testing.T) { err = db.Articles.Save(ctx, article) require.NoError(t, err, "Create article failed") - // Create 1000 reviews. + // Create 100 reviews. reviews := make([]*Review, 100) for i := range 100 { reviews[i] = &Review{ @@ -679,7 +680,7 @@ func TestLockForUpdates(t *testing.T) { require.NoError(t, err, "create review") var mu sync.Mutex - var allIDs []uint64 + var allReviews []*Review var wg sync.WaitGroup for range 10 { @@ -687,28 +688,31 @@ func TestLockForUpdates(t *testing.T) { go func() { defer wg.Done() - reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) + claimed, err := db.Reviews.DequeueForProcessing(ctx, 10) assert.NoError(t, err, "dequeue reviews") - var localIDs []uint64 - for _, review := range reviews { - localIDs = append(localIDs, review.ID) + // Verify returned records have mutated status. + for _, review := range claimed { + assert.Equal(t, ReviewStatusProcessing, review.Status, "returned record should be mutated") worker.wg.Add(1) go worker.ProcessReview(ctx, review) } mu.Lock() - allIDs = append(allIDs, localIDs...) + allReviews = append(allReviews, claimed...) mu.Unlock() }() } wg.Wait() // Ensure that all reviews were picked up for processing exactly once. - uniqueIDs := slices.Clone(allIDs) - slices.Sort(uniqueIDs) - uniqueIDs = slices.Compact(uniqueIDs) - require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") + allIDs := make([]uint64, len(allReviews)) + for i, r := range allReviews { + allIDs[i] = r.ID + } + slices.Sort(allIDs) + allIDs = slices.Compact(allIDs) + require.Equal(t, 100, len(allIDs), "number of unique reviews picked up for processing should be 100") // Wait for all reviews to be processed asynchronously. worker.Wait() @@ -718,4 +722,161 @@ func TestLockForUpdates(t *testing.T) { require.NoError(t, err, "count reviews") require.Zero(t, count, "there should be no reviews stuck in 'processing' status") }) + + t.Run("mutateFn error rolls back", func(t *testing.T) { + truncateAllTables(t) + + account := &Account{Name: "Rollback Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + article := &Article{AccountID: account.ID, Author: "Author"} + err = db.Articles.Save(ctx, article) + require.NoError(t, err) + + for i := range 5 { + err := db.Reviews.Save(ctx, &Review{ + Comment: fmt.Sprintf("Rollback comment %d", i), + AccountID: account.ID, + ArticleID: article.ID, + Status: ReviewStatusPending, + }) + require.NoError(t, err) + } + + callCount := 0 + records, err := db.Reviews.ClaimForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, 5, func(review *Review) error { + callCount++ + if callCount == 3 { + return fmt.Errorf("deliberate error on record 3") + } + review.Status = ReviewStatusProcessing + return nil + }) + require.Error(t, err, "should return mutateFn error") + require.Nil(t, records, "no records on error") + + // All reviews should still be pending (transaction rolled back). + count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusPending}) + require.NoError(t, err) + require.Equal(t, uint64(5), count, "all reviews should still be pending after rollback") + }) + + t.Run("empty result returns nil slice", func(t *testing.T) { + truncateAllTables(t) + + records, err := db.Reviews.ClaimForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, 10, func(review *Review) error { + return nil + }) + require.NoError(t, err) + require.Empty(t, records) + }) + + t.Run("ClaimOneForUpdate no match returns ErrNoRows", func(t *testing.T) { + truncateAllTables(t) + + record, err := db.Reviews.ClaimOneForUpdate(ctx, sq.Eq{"status": ReviewStatusPending}, nil, func(review *Review) error { + return nil + }) + require.ErrorIs(t, err, pgkit.ErrNoRows) + require.Nil(t, record) + }) +} + +func TestUpsert(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + // Add a unique index on accounts.name for upsert testing. + _, err := DB.Conn.Exec(ctx, "CREATE UNIQUE INDEX IF NOT EXISTS accounts_name_unique ON accounts (name)") + require.NoError(t, err) + t.Cleanup(func() { + DB.Conn.Exec(context.Background(), "DROP INDEX IF EXISTS accounts_name_unique") //nolint:errcheck + }) + + t.Run("insert new record", func(t *testing.T) { + truncateAllTables(t) + + err := db.Accounts.Upsert(ctx, []string{"name"}, nil, &Account{Name: "Upsert New"}) + require.NoError(t, err) + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "Upsert New"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count) + }) + + t.Run("DO NOTHING on conflict", func(t *testing.T) { + truncateAllTables(t) + + original := &Account{Name: "DoNothing"} + err := db.Accounts.Insert(ctx, original) + require.NoError(t, err) + + // Upsert with DO NOTHING — original should be preserved. + err = db.Accounts.Upsert(ctx, []string{"name"}, nil, &Account{Name: "DoNothing", Disabled: true}) + require.NoError(t, err) + + got, err := db.Accounts.GetByID(ctx, original.ID) + require.NoError(t, err) + require.False(t, got.Disabled, "original data should be preserved with DO NOTHING") + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "DoNothing"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should not create duplicate") + }) + + t.Run("DO UPDATE on conflict", func(t *testing.T) { + truncateAllTables(t) + + original := &Account{Name: "DoUpdate"} + err := db.Accounts.Insert(ctx, original) + require.NoError(t, err) + require.False(t, original.Disabled) + + // Upsert with DO UPDATE SET disabled — should update disabled column. + err = db.Accounts.Upsert(ctx, []string{"name"}, []string{"disabled"}, &Account{Name: "DoUpdate", Disabled: true}) + require.NoError(t, err) + + got, err := db.Accounts.GetByID(ctx, original.ID) + require.NoError(t, err) + require.True(t, got.Disabled, "disabled should be updated on conflict") + require.True(t, got.UpdatedAt.After(original.UpdatedAt), "updated_at should be bumped on conflict even though caller only listed 'disabled'") + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "DoUpdate"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should not create duplicate") + }) + + t.Run("concurrent upserts no duplicates", func(t *testing.T) { + truncateAllTables(t) + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func() { + defer wg.Done() + err := db.Accounts.Upsert(ctx, []string{"name"}, []string{"disabled"}, &Account{Name: "ConcurrentUpsert", Disabled: i%2 == 0}) + assert.NoError(t, err) + }() + } + wg.Wait() + + count, err := db.Accounts.Count(ctx, sq.Eq{"name": "ConcurrentUpsert"}) + require.NoError(t, err) + require.Equal(t, uint64(1), count, "should have exactly one record after concurrent upserts") + }) + + t.Run("invalid conflict column", func(t *testing.T) { + err := db.Accounts.Upsert(ctx, []string{"nonexistent"}, nil, &Account{Name: "Invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid conflict column") + }) + + t.Run("invalid update column", func(t *testing.T) { + err := db.Accounts.Upsert(ctx, []string{"name"}, []string{"nonexistent"}, &Account{Name: "Invalid"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid update column") + }) } diff --git a/tests/tables_test.go b/tests/tables_test.go index d734da0..a590303 100644 --- a/tests/tables_test.go +++ b/tests/tables_test.go @@ -22,7 +22,6 @@ type reviewsTable struct { } func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { - var dequeued []*Review where := sq.Eq{ "status": ReviewStatusPending, "deleted_at": nil, @@ -31,16 +30,14 @@ func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ( "created_at ASC", } - err := t.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { - now := time.Now().UTC() - for _, review := range reviews { - review.Status = ReviewStatusProcessing - review.ProcessedAt = &now - } - dequeued = reviews + now := time.Now().UTC() + dequeued, err := t.ClaimForUpdate(ctx, where, orderBy, limit, func(review *Review) error { + review.Status = ReviewStatusProcessing + review.ProcessedAt = &now + return nil }) if err != nil { - return nil, fmt.Errorf("lock for updates: %w", err) + return nil, fmt.Errorf("claim for update: %w", err) } return dequeued, nil diff --git a/tests/worker_test.go b/tests/worker_test.go index 6bc6416..893bddb 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -27,17 +27,18 @@ func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) defer func() { // Always update review status to "approved", "rejected" or "failed". noCtx := context.Background() - err = w.DB.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { + _, claimErr := w.DB.Reviews.ClaimOneForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) error { now := time.Now().UTC() update.ProcessedAt = &now if err != nil { update.Status = ReviewStatusFailed - return + } else { + update.Status = review.Status } - update.Status = review.Status + return nil }) - if err != nil { - log.Printf("failed to save review: %v", err) + if claimErr != nil { + log.Printf("failed to save review: %v", claimErr) } }()