diff --git a/table.go b/table.go index 3e96fd5..bc4622d 100644 --- a/table.go +++ b/table.go @@ -12,6 +12,48 @@ import ( "github.com/jackc/pgx/v5" ) +// AfterScanner is implemented by structs that need post-scan field hydration. +type AfterScanner interface { + AfterScan() error +} + +// AfterScanError is returned when one or more records fail AfterScan. +// The Errors map is keyed by the record's ID. +type AfterScanError[I ID] struct { + Errors map[I]error +} + +func (e *AfterScanError[I]) Error() string { + msg := fmt.Sprintf("after scan: %d error(s)", len(e.Errors)) + for id, err := range e.Errors { + msg += fmt.Sprintf("\n- %v: %v", id, err) + } + return msg +} + +func (e *AfterScanError[I]) Unwrap() []error { + errs := make([]error, 0, len(e.Errors)) + for id, err := range e.Errors { + errs = append(errs, fmt.Errorf("%v: %w", id, err)) + } + return errs +} + +func afterScanAll[T any, I ID](records []T, key func(T) I) error { + errs := make(map[I]error) + for _, r := range records { + if as, ok := any(r).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + errs[key(r)] = err + } + } + } + if len(errs) > 0 { + return &AfterScanError[I]{Errors: errs} + } + return nil +} + // ID is a comparable type used for record IDs. type ID comparable @@ -371,6 +413,11 @@ func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []st if err := t.Query.GetOne(ctx, q, record); err != nil { return nil, fmt.Errorf("get record: %w", err) } + if as, ok := any(record).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + return nil, fmt.Errorf("after scan: %w", err) + } + } return record, nil } @@ -382,8 +429,7 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } - - return records, nil + return records, afterScanAll(records, P.GetID) } // ListPaged returns paginated records matching the condition. @@ -402,7 +448,7 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page * return nil, nil, err } result = t.Paginator.PrepareResult(result, page) - return result, page, nil + return result, page, afterScanAll(result, P.GetID) } // Iter returns an iterator for records matching the condition. @@ -421,6 +467,12 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s yield(nil, err) return } + if as, ok := any(&record).(AfterScanner); ok { + if err := as.AfterScan(); err != nil { + yield(nil, fmt.Errorf("after scan: %w", err)) + return + } + } if !yield(&record, nil) { return } diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index c6c8530..bdedd1b 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "encoding/json" "fmt" "io" @@ -1145,6 +1146,72 @@ func TestSlogTracerBatchQuery(t *testing.T) { } } +func TestAfterScan(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + + hookTable := pgkit.Table[AccountWithHook, *AccountWithHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + failTable := pgkit.Table[AccountWithFailingHook, *AccountWithFailingHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + plainTable := pgkit.Table[Account, *Account, int64]{DB: DB, Name: "accounts", IDColumn: "id"} + + // Seed data: "alice" and "bob" succeed AfterScan, "fail" triggers error. + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "alice"}})) + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "bob"}})) + require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "fail"}})) + + t.Run("Get", func(t *testing.T) { + acc, err := hookTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + require.NoError(t, err) + assert.Equal(t, "ALICE", acc.UpperName) + }) + + t.Run("List", func(t *testing.T) { + accs, err := hookTable.List(ctx, nil, []string{"name"}) + require.NoError(t, err) + require.Len(t, accs, 3) + assert.Equal(t, "ALICE", accs[0].UpperName) + assert.Equal(t, "BOB", accs[1].UpperName) + assert.Equal(t, "FAIL", accs[2].UpperName) + }) + + t.Run("NoHook", func(t *testing.T) { + // Account does not implement AfterScanner — should work unchanged. + acc, err := plainTable.Get(ctx, sq.Eq{"name": "alice"}, nil) + require.NoError(t, err) + assert.Equal(t, "alice", acc.Name) + }) + + t.Run("GetErrorPropagation", func(t *testing.T) { + _, err := failTable.Get(ctx, sq.Eq{"name": "fail"}, nil) + require.ErrorContains(t, err, "after scan boom") + }) + + t.Run("ListPartialFailure", func(t *testing.T) { + // "alice" and "bob" succeed, "fail" fails — all three returned. + accs, err := failTable.List(ctx, nil, []string{"name"}) + require.Error(t, err) + require.Len(t, accs, 3, "all records returned despite error") + + var scanErr *pgkit.AfterScanError[int64] + require.True(t, errors.As(err, &scanErr)) + require.Len(t, scanErr.Errors, 1) + + // Error keyed by the record's ID, not slice index. + failAcc := accs[2] // "fail" sorted last (alice, bob, fail) + _, ok := scanErr.Errors[failAcc.ID] + assert.True(t, ok, "error keyed by record ID") + }) + + t.Run("UnwrapTransitive", func(t *testing.T) { + _, err := failTable.List(ctx, nil, []string{"name"}) + require.Error(t, err) + + // errors.Is works transitively via Unwrap() []error. + assert.True(t, errors.Is(err, errAfterScanBoom)) + }) +} + func getTracer(opts []tracer.Option) (*bytes.Buffer, *tracer.LogTracer) { buf := &bytes.Buffer{} handler := slog.NewJSONHandler(buf, nil) diff --git a/tests/schema_test.go b/tests/schema_test.go index 05454d8..57b5b5b 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -2,6 +2,7 @@ package pgkit_test import ( "fmt" + "strings" "time" "github.com/goware/pgkit/v2/dbtype" @@ -117,3 +118,30 @@ type Stat struct { Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype } + +type AccountWithHook struct { + Account + UpperName string `db:"-"` +} + +func (a *AccountWithHook) AfterScan() error { + a.UpperName = strings.ToUpper(a.Name) + return nil +} + +func (a *AccountWithHook) DBTableName() string { return "accounts" } + +var errAfterScanBoom = fmt.Errorf("after scan boom") + +type AccountWithFailingHook struct { + Account +} + +func (a *AccountWithFailingHook) AfterScan() error { + if a.Name == "fail" { + return errAfterScanBoom + } + return nil +} + +func (a *AccountWithFailingHook) DBTableName() string { return "accounts" }