Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,48 @@ import (
"github.com/jackc/pgx/v5"
)

// AfterScanner is implemented by structs that need post-scan field hydration.
type AfterScanner interface {
AfterScan() error
}

// AfterScanError is returned when one or more records fail AfterScan.
// The Errors map is keyed by the record's ID.
type AfterScanError[I ID] struct {
Errors map[I]error
}

func (e *AfterScanError[I]) Error() string {
msg := fmt.Sprintf("after scan: %d error(s)", len(e.Errors))
for id, err := range e.Errors {
msg += fmt.Sprintf("\n- %v: %v", id, err)
}
return msg
}

func (e *AfterScanError[I]) Unwrap() []error {
errs := make([]error, 0, len(e.Errors))
for id, err := range e.Errors {
errs = append(errs, fmt.Errorf("%v: %w", id, err))
}
return errs
}

func afterScanAll[T any, I ID](records []T, key func(T) I) error {
errs := make(map[I]error)
for _, r := range records {
if as, ok := any(r).(AfterScanner); ok {
if err := as.AfterScan(); err != nil {
errs[key(r)] = err
}
}
}
if len(errs) > 0 {
return &AfterScanError[I]{Errors: errs}
}
return nil
}

// ID is a comparable type used for record IDs.
type ID comparable

Expand Down Expand Up @@ -371,6 +413,11 @@ func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []st
if err := t.Query.GetOne(ctx, q, record); err != nil {
return nil, fmt.Errorf("get record: %w", err)
}
if as, ok := any(record).(AfterScanner); ok {
if err := as.AfterScan(); err != nil {
return nil, fmt.Errorf("after scan: %w", err)
}
}

return record, nil
}
Expand All @@ -382,8 +429,7 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s
if err := t.Query.GetAll(ctx, q, &records); err != nil {
return nil, err
}

return records, nil
return records, afterScanAll(records, P.GetID)
}

// ListPaged returns paginated records matching the condition.
Expand All @@ -402,7 +448,7 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *
return nil, nil, err
}
result = t.Paginator.PrepareResult(result, page)
return result, page, nil
return result, page, afterScanAll(result, P.GetID)
}

// Iter returns an iterator for records matching the condition.
Expand All @@ -421,6 +467,12 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s
yield(nil, err)
return
}
if as, ok := any(&record).(AfterScanner); ok {
if err := as.AfterScan(); err != nil {
yield(nil, fmt.Errorf("after scan: %w", err))
return
}
}
if !yield(&record, nil) {
return
}
Expand Down
67 changes: 67 additions & 0 deletions tests/pgkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
"errors"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -1145,6 +1146,72 @@ func TestSlogTracerBatchQuery(t *testing.T) {
}
}

func TestAfterScan(t *testing.T) {
truncateAllTables(t)

ctx := t.Context()

hookTable := pgkit.Table[AccountWithHook, *AccountWithHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"}
failTable := pgkit.Table[AccountWithFailingHook, *AccountWithFailingHook, int64]{DB: DB, Name: "accounts", IDColumn: "id"}
plainTable := pgkit.Table[Account, *Account, int64]{DB: DB, Name: "accounts", IDColumn: "id"}

// Seed data: "alice" and "bob" succeed AfterScan, "fail" triggers error.
require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "alice"}}))
require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "bob"}}))
require.NoError(t, hookTable.Save(ctx, &AccountWithHook{Account: Account{Name: "fail"}}))

t.Run("Get", func(t *testing.T) {
acc, err := hookTable.Get(ctx, sq.Eq{"name": "alice"}, nil)
require.NoError(t, err)
assert.Equal(t, "ALICE", acc.UpperName)
})

t.Run("List", func(t *testing.T) {
accs, err := hookTable.List(ctx, nil, []string{"name"})
require.NoError(t, err)
require.Len(t, accs, 3)
assert.Equal(t, "ALICE", accs[0].UpperName)
assert.Equal(t, "BOB", accs[1].UpperName)
assert.Equal(t, "FAIL", accs[2].UpperName)
})

t.Run("NoHook", func(t *testing.T) {
// Account does not implement AfterScanner — should work unchanged.
acc, err := plainTable.Get(ctx, sq.Eq{"name": "alice"}, nil)
require.NoError(t, err)
assert.Equal(t, "alice", acc.Name)
})

t.Run("GetErrorPropagation", func(t *testing.T) {
_, err := failTable.Get(ctx, sq.Eq{"name": "fail"}, nil)
require.ErrorContains(t, err, "after scan boom")
})

t.Run("ListPartialFailure", func(t *testing.T) {
// "alice" and "bob" succeed, "fail" fails — all three returned.
accs, err := failTable.List(ctx, nil, []string{"name"})
require.Error(t, err)
require.Len(t, accs, 3, "all records returned despite error")

var scanErr *pgkit.AfterScanError[int64]
require.True(t, errors.As(err, &scanErr))
require.Len(t, scanErr.Errors, 1)

// Error keyed by the record's ID, not slice index.
failAcc := accs[2] // "fail" sorted last (alice, bob, fail)
_, ok := scanErr.Errors[failAcc.ID]
assert.True(t, ok, "error keyed by record ID")
})

t.Run("UnwrapTransitive", func(t *testing.T) {
_, err := failTable.List(ctx, nil, []string{"name"})
require.Error(t, err)

// errors.Is works transitively via Unwrap() []error.
assert.True(t, errors.Is(err, errAfterScanBoom))
})
}

func getTracer(opts []tracer.Option) (*bytes.Buffer, *tracer.LogTracer) {
buf := &bytes.Buffer{}
handler := slog.NewJSONHandler(buf, nil)
Expand Down
28 changes: 28 additions & 0 deletions tests/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgkit_test

import (
"fmt"
"strings"
"time"

"github.com/goware/pgkit/v2/dbtype"
Expand Down Expand Up @@ -117,3 +118,30 @@ type Stat struct {
Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype
Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype
}

type AccountWithHook struct {
Account
UpperName string `db:"-"`
}

func (a *AccountWithHook) AfterScan() error {
a.UpperName = strings.ToUpper(a.Name)
return nil
}

func (a *AccountWithHook) DBTableName() string { return "accounts" }

var errAfterScanBoom = fmt.Errorf("after scan boom")

type AccountWithFailingHook struct {
Account
}

func (a *AccountWithFailingHook) AfterScan() error {
if a.Name == "fail" {
return errAfterScanBoom
}
return nil
}

func (a *AccountWithFailingHook) DBTableName() string { return "accounts" }
Loading