Skip to content

Commit 5ff0766

Browse files
committed
feat: add AfterScanner hook for post-scan field hydration
Refs #43
1 parent 30bbcbe commit 5ff0766

3 files changed

Lines changed: 99 additions & 0 deletions

File tree

table.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@ import (
1212
"github.com/jackc/pgx/v5"
1313
)
1414

15+
// AfterScanner is implemented by structs that need post-scan field hydration.
16+
type AfterScanner interface {
17+
AfterScan() error
18+
}
19+
20+
func afterScanAll[T any](records []T) error {
21+
for i := range records {
22+
if as, ok := any(records[i]).(AfterScanner); ok {
23+
if err := as.AfterScan(); err != nil {
24+
return fmt.Errorf("after scan: %w", err)
25+
}
26+
}
27+
}
28+
return nil
29+
}
30+
1531
// ID is a comparable type used for record IDs.
1632
type ID comparable
1733

@@ -371,6 +387,11 @@ func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []st
371387
if err := t.Query.GetOne(ctx, q, record); err != nil {
372388
return nil, fmt.Errorf("get record: %w", err)
373389
}
390+
if as, ok := any(record).(AfterScanner); ok {
391+
if err := as.AfterScan(); err != nil {
392+
return nil, fmt.Errorf("after scan: %w", err)
393+
}
394+
}
374395

375396
return record, nil
376397
}
@@ -382,6 +403,9 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s
382403
if err := t.Query.GetAll(ctx, q, &records); err != nil {
383404
return nil, err
384405
}
406+
if err := afterScanAll(records); err != nil {
407+
return nil, err
408+
}
385409

386410
return records, nil
387411
}
@@ -402,6 +426,9 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *
402426
return nil, nil, err
403427
}
404428
result = t.Paginator.PrepareResult(result, page)
429+
if err := afterScanAll(result); err != nil {
430+
return nil, nil, err
431+
}
405432
return result, page, nil
406433
}
407434

@@ -421,6 +448,12 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s
421448
yield(nil, err)
422449
return
423450
}
451+
if as, ok := any(&record).(AfterScanner); ok {
452+
if err := as.AfterScan(); err != nil {
453+
yield(nil, fmt.Errorf("after scan: %w", err))
454+
return
455+
}
456+
}
424457
if !yield(&record, nil) {
425458
return
426459
}

tests/pgkit_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,49 @@ func TestSlogTracerBatchQuery(t *testing.T) {
11451145
}
11461146
}
11471147

1148+
func TestAfterScan(t *testing.T) {
1149+
truncateAllTables(t)
1150+
1151+
ctx := t.Context()
1152+
1153+
hookTable := pgkit.Table[AccountWithHook, *AccountWithHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"}
1154+
failTable := pgkit.Table[AccountWithFailingHook, *AccountWithFailingHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"}
1155+
plainTable := pgkit.Table[Account, *Account, int64]{DB: DB, Name: "accounts", IDColumn: "id"}
1156+
1157+
// Seed data.
1158+
require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "alice"}}))
1159+
require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "bob"}}))
1160+
1161+
t.Run("Get", func(t *testing.T) {
1162+
acc, err := hookTable.Get(ctx, sq.Eq{"name": "alice"}, nil)
1163+
require.NoError(t, err)
1164+
assert.Equal(t, "ALICE", acc.UpperName)
1165+
})
1166+
1167+
t.Run("List", func(t *testing.T) {
1168+
accs, err := hookTable.List(ctx, nil, []string{"name"})
1169+
require.NoError(t, err)
1170+
require.Len(t, accs, 2)
1171+
assert.Equal(t, "ALICE", accs[0].UpperName)
1172+
assert.Equal(t, "BOB", accs[1].UpperName)
1173+
})
1174+
1175+
t.Run("NoHook", func(t *testing.T) {
1176+
// Account does not implement AfterScanner — should work unchanged.
1177+
acc, err := plainTable.Get(ctx, sq.Eq{"name": "alice"}, nil)
1178+
require.NoError(t, err)
1179+
assert.Equal(t, "alice", acc.Name)
1180+
})
1181+
1182+
t.Run("ErrorPropagation", func(t *testing.T) {
1183+
_, err := failTable.Get(ctx, sq.Eq{"name": "alice"}, nil)
1184+
require.ErrorContains(t, err, "after scan boom")
1185+
1186+
_, err = failTable.List(ctx, nil, nil)
1187+
require.ErrorContains(t, err, "after scan boom")
1188+
})
1189+
}
1190+
11481191
func getTracer(opts []tracer.Option) (*bytes.Buffer, *tracer.LogTracer) {
11491192
buf := &bytes.Buffer{}
11501193
handler := slog.NewJSONHandler(buf, nil)

tests/schema_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pgkit_test
22

33
import (
44
"fmt"
5+
"strings"
56
"time"
67

78
"github.com/goware/pgkit/v2/dbtype"
@@ -117,3 +118,25 @@ type Stat struct {
117118
Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype
118119
Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype
119120
}
121+
122+
type AccountWithHook struct {
123+
Account
124+
UpperName string `db:"-"`
125+
}
126+
127+
func (a *AccountWithHook) AfterScan() error {
128+
a.UpperName = strings.ToUpper(a.Name)
129+
return nil
130+
}
131+
132+
func (a *AccountWithHook) DBTableName() string { return "accounts" }
133+
134+
type AccountWithFailingHook struct {
135+
Account
136+
}
137+
138+
func (a *AccountWithFailingHook) AfterScan() error {
139+
return fmt.Errorf("after scan boom")
140+
}
141+
142+
func (a *AccountWithFailingHook) DBTableName() string { return "accounts" }

0 commit comments

Comments
 (0)