Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 135 additions & 46 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"iter"
"slices"
"strings"
"time"

sq "github.com/Masterminds/squirrel"
Expand Down Expand Up @@ -139,6 +140,89 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error {
return nil
}

// Upsert inserts records with ON CONFLICT handling. conflictColumns specify the unique
// constraint columns. If updateColumns is nil or empty, uses DO NOTHING. Otherwise,
// uses DO UPDATE SET to update the specified columns on conflict.
//
// Does not use RETURNING — with DO NOTHING the conflicting row may not be returned.
func (t *Table[T, P, I]) Upsert(ctx context.Context, conflictColumns []string, updateColumns []string, records ...P) error {
if len(records) == 0 {
return nil
}
if len(conflictColumns) == 0 {
return fmt.Errorf("upsert: conflictColumns must not be empty")
}

now := time.Now().UTC()
for i, r := range records {
if r == nil {
return fmt.Errorf("record with index=%d is nil", i)
}
if err := r.Validate(); err != nil {
return fmt.Errorf("validate record: %w", err)
}
if row, ok := any(r).(HasSetCreatedAt); ok {
row.SetCreatedAt(now)
}
if row, ok := any(r).(HasSetUpdatedAt); ok {
row.SetUpdatedAt(now)
}
}

// Validate column names against the first record's mapped columns.
cols, _, err := Map(records[0])
if err != nil {
return fmt.Errorf("upsert: map record: %w", err)
}
colSet := make(map[string]struct{}, len(cols))
for _, c := range cols {
colSet[c] = struct{}{}
}
for _, c := range conflictColumns {
if _, ok := colSet[c]; !ok {
return fmt.Errorf("upsert: invalid conflict column %q", c)
}
}
for _, c := range updateColumns {
if _, ok := colSet[c]; !ok {
return fmt.Errorf("upsert: invalid update column %q", c)
}
}

// Auto-include updated_at in DO UPDATE when the record tracks update time.
if len(updateColumns) > 0 {
if _, ok := any(records[0]).(HasSetUpdatedAt); ok {
if _, exists := colSet["updated_at"]; exists && !slices.Contains(updateColumns, "updated_at") {
updateColumns = append(updateColumns, "updated_at")
}
}
}

// Build ON CONFLICT suffix.
var suffix string
if len(updateColumns) == 0 {
suffix = fmt.Sprintf("ON CONFLICT (%s) DO NOTHING", strings.Join(conflictColumns, ", "))
} else {
sets := make([]string, len(updateColumns))
for i, c := range updateColumns {
sets[i] = fmt.Sprintf("%s = EXCLUDED.%s", c, c)
}
suffix = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET %s", strings.Join(conflictColumns, ", "), strings.Join(sets, ", "))
}

for start := 0; start < len(records); start += chunkSize {
end := min(start+chunkSize, len(records))
chunk := records[start:end]

q := t.SQL.InsertRecords(chunk).Into(t.Name).SuffixExpr(sq.Expr(suffix))
if _, err := t.Query.Exec(ctx, q); err != nil {
return fmt.Errorf("upsert records: %w", err)
}
}

return nil
}

// Update updates one or more records by their ID. Sets UpdatedAt timestamp if available.
// Returns (true, nil) if at least one row was updated, (false, nil) if no rows matched.
func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) (bool, error) {
Expand Down Expand Up @@ -536,61 +620,62 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] {
}
}

// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
// within a database transaction for safe concurrent processing. The record is processed exactly
// once across multiple workers. The record is automatically updated after updateFn() completes.
// ClaimForUpdate locks matching rows with FOR UPDATE SKIP LOCKED, calls mutateFn
// on each record, and saves all records within the transaction.
//
// When no existing transaction is present, ClaimForUpdate creates and commits its own
// transaction — records are returned after commit, safe for processing outside the tx.
//
// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
// to update status to "completed" or "failed".
// When called on a table bound to an existing transaction (via WithTx), the caller
// controls commit/rollback. Records are returned after the mutations are persisted
// within the tx but before commit — the caller must commit the tx to finalize.
//
// Returns ErrNoRows if no matching records are available for locking.
func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error {
var noRows bool
// If mutateFn returns an error, the transaction is rolled back (or left for the caller
// to roll back in the WithTx case) and no records are returned.
func (t *Table[T, P, I]) ClaimForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) {
var claimed []P

err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) {
if len(records) > 0 {
updateFn(records[0])
} else {
noRows = true
claimWithTx := func(pgTx pgx.Tx) error {
records, err := t.claimForUpdateWithTx(ctx, pgTx, where, orderBy, limit, mutateFn)
if err != nil {
return err
}
})
if err != nil {
return err //nolint:wrapcheck
}

if noRows {
return ErrNoRows
claimed = records
return nil
}

return nil
}

// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern
// within a database transaction for safe concurrent processing. Each record is processed exactly
// once across multiple workers. Records are automatically updated after updateFn() completes.
//
// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status
// to "processing" and return early, then process asynchronously. Use defer LockForUpdate()
// to update status to "completed" or "failed".
func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error {
// Reuse existing transaction if available.
if t.DB.Query.Tx != nil {
if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil {
return fmt.Errorf("lock for update (with tx): %w", err)
if err := claimWithTx(t.DB.Query.Tx); err != nil {
return nil, fmt.Errorf("claim for update (with tx): %w", err)
}
return nil
return claimed, nil
}

return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error {
if err := t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn); err != nil {
return fmt.Errorf("lock for update (new tx): %w", err)
}
return nil
err := pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error {
return claimWithTx(pgTx)
})
if err != nil {
return nil, fmt.Errorf("claim for update (new tx): %w", err)
}

return claimed, nil
}

// ClaimOneForUpdate is ClaimForUpdate with limit=1. Returns the single record
// or ErrNoRows if nothing matched.
func (t *Table[T, P, I]) ClaimOneForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, mutateFn func(record P) error) (P, error) {
records, err := t.ClaimForUpdate(ctx, where, orderBy, 1, mutateFn)
if err != nil {
return nil, err //nolint:wrapcheck
}
if len(records) == 0 {
return nil, ErrNoRows
}
return records[0], nil
}

func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error {
func (t *Table[T, P, I]) claimForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, mutateFn func(record P) error) ([]P, error) {
if len(orderBy) == 0 {
orderBy = []string{t.IDColumn}
}
Expand All @@ -607,24 +692,28 @@ func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx,

var records []P
if err := txQuery.GetAll(ctx, q, &records); err != nil {
return fmt.Errorf("select for update skip locked: %w", err)
return nil, fmt.Errorf("select for update skip locked: %w", err)
}

updateFn(records)
for _, record := range records {
if err := mutateFn(record); err != nil {
return nil, fmt.Errorf("mutate record: %w", err)
}
}

now := time.Now().UTC()
for _, record := range records {
if err := record.Validate(); err != nil {
return fmt.Errorf("validate record after update: %w", err)
return nil, fmt.Errorf("validate record after update: %w", err)
}
if row, ok := any(record).(HasSetUpdatedAt); ok {
row.SetUpdatedAt(now)
}
q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name)
if _, err := txQuery.Exec(ctx, q); err != nil {
return fmt.Errorf("update record: %w", err)
return nil, fmt.Errorf("update record: %w", err)
}
}

return nil
return records, nil
}
Loading
Loading