From 5ff07661faa784ddbfb7eb77a1bb6b2345a50b58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Wed, 8 Apr 2026 12:54:43 +0200 Subject: [PATCH 1/4] feat: add AfterScanner hook for post-scan field hydration Refs #43 --- table.go | 33 +++++++++++++++++++++++++++++++++ tests/pgkit_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ tests/schema_test.go | 23 +++++++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/table.go b/table.go index 3e96fd5..ad85f1e 100644 --- a/table.go +++ b/table.go @@ -12,6 +12,22 @@ import ( "github.com/jackc/pgx/v5" ) +// AfterScanner is implemented by structs that need post-scan field hydration. +type AfterScanner interface { + AfterScan() error +} + +func afterScanAll[T any](records []T) error { + for i := range records { + if as, ok := any(records[i]).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + return fmt.Errorf("after scan: %w", err) + } + } + } + return nil +} + // ID is a comparable type used for record IDs. type ID comparable @@ -371,6 +387,11 @@ func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []st if err := t.Query.GetOne(ctx, q, record); err != nil { return nil, fmt.Errorf("get record: %w", err) } + if as, ok := any(record).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + return nil, fmt.Errorf("after scan: %w", err) + } + } return record, nil } @@ -382,6 +403,9 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } + if err := afterScanAll(records); err != nil { + return nil, err + } return records, nil } @@ -402,6 +426,9 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page * return nil, nil, err } result = t.Paginator.PrepareResult(result, page) + if err := afterScanAll(result); err != nil { + return nil, nil, err + } return result, page, nil } @@ -421,6 +448,12 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s yield(nil, err) return } + if as, ok := any(&record).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + yield(nil, fmt.Errorf("after scan: %w", err)) + return + } + } if !yield(&record, nil) { return } diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index c6c8530..f51e1e2 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -1145,6 +1145,49 @@ func TestSlogTracerBatchQuery(t *testing.T) { } } +func TestAfterScan(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + + hookTable := pgkit.Table[AccountWithHook, *AccountWithHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + failTable := pgkit.Table[AccountWithFailingHook, *AccountWithFailingHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + plainTable := pgkit.Table[Account, *Account, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + + // Seed data. + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "alice"}})) + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "bob"}})) + + t.Run("Get", func(t *testing.T) { + acc, err := hookTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + require.NoError(t, err) + assert.Equal(t, "ALICE", acc.UpperName) + }) + + t.Run("List", func(t *testing.T) { + accs, err := hookTable.List(ctx, nil, []string{"name"}) + require.NoError(t, err) + require.Len(t, accs, 2) + assert.Equal(t, "ALICE", accs[0].UpperName) + assert.Equal(t, "BOB", accs[1].UpperName) + }) + + t.Run("NoHook", func(t *testing.T) { + // Account does not implement AfterScanner — should work unchanged. + acc, err := plainTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + require.NoError(t, err) + assert.Equal(t, "alice", acc.Name) + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + _, err := failTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + require.ErrorContains(t, err, "after scan boom") + + _, err = failTable.List(ctx, nil, nil) + require.ErrorContains(t, err, "after scan boom") + }) +} + func getTracer(opts []tracer.Option) (*bytes.Buffer, *tracer.LogTracer) { buf := &bytes.Buffer{} handler := slog.NewJSONHandler(buf, nil) diff --git a/tests/schema_test.go b/tests/schema_test.go index 05454d8..884f364 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -2,6 +2,7 @@ package pgkit_test import ( "fmt" + "strings" "time" "github.com/goware/pgkit/v2/dbtype" @@ -117,3 +118,25 @@ type Stat struct { Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype } + +type AccountWithHook struct { + Account + UpperName string `db:"-"` +} + +func (a *AccountWithHook) AfterScan() error { + a.UpperName = strings.ToUpper(a.Name) + return nil +} + +func (a *AccountWithHook) DBTableName() string { return "accounts" } + +type AccountWithFailingHook struct { + Account +} + +func (a *AccountWithFailingHook) AfterScan() error { + return fmt.Errorf("after scan boom") +} + +func (a *AccountWithFailingHook) DBTableName() string { return "accounts" } From 5d4fd530318d57f337fd2c74c0f40114c07e259a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Wed, 8 Apr 2026 13:46:06 +0200 Subject: [PATCH 2/4] feat: add AfterScanError for partial failure in List/ListPaged Returns all records even when some AfterScan calls fail. AfterScanError carries a map[int]error keyed by index and implements Unwrap() []error for errors.Is transitivity. Refs #43 --- table.go | 35 +++++++++++++++++++++++++---------- tests/pgkit_test.go | 35 +++++++++++++++++++++++++++++------ tests/schema_test.go | 7 ++++++- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/table.go b/table.go index ad85f1e..5a8cb4d 100644 --- a/table.go +++ b/table.go @@ -17,14 +17,36 @@ type AfterScanner interface { AfterScan() error } +// AfterScanError is returned when one or more records fail AfterScan. +// The Errors map is keyed by the record's index in the returned slice. +type AfterScanError struct { + Errors map[int]error +} + +func (e *AfterScanError) Error() string { + return fmt.Sprintf("after scan: %d error(s)", len(e.Errors)) +} + +func (e *AfterScanError) Unwrap() []error { + errs := make([]error, 0, len(e.Errors)) + for _, err := range e.Errors { + errs = append(errs, err) + } + return errs +} + func afterScanAll[T any](records []T) error { + errs := make(map[int]error) for i := range records { if as, ok := any(records[i]).(AfterScanner); ok { if err := as.AfterScan(); err != nil { - return fmt.Errorf("after scan: %w", err) + errs[i] = err } } } + if len(errs) > 0 { + return &AfterScanError{Errors: errs} + } return nil } @@ -403,11 +425,7 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } - if err := afterScanAll(records); err != nil { - return nil, err - } - - return records, nil + return records, afterScanAll(records) } // ListPaged returns paginated records matching the condition. @@ -426,10 +444,7 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page * return nil, nil, err } result = t.Paginator.PrepareResult(result, page) - if err := afterScanAll(result); err != nil { - return nil, nil, err - } - return result, page, nil + return result, page, afterScanAll(result) } // Iter returns an iterator for records matching the condition. diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index f51e1e2..404a6e2 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "encoding/json" "fmt" "io" @@ -1154,9 +1155,10 @@ func TestAfterScan(t *testing.T) { failTable := pgkit.Table[AccountWithFailingHook, *AccountWithFailingHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"} plainTable := pgkit.Table[Account, *Account, int64]{DB: DB, Name: "accounts", IDColumn: "id"} - // Seed data. + // Seed data: "alice" and "bob" succeed AfterScan, "fail" triggers error. require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "alice"}})) require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "bob"}})) + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "fail"}})) t.Run("Get", func(t *testing.T) { acc, err := hookTable.Get(ctx, sq.Eq{"name": "alice"}, nil) @@ -1167,9 +1169,10 @@ func TestAfterScan(t *testing.T) { t.Run("List", func(t *testing.T) { accs, err := hookTable.List(ctx, nil, []string{"name"}) require.NoError(t, err) - require.Len(t, accs, 2) + require.Len(t, accs, 3) assert.Equal(t, "ALICE", accs[0].UpperName) assert.Equal(t, "BOB", accs[1].UpperName) + assert.Equal(t, "FAIL", accs[2].UpperName) }) t.Run("NoHook", func(t *testing.T) { @@ -1179,12 +1182,32 @@ func TestAfterScan(t *testing.T) { assert.Equal(t, "alice", acc.Name) }) - t.Run("ErrorPropagation", func(t *testing.T) { - _, err := failTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + t.Run("GetErrorPropagation", func(t *testing.T) { + _, err := failTable.Get(ctx, sq.Eq{"name": "fail"}, nil) require.ErrorContains(t, err, "after scan boom") + }) - _, err = failTable.List(ctx, nil, nil) - require.ErrorContains(t, err, "after scan boom") + t.Run("ListPartialFailure", func(t *testing.T) { + // "alice" and "bob" succeed, "fail" fails — all three returned. + accs, err := failTable.List(ctx, nil, []string{"name"}) + require.Error(t, err) + require.Len(t, accs, 3, "all records returned despite error") + + var scanErr *pgkit.AfterScanError + require.True(t, errors.As(err, &scanErr)) + require.Len(t, scanErr.Errors, 1) + + // "fail" sorts to index 1 (alice=0, fail=1, bob would be... let me order: alice, bob, fail → index 2) + _, ok := scanErr.Errors[2] + assert.True(t, ok, "error keyed by index of failing record") + }) + + t.Run("UnwrapTransitive", func(t *testing.T) { + _, err := failTable.List(ctx, nil, []string{"name"}) + require.Error(t, err) + + // errors.Is works transitively via Unwrap() []error. + assert.True(t, errors.Is(err, errAfterScanBoom)) }) } diff --git a/tests/schema_test.go b/tests/schema_test.go index 884f364..57b5b5b 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -131,12 +131,17 @@ func (a *AccountWithHook) AfterScan() error { func (a *AccountWithHook) DBTableName() string { return "accounts" } +var errAfterScanBoom = fmt.Errorf("after scan boom") + type AccountWithFailingHook struct { Account } func (a *AccountWithFailingHook) AfterScan() error { - return fmt.Errorf("after scan boom") + if a.Name == "fail" { + return errAfterScanBoom + } + return nil } func (a *AccountWithFailingHook) DBTableName() string { return "accounts" } From 497cfa23798bf3ef83fbf84f8b0f7a0d2a222cb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Wed, 8 Apr 2026 13:59:04 +0200 Subject: [PATCH 3/4] feat: make AfterScanError generic on record ID type Errors map keyed by record ID instead of slice index. Error message lists each failing ID and its error. Refs #43 --- table.go | 32 ++++++++++++++++++-------------- tests/pgkit_test.go | 9 +++++---- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/table.go b/table.go index 5a8cb4d..f17c3f1 100644 --- a/table.go +++ b/table.go @@ -18,16 +18,20 @@ type AfterScanner interface { } // AfterScanError is returned when one or more records fail AfterScan. -// The Errors map is keyed by the record's index in the returned slice. -type AfterScanError struct { - Errors map[int]error +// The Errors map is keyed by the record's ID. +type AfterScanError[I ID] struct { + Errors map[I]error } -func (e *AfterScanError) Error() string { - return fmt.Sprintf("after scan: %d error(s)", len(e.Errors)) +func (e *AfterScanError[I]) Error() string { + msg := fmt.Sprintf("after scan: %d error(s)", len(e.Errors)) + for id, err := range e.Errors { + msg += fmt.Sprintf("\n- %v: %v", id, err) + } + return msg } -func (e *AfterScanError) Unwrap() []error { +func (e *AfterScanError[I]) Unwrap() []error { errs := make([]error, 0, len(e.Errors)) for _, err := range e.Errors { errs = append(errs, err) @@ -35,17 +39,17 @@ func (e *AfterScanError) Unwrap() []error { return errs } -func afterScanAll[T any](records []T) error { - errs := make(map[int]error) - for i := range records { - if as, ok := any(records[i]).(AfterScanner); ok { +func afterScanAll[T any, I ID](records []T, key func(T) I) error { + errs := make(map[I]error) + for _, r := range records { + if as, ok := any(r).(AfterScanner); ok { if err := as.AfterScan(); err != nil { - errs[i] = err + errs[key(r)] = err } } } if len(errs) > 0 { - return &AfterScanError{Errors: errs} + return &AfterScanError[I]{Errors: errs} } return nil } @@ -425,7 +429,7 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } - return records, afterScanAll(records) + return records, afterScanAll(records, P.GetID) } // ListPaged returns paginated records matching the condition. @@ -444,7 +448,7 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page * return nil, nil, err } result = t.Paginator.PrepareResult(result, page) - return result, page, afterScanAll(result) + return result, page, afterScanAll(result, P.GetID) } // Iter returns an iterator for records matching the condition. diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index 404a6e2..bdedd1b 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -1193,13 +1193,14 @@ func TestAfterScan(t *testing.T) { require.Error(t, err) require.Len(t, accs, 3, "all records returned despite error") - var scanErr *pgkit.AfterScanError + var scanErr *pgkit.AfterScanError[int64] require.True(t, errors.As(err, &scanErr)) require.Len(t, scanErr.Errors, 1) - // "fail" sorts to index 1 (alice=0, fail=1, bob would be... let me order: alice, bob, fail → index 2) - _, ok := scanErr.Errors[2] - assert.True(t, ok, "error keyed by index of failing record") + // Error keyed by the record's ID, not slice index. + failAcc := accs[2] // "fail" sorted last (alice, bob, fail) + _, ok := scanErr.Errors[failAcc.ID] + assert.True(t, ok, "error keyed by record ID") }) t.Run("UnwrapTransitive", func(t *testing.T) { From 740d77e98755cfd1d5b278033a7b94ad6ed09b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Wed, 8 Apr 2026 14:01:38 +0200 Subject: [PATCH 4/4] fix: include record ID in AfterScanError unwrapped errors Refs #43 --- table.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/table.go b/table.go index f17c3f1..bc4622d 100644 --- a/table.go +++ b/table.go @@ -33,8 +33,8 @@ func (e *AfterScanError[I]) Error() string { func (e *AfterScanError[I]) Unwrap() []error { errs := make([]error, 0, len(e.Errors)) - for _, err := range e.Errors { - errs = append(errs, err) + for id, err := range e.Errors { + errs = append(errs, fmt.Errorf("%v: %w", id, err)) } return errs }