Skip to content

Commit ae6bef6

Browse files
committed
feat: add AfterScan hook for post-scan field hydration
Records implementing HasAfterScan get called after every scan from the database (Get, List, ListPaged, Insert RETURNING, Iter, etc.). Use this to populate computed fields not stored in the DB.
1 parent 2fa3605 commit ae6bef6

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

table.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ type HasSetUpdatedAt interface {
4242
SetUpdatedAt(time.Time)
4343
}
4444

45+
// HasAfterScan is implemented by records that need post-scan hydration.
46+
// Called after every scan from the database (Get, List, ListPaged, Insert RETURNING, etc.).
47+
// Use this to populate computed fields that are not stored in the database.
48+
type HasAfterScan interface {
49+
AfterScan()
50+
}
51+
4552
// HasSetDeletedAt is implemented by records that support soft delete.
4653
// DeleteByID will call SetDeletedAt with the current UTC time to soft-delete,
4754
// and RestoreByID will call SetDeletedAt with a zero time.Time{} to restore.
@@ -97,6 +104,7 @@ func (t *Table[T, P, I]) insertOne(ctx context.Context, record P) error {
97104
if err := t.Query.GetOne(ctx, q, record); err != nil {
98105
return fmt.Errorf("insert record: %w", err)
99106
}
107+
callAfterScan(record)
100108

101109
return nil
102110
}
@@ -136,6 +144,7 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error {
136144
if err := t.Query.GetAll(ctx, q, &chunk); err != nil {
137145
return fmt.Errorf("insert records: %w", err)
138146
}
147+
callAfterScanAll(chunk)
139148
}
140149

141150
return nil
@@ -254,6 +263,7 @@ func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error {
254263
if err := t.Query.GetOne(ctx, q, record); err != nil {
255264
return fmt.Errorf("save: insert record: %w", err)
256265
}
266+
callAfterScan(record)
257267

258268
return nil
259269
}
@@ -322,6 +332,7 @@ func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error {
322332
if err := t.Query.GetAll(ctx, q, &chunk); err != nil {
323333
return fmt.Errorf("insert records: %w", err)
324334
}
335+
callAfterScanAll(chunk)
325336

326337
// update original slice
327338
for i, rr := range chunk {
@@ -363,6 +374,7 @@ func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []st
363374
if err := t.Query.GetOne(ctx, q, record); err != nil {
364375
return nil, fmt.Errorf("get record: %w", err)
365376
}
377+
callAfterScan(record)
366378

367379
return record, nil
368380
}
@@ -374,6 +386,7 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s
374386
if err := t.Query.GetAll(ctx, q, &records); err != nil {
375387
return nil, err
376388
}
389+
callAfterScanAll(records)
377390

378391
return records, nil
379392
}
@@ -386,6 +399,7 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *
386399
if err := t.Query.GetAll(ctx, q, &result); err != nil {
387400
return nil, nil, err
388401
}
402+
callAfterScanAll(result)
389403
result = t.Paginator.PrepareResult(result, page)
390404
return result, page, nil
391405
}
@@ -408,6 +422,7 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s
408422
}
409423
return
410424
}
425+
callAfterScan(&record)
411426
if !yield(&record, nil) {
412427
return
413428
}
@@ -567,6 +582,20 @@ func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, o
567582
})
568583
}
569584

585+
func callAfterScan[P any](record P) {
586+
if row, ok := any(record).(HasAfterScan); ok {
587+
row.AfterScan()
588+
}
589+
}
590+
591+
func callAfterScanAll[P any](records []P) {
592+
for _, r := range records {
593+
if row, ok := any(r).(HasAfterScan); ok {
594+
row.AfterScan()
595+
}
596+
}
597+
}
598+
570599
func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error {
571600
if len(orderBy) == 0 {
572601
orderBy = []string{t.IDColumn}
@@ -586,6 +615,7 @@ func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx,
586615
if err := txQuery.GetAll(ctx, q, &records); err != nil {
587616
return fmt.Errorf("select for update skip locked: %w", err)
588617
}
618+
callAfterScanAll(records)
589619

590620
updateFn(records)
591621

0 commit comments

Comments
 (0)