Skip to content

Commit 211e63f

Browse files
committed
Tests: Implement in-memory worker pattern via simple WaitGroup
1 parent c5669d4 commit 211e63f

4 files changed

Lines changed: 106 additions & 77 deletions

File tree

table.go

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,43 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] {
180180
}
181181
}
182182

183-
// LockForUpdates locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
184-
// for safe concurrent processing. Each record is processed exactly once across multiple workers.
185-
// Records are automatically updated after updateFn() completes. Keep updateFn() fast to avoid
186-
// holding the transaction. For long-running work, update status to "processing" and return early,
187-
// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed".
183+
// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
184+
// within a database transaction for safe concurrent processing. The record is processed exactly
185+
// once across multiple workers. The record is automatically updated after updateFn() completes.
186+
//
187+
// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
188+
// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
189+
// to update status to "completed" or "failed".
190+
//
191+
// Returns ErrNoRows if no matching records are available for locking.
192+
func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error {
193+
var noRows bool
194+
195+
err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) {
196+
if len(records) > 0 {
197+
updateFn(records[0])
198+
} else {
199+
noRows = true
200+
}
201+
})
202+
if err != nil {
203+
return fmt.Errorf("lock for update one: %w", err)
204+
}
205+
206+
if noRows {
207+
return ErrNoRows
208+
}
209+
210+
return nil
211+
}
212+
213+
// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
214+
// within a database transaction for safe concurrent processing. Each record is processed exactly
215+
// once across multiple workers. Records are automatically updated after updateFn() completes.
216+
//
217+
// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
218+
// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
219+
// to update status to "completed" or "failed".
188220
func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error {
189221
// Check if we're already in a transaction
190222
if t.DB.Query.Tx != nil {
@@ -227,31 +259,3 @@ func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.T
227259

228260
return nil
229261
}
230-
231-
// LockForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
232-
// for safe concurrent processing. The record is processed exactly once across multiple workers.
233-
// The record is automatically updated after updateFn() completes. Keep updateFn() fast to avoid
234-
// holding the transaction. For long-running work, update status to "processing" and return early,
235-
// then process asynchronously. Use defer LockForUpdate() to update status to "completed" or "failed".
236-
//
237-
// Returns ErrNoRows if no matching records are available for locking.
238-
func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error {
239-
var noRows bool
240-
241-
err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) {
242-
if len(records) > 0 {
243-
updateFn(records[0])
244-
} else {
245-
noRows = true
246-
}
247-
})
248-
if err != nil {
249-
return fmt.Errorf("lock for update one: %w", err)
250-
}
251-
252-
if noRows {
253-
return ErrNoRows
254-
}
255-
256-
return nil
257-
}

tests/database_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func (db *Database) WithTx(tx pgx.Tx) *Database {
4141
return initDB(pgkitDB)
4242
}
4343

44+
func (db *Database) Close() { db.DB.Conn.Close() }
45+
4446
type accountsTable struct {
4547
*pgkit.Table[Account, *Account, int64]
4648
}

tests/table_test.go

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package pgkit_test
22

33
import (
4-
"context"
54
"fmt"
6-
"log"
7-
"math/rand"
85
"slices"
96
"sync"
107
"testing"
@@ -160,6 +157,7 @@ func TestLockForUpdates(t *testing.T) {
160157

161158
ctx := t.Context()
162159
db := initDB(DB)
160+
worker := &Worker{DB: db}
163161

164162
t.Run("TestLockForUpdates", func(t *testing.T) {
165163
// Create account.
@@ -216,7 +214,7 @@ func TestLockForUpdates(t *testing.T) {
216214
require.NoError(t, err, "lock for update")
217215

218216
for _, review := range processReviews {
219-
go processReviewAsynchronously(ctx, db, review)
217+
go worker.ProcessReview(ctx, review)
220218
}
221219

222220
for i, review := range processReviews {
@@ -233,50 +231,11 @@ func TestLockForUpdates(t *testing.T) {
233231
require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100")
234232

235233
// Wait for all reviews to be processed asynchronously.
236-
time.Sleep(2 * time.Second)
234+
worker.Wait()
237235

238236
// Double check there's no reviews stuck in "processing" status.
239237
count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusProcessing})
240238
require.NoError(t, err, "count reviews")
241239
require.Zero(t, count, "there should be no reviews stuck in 'processing' status")
242240
})
243241
}
244-
245-
func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) (err error) {
246-
defer func() {
247-
// Always update status to "approved", "rejected" or "failed".
248-
noCtx := context.Background()
249-
err = db.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) {
250-
now := time.Now().UTC()
251-
update.ProcessedAt = &now
252-
if err != nil {
253-
update.Status = ReviewStatusFailed
254-
return
255-
}
256-
update.Status = review.Status
257-
})
258-
if err != nil {
259-
log.Printf("failed to save review: %v", err)
260-
}
261-
}()
262-
263-
// Simulate long-running work.
264-
select {
265-
case <-ctx.Done():
266-
return ctx.Err()
267-
case <-time.After(1 * time.Second):
268-
}
269-
270-
// Simulate external API call to an LLM.
271-
if rand.Intn(2) == 0 {
272-
return fmt.Errorf("failed to process review: <some underlying error>")
273-
}
274-
275-
review.Status = ReviewStatusApproved
276-
if rand.Intn(2) == 0 {
277-
review.Status = ReviewStatusRejected
278-
}
279-
now := time.Now().UTC()
280-
review.ProcessedAt = &now
281-
return nil
282-
}

tests/worker_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package pgkit_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
"math/rand"
8+
"sync"
9+
"time"
10+
11+
sq "github.com/Masterminds/squirrel"
12+
)
13+
14+
type Worker struct {
15+
DB *Database
16+
17+
wg sync.WaitGroup
18+
}
19+
20+
func (w *Worker) Wait() {
21+
w.wg.Wait()
22+
}
23+
24+
func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) {
25+
w.wg.Add(1)
26+
defer w.wg.Done()
27+
28+
defer func() {
29+
// Always update review status to "approved", "rejected" or "failed".
30+
noCtx := context.Background()
31+
err = w.DB.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) {
32+
now := time.Now().UTC()
33+
update.ProcessedAt = &now
34+
if err != nil {
35+
update.Status = ReviewStatusFailed
36+
return
37+
}
38+
update.Status = review.Status
39+
})
40+
if err != nil {
41+
log.Printf("failed to save review: %v", err)
42+
}
43+
}()
44+
45+
// Simulate long-running work.
46+
select {
47+
case <-ctx.Done():
48+
return ctx.Err()
49+
case <-time.After(1 * time.Second):
50+
}
51+
52+
// Simulate external API call to an LLM.
53+
if rand.Intn(2) == 0 {
54+
return fmt.Errorf("failed to process review: <some underlying error>")
55+
}
56+
57+
review.Status = ReviewStatusApproved
58+
if rand.Intn(2) == 0 {
59+
review.Status = ReviewStatusRejected
60+
}
61+
now := time.Now().UTC()
62+
review.ProcessedAt = &now
63+
return nil
64+
}

0 commit comments

Comments
 (0)