From d79a6db0ca0d126e26e18ed2e2cd01214e953835 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 16:09:54 +0200 Subject: [PATCH 01/34] Implement generic Table for basic CRUD operations Co-authored-by: David Sedlacek --- table.go | 196 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 table.go diff --git a/table.go b/table.go new file mode 100644 index 0000000..023b3c6 --- /dev/null +++ b/table.go @@ -0,0 +1,196 @@ +package pgkit + +import ( + "context" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" +) + +// Table provides basic CRUD operations for database records. +// Records must implement GetID() and Validate() methods. +type Table[T any, PT interface { + *T // Enforce T is a pointer; and thus all methods are defined on a pointer receiver. + GetID() IDT + Validate() error +}, IDT comparable] struct { + *DB + Name string + IDColumn string +} + +type hasUpdatedAt interface { + SetUpdatedAt(time.Time) +} + +type hasDeletedAt interface { + SetDeletedAt(time.Time) +} + +// Save inserts or updates a record. Auto-detects insert vs update by ID. +func (t *Table[T, PT, ID]) Save(ctx context.Context, record PT) error { + if err := record.Validate(); err != nil { + return err //nolint:wrapcheck + } + + if row, ok := any(record).(hasUpdatedAt); ok { + row.SetUpdatedAt(time.Now().UTC()) + } + + // Insert + var zero ID + if record.GetID() == zero { + q := t.SQL.InsertRecord(record).Into(t.Name).Suffix("RETURNING *") + if err := t.Query.GetOne(ctx, q, record); err != nil { + return fmt.Errorf("insert records: %w", err) + } + + return nil + } + + // Update + q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("update record: %w", err) + } + + return nil +} + +// SaveAll saves multiple records sequentially. +func (t *Table[T, PT, ID]) SaveAll(ctx context.Context, records []PT) error { + for _, record := range records { + if err := t.Save(ctx, record); err != nil { + return err + } + } + + return nil +} + +// GetOne returns the first record matching the condition. +func (t *Table[T, PT, ID]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy []string) (PT, error) { + cond = t.appendDeletedAtNULL(cond) + + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } + + dest := new(T) + + q := t.SQL. + Select("*"). + From(t.Name). + Where(cond). + Limit(1). + OrderBy(orderBy...) + + if err := t.Query.GetOne(ctx, q, dest); err != nil { + return nil, err + } + + return dest, nil +} + +// GetAll returns all records matching the condition. +func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy []string) ([]PT, error) { + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } + + cond = t.appendDeletedAtNULL(cond) + + q := t.SQL. + Select("*"). + From(t.Name). + Where(cond). + OrderBy(orderBy...) + + var dest []PT + if err := t.Query.GetAll(ctx, q, &dest); err != nil { + return nil, err + } + + return dest, nil +} + +// GetByID returns a record by its ID. +func (t *Table[T, PT, ID]) GetByID(ctx context.Context, id uint64) (PT, error) { + return t.GetOne(ctx, t.appendDeletedAtNULL(sq.Eq{t.IDColumn: id}), []string{t.IDColumn}) +} + +// GetByIDs returns records by their IDs. +func (t *Table[T, PT, ID]) GetByIDs(ctx context.Context, ids []uint64) ([]PT, error) { + return t.GetAll(ctx, t.appendDeletedAtNULL(sq.Eq{t.IDColumn: ids}), nil) +} + +// Count returns the number of matching records. +func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (uint64, error) { + cond = t.appendDeletedAtNULL(cond) + + var count uint64 + q := t.SQL. + Select("COUNT(1)"). + From(t.Name). + Where(cond) + + if err := t.Query.GetOne(ctx, q, &count); err != nil { + return 0, fmt.Errorf("get one: %w", err) + } + + return count, nil +} + +// DeleteByID deletes a record by ID. Uses soft delete if deleted_at column exists. +func (t *Table[T, PT, ID]) DeleteByID(ctx context.Context, id uint64) error { + resource, err := t.GetByID(ctx, id) + if err != nil { + return err + } + + // Soft delete. + if row, ok := any(resource).(hasDeletedAt); ok { + row.SetDeletedAt(time.Now().UTC()) + return t.Save(ctx, resource) + } + + // Hard delete for tables without timestamps + return t.HardDeleteByID(ctx, id) +} + +// HardDeleteByID permanently deletes a record by ID. +func (t *Table[T, PT, ID]) HardDeleteByID(ctx context.Context, id uint64) error { + _, err := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}).Exec() + return err +} + +func (t *Table[T, PT, ID]) appendDeletedAtNULL(cond sq.Sqlizer) sq.Sqlizer { + var zero PT + if _, ok := any(zero).(hasDeletedAt); ok { + condDeletedAt := sq.Eq{"deleted_at": nil} + if cond == nil { + cond = condDeletedAt + } else { + cond = sq.And{ + cond, + condDeletedAt, + } + } + } + + return cond +} + +// WithTx returns a table instance bound to the given transaction. +func (t *Table[T, TP, ID]) WithTx(tx pgx.Tx) *Table[T, TP, ID] { + return &Table[T, TP, ID]{ + DB: &DB{ + Conn: t.DB.Conn, + SQL: t.DB.SQL, + Query: t.DB.TxQuery(tx), + }, + Name: t.Name, + } +} From c42ba49e83e69e018bf56c94f0e95dea09c1a964 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 16:12:36 +0200 Subject: [PATCH 02/34] Remove automatic 'deleted_at NULL' condition --- table.go | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/table.go b/table.go index 023b3c6..5ba74a7 100644 --- a/table.go +++ b/table.go @@ -72,8 +72,6 @@ func (t *Table[T, PT, ID]) SaveAll(ctx context.Context, records []PT) error { // GetOne returns the first record matching the condition. func (t *Table[T, PT, ID]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy []string) (PT, error) { - cond = t.appendDeletedAtNULL(cond) - if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -100,8 +98,6 @@ func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy orderBy = []string{t.IDColumn} } - cond = t.appendDeletedAtNULL(cond) - q := t.SQL. Select("*"). From(t.Name). @@ -118,18 +114,16 @@ func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy // GetByID returns a record by its ID. func (t *Table[T, PT, ID]) GetByID(ctx context.Context, id uint64) (PT, error) { - return t.GetOne(ctx, t.appendDeletedAtNULL(sq.Eq{t.IDColumn: id}), []string{t.IDColumn}) + return t.GetOne(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } // GetByIDs returns records by their IDs. func (t *Table[T, PT, ID]) GetByIDs(ctx context.Context, ids []uint64) ([]PT, error) { - return t.GetAll(ctx, t.appendDeletedAtNULL(sq.Eq{t.IDColumn: ids}), nil) + return t.GetAll(ctx, sq.Eq{t.IDColumn: ids}, nil) } // Count returns the number of matching records. func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (uint64, error) { - cond = t.appendDeletedAtNULL(cond) - var count uint64 q := t.SQL. Select("COUNT(1)"). @@ -166,23 +160,6 @@ func (t *Table[T, PT, ID]) HardDeleteByID(ctx context.Context, id uint64) error return err } -func (t *Table[T, PT, ID]) appendDeletedAtNULL(cond sq.Sqlizer) sq.Sqlizer { - var zero PT - if _, ok := any(zero).(hasDeletedAt); ok { - condDeletedAt := sq.Eq{"deleted_at": nil} - if cond == nil { - cond = condDeletedAt - } else { - cond = sq.And{ - cond, - condDeletedAt, - } - } - } - - return cond -} - // WithTx returns a table instance bound to the given transaction. func (t *Table[T, TP, ID]) WithTx(tx pgx.Tx) *Table[T, TP, ID] { return &Table[T, TP, ID]{ From 89f19c23d179079e20539852af584e50dcdb3e7b Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 16:13:51 +0200 Subject: [PATCH 03/34] Implement .LockForUpdate() using PostgreSQL's FOR UPDATE SKIP LOCKED pattern --- table.go | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/table.go b/table.go index 5ba74a7..ec5e4a1 100644 --- a/table.go +++ b/table.go @@ -171,3 +171,41 @@ func (t *Table[T, TP, ID]) WithTx(tx pgx.Tx) *Table[T, TP, ID] { Name: t.Name, } } + +// LockForUpdate locks 0..limit records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// for safe concurrent processing where each record is processed exactly once. +// Complete updateFn() quickly to avoid holding the transaction. For long-running work: +// update status to "processing" and return early, then process asynchronously. +func (t *Table[T, TP, ID]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func([]TP)) error { + return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } + + tx := t.WithTx(pgTx) + + q := tx.SQL. + Select("*"). + From(t.Name). + Where(cond). + OrderBy(orderBy...). + Limit(limit). + Suffix("FOR UPDATE SKIP LOCKED") + + var records []TP + if err := tx.Query.GetAll(ctx, q, &records); err != nil { + return fmt.Errorf("select for update skip locked: %w", err) + } + + updateFn(records) + + for _, record := range records { + q := tx.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := tx.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("update record: %w", err) + } + } + + return nil + }) +} From a5d9aed4cc31517e6e1c69548bbdd9b8d22fb1b9 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 19:47:30 +0200 Subject: [PATCH 04/34] Fix generic ID --- table.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/table.go b/table.go index ec5e4a1..34235f3 100644 --- a/table.go +++ b/table.go @@ -113,32 +113,32 @@ func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy } // GetByID returns a record by its ID. -func (t *Table[T, PT, ID]) GetByID(ctx context.Context, id uint64) (PT, error) { +func (t *Table[T, PT, ID]) GetByID(ctx context.Context, id ID) (PT, error) { return t.GetOne(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } // GetByIDs returns records by their IDs. -func (t *Table[T, PT, ID]) GetByIDs(ctx context.Context, ids []uint64) ([]PT, error) { +func (t *Table[T, PT, ID]) GetByIDs(ctx context.Context, ids []ID) ([]PT, error) { return t.GetAll(ctx, sq.Eq{t.IDColumn: ids}, nil) } // Count returns the number of matching records. -func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (uint64, error) { - var count uint64 +func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (ID, error) { + var count ID q := t.SQL. Select("COUNT(1)"). From(t.Name). Where(cond) if err := t.Query.GetOne(ctx, q, &count); err != nil { - return 0, fmt.Errorf("get one: %w", err) + return count, fmt.Errorf("get one: %w", err) } return count, nil } // DeleteByID deletes a record by ID. Uses soft delete if deleted_at column exists. -func (t *Table[T, PT, ID]) DeleteByID(ctx context.Context, id uint64) error { +func (t *Table[T, PT, ID]) DeleteByID(ctx context.Context, id ID) error { resource, err := t.GetByID(ctx, id) if err != nil { return err @@ -155,7 +155,7 @@ func (t *Table[T, PT, ID]) DeleteByID(ctx context.Context, id uint64) error { } // HardDeleteByID permanently deletes a record by ID. -func (t *Table[T, PT, ID]) HardDeleteByID(ctx context.Context, id uint64) error { +func (t *Table[T, PT, ID]) HardDeleteByID(ctx context.Context, id ID) error { _, err := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}).Exec() return err } From f12a47b69e22a9ed6a219ee0050a7e7c52cb982e Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 19:51:34 +0200 Subject: [PATCH 05/34] Fix naming of generic types --- table.go | 53 +++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/table.go b/table.go index 34235f3..320e23d 100644 --- a/table.go +++ b/table.go @@ -21,26 +21,26 @@ type Table[T any, PT interface { IDColumn string } -type hasUpdatedAt interface { +type hasSetUpdatedAt interface { SetUpdatedAt(time.Time) } -type hasDeletedAt interface { +type hasSetDeletedAt interface { SetDeletedAt(time.Time) } // Save inserts or updates a record. Auto-detects insert vs update by ID. -func (t *Table[T, PT, ID]) Save(ctx context.Context, record PT) error { +func (t *Table[T, PT, IDT]) Save(ctx context.Context, record PT) error { if err := record.Validate(); err != nil { return err //nolint:wrapcheck } - if row, ok := any(record).(hasUpdatedAt); ok { + if row, ok := any(record).(hasSetUpdatedAt); ok { row.SetUpdatedAt(time.Now().UTC()) } // Insert - var zero ID + var zero IDT if record.GetID() == zero { q := t.SQL.InsertRecord(record).Into(t.Name).Suffix("RETURNING *") if err := t.Query.GetOne(ctx, q, record); err != nil { @@ -60,9 +60,9 @@ func (t *Table[T, PT, ID]) Save(ctx context.Context, record PT) error { } // SaveAll saves multiple records sequentially. -func (t *Table[T, PT, ID]) SaveAll(ctx context.Context, records []PT) error { - for _, record := range records { - if err := t.Save(ctx, record); err != nil { +func (t *Table[T, PT, IDT]) SaveAll(ctx context.Context, records []PT) error { + for i := range records { + if err := t.Save(ctx, records[i]); err != nil { return err } } @@ -71,7 +71,7 @@ func (t *Table[T, PT, ID]) SaveAll(ctx context.Context, records []PT) error { } // GetOne returns the first record matching the condition. -func (t *Table[T, PT, ID]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy []string) (PT, error) { +func (t *Table[T, PT, IDT]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy []string) (PT, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -93,7 +93,7 @@ func (t *Table[T, PT, ID]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy } // GetAll returns all records matching the condition. -func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy []string) ([]PT, error) { +func (t *Table[T, PT, IDT]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy []string) ([]PT, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -113,18 +113,18 @@ func (t *Table[T, PT, ID]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy } // GetByID returns a record by its ID. -func (t *Table[T, PT, ID]) GetByID(ctx context.Context, id ID) (PT, error) { +func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) { return t.GetOne(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } // GetByIDs returns records by their IDs. -func (t *Table[T, PT, ID]) GetByIDs(ctx context.Context, ids []ID) ([]PT, error) { +func (t *Table[T, PT, IDT]) GetByIDs(ctx context.Context, ids []IDT) ([]PT, error) { return t.GetAll(ctx, sq.Eq{t.IDColumn: ids}, nil) } // Count returns the number of matching records. -func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (ID, error) { - var count ID +func (t *Table[T, PT, IDT]) Count(ctx context.Context, cond sq.Sqlizer) (uint64, error) { + var count uint64 q := t.SQL. Select("COUNT(1)"). From(t.Name). @@ -137,32 +137,33 @@ func (t *Table[T, PT, ID]) Count(ctx context.Context, cond sq.Sqlizer) (ID, erro return count, nil } -// DeleteByID deletes a record by ID. Uses soft delete if deleted_at column exists. -func (t *Table[T, PT, ID]) DeleteByID(ctx context.Context, id ID) error { - resource, err := t.GetByID(ctx, id) +// DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists. +func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { + record, err := t.GetByID(ctx, id) if err != nil { return err } // Soft delete. - if row, ok := any(resource).(hasDeletedAt); ok { + if row, ok := any(record).(hasSetDeletedAt); ok { row.SetDeletedAt(time.Now().UTC()) - return t.Save(ctx, resource) + return t.Save(ctx, record) } - // Hard delete for tables without timestamps + // Hard delete for tables without timestamps. return t.HardDeleteByID(ctx, id) } // HardDeleteByID permanently deletes a record by ID. -func (t *Table[T, PT, ID]) HardDeleteByID(ctx context.Context, id ID) error { - _, err := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}).Exec() +func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error { + q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) + _, err := t.Query.Exec(ctx, q) return err } // WithTx returns a table instance bound to the given transaction. -func (t *Table[T, TP, ID]) WithTx(tx pgx.Tx) *Table[T, TP, ID] { - return &Table[T, TP, ID]{ +func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { + return &Table[T, PT, IDT]{ DB: &DB{ Conn: t.DB.Conn, SQL: t.DB.SQL, @@ -176,7 +177,7 @@ func (t *Table[T, TP, ID]) WithTx(tx pgx.Tx) *Table[T, TP, ID] { // for safe concurrent processing where each record is processed exactly once. // Complete updateFn() quickly to avoid holding the transaction. For long-running work: // update status to "processing" and return early, then process asynchronously. -func (t *Table[T, TP, ID]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func([]TP)) error { +func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func([]PT)) error { return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} @@ -192,7 +193,7 @@ func (t *Table[T, TP, ID]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, o Limit(limit). Suffix("FOR UPDATE SKIP LOCKED") - var records []TP + var records []PT if err := tx.Query.GetAll(ctx, q, &records); err != nil { return fmt.Errorf("select for update skip locked: %w", err) } From 3569547e1cf874acc244edd6c9ae3581cd9e8472 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 13 Oct 2025 22:18:10 +0200 Subject: [PATCH 06/34] Add tests for simple CRUD, complex transactions and LockForUpdate() --- tests/database_test.go | 54 ++++++++ tests/schema_test.go | 90 ++++++++++--- tests/table_test.go | 223 +++++++++++++++++++++++++++++++ tests/testdata/pgkit_test_db.sql | 30 +++-- 4 files changed, 367 insertions(+), 30 deletions(-) create mode 100644 tests/database_test.go create mode 100644 tests/table_test.go diff --git a/tests/database_test.go b/tests/database_test.go new file mode 100644 index 0000000..d236678 --- /dev/null +++ b/tests/database_test.go @@ -0,0 +1,54 @@ +package pgkit_test + +import ( + "context" + + "github.com/goware/pgkit/v2" + "github.com/jackc/pgx/v5" +) + +type Database struct { + *pgkit.DB + + Accounts *accountsTable + Articles *articlesTable + Reviews *reviewsTable +} + +func initDB(db *pgkit.DB) *Database { + return &Database{ + DB: db, + Accounts: &accountsTable{Table: &pgkit.Table[Account, *Account, int64]{DB: db, Name: "accounts", IDColumn: "id"}}, + Articles: &articlesTable{Table: &pgkit.Table[Article, *Article, uint64]{DB: db, Name: "articles", IDColumn: "id"}}, + Reviews: &reviewsTable{Table: &pgkit.Table[Review, *Review, uint64]{DB: db, Name: "reviews", IDColumn: "id"}}, + } +} + +func (db *Database) BeginTx(ctx context.Context, fn func(tx *Database) error) error { + return pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + tx := db.WithTxQuery(pgTx) + return fn(tx) + }) +} + +func (db *Database) WithTxQuery(tx pgx.Tx) *Database { + pgkitDB := &pgkit.DB{ + Conn: db.Conn, + SQL: db.SQL, + Query: db.TxQuery(tx), + } + + return initDB(pgkitDB) +} + +type accountsTable struct { + *pgkit.Table[Account, *Account, int64] +} + +type articlesTable struct { + *pgkit.Table[Article, *Article, uint64] +} + +type reviewsTable struct { + *pgkit.Table[Review, *Review, uint64] +} diff --git a/tests/schema_test.go b/tests/schema_test.go index 3c40bf5..04fe1f9 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -1,6 +1,7 @@ package pgkit_test import ( + "fmt" "time" "github.com/goware/pgkit/v2/dbtype" @@ -11,19 +12,83 @@ type Account struct { Name string `db:"name"` Disabled bool `db:"disabled"` CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT } -func (a *Account) DBTableName() string { - return "accounts" +func (a *Account) DBTableName() string { return "accounts" } +func (a *Account) GetID() int64 { return a.ID } +func (a *Account) SetUpdatedAt(t time.Time) { a.UpdatedAt = t } + +func (a *Account) Validate() error { + if a.Name == "" { + return fmt.Errorf("name is required") + } + + return nil +} + +type Article struct { + ID uint64 `db:"id,omitempty"` + Author string `db:"author"` + Alias *string `db:"alias"` + Content Content `db:"content"` // using JSONB postgres datatype + AccountID int64 `db:"account_id"` + CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + DeletedAt *time.Time `db:"deleted_at"` +} + +func (a *Article) GetID() uint64 { return a.ID } +func (a *Article) SetUpdatedAt(t time.Time) { a.UpdatedAt = t } +func (a *Article) SetDeletedAt(t time.Time) { a.DeletedAt = &t } + +func (a *Article) Validate() error { + if a.Author == "" { + return fmt.Errorf("author is required") + } + + return nil +} + +type Content struct { + Title string `json:"title"` + Body string `json:"body"` + Views int64 `json:"views"` } type Review struct { - ID int64 `db:"id,omitempty"` - Name string `db:"name"` - Comments string `db:"comments"` - CreatedAt time.Time `db:"created_at"` // if unset, will store Go zero-value + ID uint64 `db:"id,omitempty"` + Comment string `db:"comment"` + Status ReviewStatus `db:"status"` + Sentiment int64 `db:"sentiment"` + AccountID int64 `db:"account_id"` + ArticleID uint64 `db:"article_id"` + CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + DeletedAt *time.Time `db:"deleted_at"` } +func (r *Review) GetID() uint64 { return r.ID } +func (r *Review) SetUpdatedAt(t time.Time) { r.UpdatedAt = t } +func (r *Review) SetDeletedAt(t time.Time) { r.DeletedAt = &t } + +func (r *Review) Validate() error { + if len(r.Comment) < 3 { + return fmt.Errorf("comment too short") + } + + return nil +} + +type ReviewStatus int64 + +const ( + ReviewStatusPending ReviewStatus = iota + ReviewStatusProcessing + ReviewStatusApproved + ReviewStatusRejected +) + type Log struct { ID int64 `db:"id,omitempty"` Message string `db:"message"` @@ -38,16 +103,3 @@ type Stat struct { Num dbtype.BigInt `db:"big_num"` // using NUMERIC(78,0) postgres datatype Rating dbtype.BigInt `db:"rating"` // using NUMERIC(78,0) postgres datatype } - -type Article struct { - ID int64 `db:"id,omitempty"` - Author string `db:"author"` - Alias *string `db:"alias"` - Content Content `db:"content"` // using JSONB postgres datatype -} - -type Content struct { - Title string `json:"title"` - Body string `json:"body"` - Views int64 `json:"views"` -} diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 0000000..16e4385 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,223 @@ +package pgkit_test + +import ( + "context" + "fmt" + "slices" + "sync" + "testing" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" +) + +func TestTable(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("Simple CRUD", func(t *testing.T) { + account := &Account{ + Name: "Save Account", + } + + // Create. + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create failed") + require.NotZero(t, account.ID, "ID should be set") + require.NotZero(t, account.CreatedAt, "CreatedAt should be set") + require.NotZero(t, account.UpdatedAt, "UpdatedAt should be set") + + // Check count. + count, err := db.Accounts.Count(ctx, nil) + require.NoError(t, err, "FindAll failed") + require.Equal(t, uint64(1), count, "Expected 1 account") + + // Read from DB & check for equality. + accountCheck, err := db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err, "FindByID failed") + require.Equal(t, account.ID, accountCheck.ID, "account ID should match") + require.Equal(t, account.Name, accountCheck.Name, "account name should match") + + // Update. + account.Name = "Updated account" + err = db.Accounts.Save(ctx, account) + require.NoError(t, err, "Save failed") + + // Read from DB & check for equality again. + accountCheck, err = db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err, "FindByID failed") + require.Equal(t, account.ID, accountCheck.ID, "account ID should match") + require.Equal(t, account.Name, accountCheck.Name, "account name should match") + + // Check count again. + count, err = db.Accounts.Count(ctx, nil) + require.NoError(t, err, "FindAll failed") + require.Equal(t, uint64(1), count, "Expected 1 account") + }) + + t.Run("Complex Transaction", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + err := db.BeginTx(ctx, func(tx *Database) error { + // Create account. + account := &Account{Name: "Complex Transaction Account"} + err := tx.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + + articles := []*Article{ + {Author: "First", AccountID: account.ID}, + {Author: "Second", AccountID: account.ID}, + {Author: "Third", AccountID: account.ID}, + } + + // Save articles (3x insert). + err = tx.Articles.SaveAll(ctx, articles) + require.NoError(t, err, "SaveAll failed") + + for _, article := range articles { + require.NotZero(t, article.ID, "ID should be set") + require.NotZero(t, article.CreatedAt, "CreatedAt should be set") + require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set") + } + + firstArticle := articles[0] + + // Save articles (3x update, 1x insert). + articles = append(articles, &Article{Author: "Fourth", AccountID: account.ID}) + err = tx.Articles.SaveAll(ctx, articles) + require.NoError(t, err, "SaveAll failed") + + for _, article := range articles { + require.NotZero(t, article.ID, "ID should be set") + require.NotZero(t, article.CreatedAt, "CreatedAt should be set") + require.NotZero(t, article.UpdatedAt, "UpdatedAt should be set") + } + require.Equal(t, firstArticle.ID, articles[0].ID, "First article ID should be the same") + + // Verify we can load all articles with .GetById() + for _, article := range articles { + articleCheck, err := tx.Articles.GetByID(ctx, article.ID) + require.NoError(t, err, "GetByID failed") + require.Equal(t, article.ID, articleCheck.ID, "Article ID should match") + require.Equal(t, article.Author, articleCheck.Author, "Article Author should match") + require.Equal(t, article.AccountID, articleCheck.AccountID, "Article AccountID should match") + require.Equal(t, article.CreatedAt, articleCheck.CreatedAt, "Article CreatedAt should match") + //require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match") + //require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .SaveAll() aboe updates the timestamp. + require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match") + } + + // Verify we can load all articles with .GetByIDs() + articleIDs := make([]uint64, len(articles)) + for _, article := range articles { + articleIDs = append(articleIDs, article.ID) + } + articlesCheck, err := tx.Articles.GetByIDs(ctx, articleIDs) + require.NoError(t, err, "GetByIDs failed") + require.Equal(t, len(articles), len(articlesCheck), "Number of articles should match") + for i, _ := range articlesCheck { + require.Equal(t, articles[i].ID, articlesCheck[i].ID, "Article ID should match") + require.Equal(t, articles[i].Author, articlesCheck[i].Author, "Article Author should match") + require.Equal(t, articles[i].AccountID, articlesCheck[i].AccountID, "Article AccountID should match") + require.Equal(t, articles[i].CreatedAt, articlesCheck[i].CreatedAt, "Article CreatedAt should match") + //require.Equal(t, articles[i].UpdatedAt, articlesCheck[i].UpdatedAt, "Article UpdatedAt should match") + require.Equal(t, articles[i].DeletedAt, articlesCheck[i].DeletedAt, "Article DeletedAt should match") + } + + // Soft-delete first article. + err = tx.Articles.DeleteByID(ctx, firstArticle.ID) + require.NoError(t, err, "DeleteByID failed") + + // Check if article is soft-deleted. + article, err := tx.Articles.GetByID(ctx, firstArticle.ID) + require.NoError(t, err, "GetByID failed") + require.Equal(t, firstArticle.ID, article.ID, "DeletedAt should be set") + require.NotNil(t, article.DeletedAt, "DeletedAt should be set") + + // Hard-delete first article. + err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID) + require.NoError(t, err, "HardDeleteByID failed") + + // Check if article is hard-deleted. + article, err = tx.Articles.GetByID(ctx, firstArticle.ID) + require.Error(t, err, "article was not hard-deleted") + require.Nil(t, article, "article is not nil") + + return nil + }) + require.NoError(t, err, "SaveTx transaction failed") + }) +} + +func TestLockForUpdate(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("TestLockForUpdate", func(t *testing.T) { + // Create account. + account := &Account{Name: "LockForUpdate Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + + // Create article. + article := &Article{AccountID: account.ID, Author: "Author", Content: Content{Title: "Title", Body: "Body"}} + err = db.Articles.Save(ctx, article) + require.NoError(t, err, "Create article failed") + + // Create 1000 reviews. + reviews := make([]*Review, 100) + for i := range 100 { + reviews[i] = &Review{ + Comment: fmt.Sprintf("Test comment %d", i), + AccountID: account.ID, + ArticleID: article.ID, + Status: ReviewStatusPending, + } + } + err = db.Reviews.SaveAll(ctx, reviews) + require.NoError(t, err, "create review") + + cond := sq.Eq{ + "status": ReviewStatusPending, + "deleted_at": nil, + } + orderBy := []string{"created_at ASC"} + + var uniqueIDs [][]uint64 = make([][]uint64, 10) + var wg sync.WaitGroup + + for range 10 { + wg.Go(func() { + + err := db.Reviews.LockForUpdate(ctx, cond, orderBy, 10, func(reviews []*Review) { + for _, review := range reviews { + review.Status = ReviewStatusProcessing + go processReviewAsynchronously(ctx, db, review) + } + }) + require.NoError(t, err, "lock for update") + + }) + } + wg.Wait() + + ids := slices.Concat(uniqueIDs...) + slices.Sort(ids) + ids = slices.Compact(ids) + + require.Equal(t, 100, len(ids), "number of processed unique reviews should be 100") + }) +} + +// TODO: defer() save status (success/failure) or put back to queue for processing. +func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) { + time.Sleep(1 * time.Second) + review.Status = ReviewStatusApproved + db.Reviews.Save(ctx, review) +} diff --git a/tests/testdata/pgkit_test_db.sql b/tests/testdata/pgkit_test_db.sql index a55dbf8..fb68ee5 100644 --- a/tests/testdata/pgkit_test_db.sql +++ b/tests/testdata/pgkit_test_db.sql @@ -6,12 +6,27 @@ CREATE TABLE accounts ( created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL ); +CREATE TABLE articles ( + id SERIAL PRIMARY KEY, + author VARCHAR(80) NOT NULL, + alias VARCHAR(80), + content JSONB, + account_id INTEGER NOT NULL REFERENCES accounts(id), + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP WITHOUT TIME ZONE NULL +); + CREATE TABLE reviews ( id SERIAL PRIMARY KEY, - -- article_id integer, - name VARCHAR(80), - comments TEXT, - created_at TIMESTAMP WITHOUT TIME ZONE + article_id INTEGER REFERENCES articles(id), + account_id INTEGER NOT NULL REFERENCES accounts(id), + comment TEXT, + status SMALLINT, + sentiment SMALLINT, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP WITHOUT TIME ZONE NULL ); CREATE TABLE logs ( @@ -27,10 +42,3 @@ CREATE TABLE stats ( big_num NUMERIC(78,0) NOT NULL, -- representing a big.Int runtime type rating NUMERIC(78,0) NULL -- representing a nullable big.Int runtime type ); - -CREATE TABLE articles ( - id SERIAL PRIMARY KEY, - author VARCHAR(80) NOT NULL, - alias VARCHAR(80), - content JSONB -); From 1fc30a9f0ddf493a5c4d1fc5b4f17222cb465004 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 14 Oct 2025 14:01:56 +0200 Subject: [PATCH 07/34] Fix TestRecordsWithJSONStruct test, since schema changed --- tests/pgkit_test.go | 12 ++++++++++-- tests/schema_test.go | 19 ++++++++++--------- tests/testdata/pgkit_test_db.sql | 1 + 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/pgkit_test.go b/tests/pgkit_test.go index 281f387..c6c8530 100644 --- a/tests/pgkit_test.go +++ b/tests/pgkit_test.go @@ -307,8 +307,16 @@ func TestRecordsWithJSONB(t *testing.T) { func TestRecordsWithJSONStruct(t *testing.T) { truncateTable(t, "articles") + account := &Account{ + Name: "TestRecordsWithJSONStruct", + } + err := DB.Query.QueryRow(context.Background(), DB.SQL.InsertRecord(account).Suffix(`RETURNING "id"`)).Scan(&account.ID) + assert.NoError(t, err) + assert.True(t, account.ID > 0) + article := &Article{ - Author: "Gary", + AccountID: account.ID, + Author: "Gary", Content: Content{ Title: "How to cook pizza", Body: "flour+water+salt+yeast+cheese", @@ -319,7 +327,7 @@ func TestRecordsWithJSONStruct(t *testing.T) { cols, _, err := pgkit.Map(article) assert.NoError(t, err) sort.Strings(cols) - assert.Equal(t, []string{"alias", "author", "content"}, cols) + assert.Equal(t, []string{"account_id", "alias", "author", "content", "deleted_at"}, cols) // Insert record q1 := DB.SQL.InsertRecord(article, "articles") diff --git a/tests/schema_test.go b/tests/schema_test.go index 04fe1f9..c0c392e 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -57,15 +57,16 @@ type Content struct { } type Review struct { - ID uint64 `db:"id,omitempty"` - Comment string `db:"comment"` - Status ReviewStatus `db:"status"` - Sentiment int64 `db:"sentiment"` - AccountID int64 `db:"account_id"` - ArticleID uint64 `db:"article_id"` - CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT - UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT - DeletedAt *time.Time `db:"deleted_at"` + ID uint64 `db:"id,omitempty"` + Comment string `db:"comment"` + Status ReviewStatus `db:"status"` + Sentiment int64 `db:"sentiment"` + AccountID int64 `db:"account_id"` + ArticleID uint64 `db:"article_id"` + ProcessedAt *time.Time `db:"processed_at"` + CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + DeletedAt *time.Time `db:"deleted_at"` } func (r *Review) GetID() uint64 { return r.ID } diff --git a/tests/testdata/pgkit_test_db.sql b/tests/testdata/pgkit_test_db.sql index fb68ee5..c4301e5 100644 --- a/tests/testdata/pgkit_test_db.sql +++ b/tests/testdata/pgkit_test_db.sql @@ -24,6 +24,7 @@ CREATE TABLE reviews ( comment TEXT, status SMALLINT, sentiment SMALLINT, + processed_at TIMESTAMP WITHOUT TIME ZONE NULL, created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, deleted_at TIMESTAMP WITHOUT TIME ZONE NULL From bf7f7bebc65de6fe6dc110360ed694ddf8076fca Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 14 Oct 2025 14:05:41 +0200 Subject: [PATCH 08/34] Fix LockForUpdate test --- tests/table_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/table_test.go b/tests/table_test.go index 16e4385..7de071f 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -196,9 +196,13 @@ func TestLockForUpdate(t *testing.T) { wg.Go(func() { err := db.Reviews.LockForUpdate(ctx, cond, orderBy, 10, func(reviews []*Review) { - for _, review := range reviews { + now := time.Now().UTC() + for i, review := range reviews { review.Status = ReviewStatusProcessing + review.ProcessedAt = &now go processReviewAsynchronously(ctx, db, review) + + uniqueIDs[i] = append(uniqueIDs[i], review.ID) } }) require.NoError(t, err, "lock for update") From feb70f313be63b84ea3f214949f0e38de4fe5d0a Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 14 Oct 2025 14:55:36 +0200 Subject: [PATCH 09/34] Don't rely on wg.Go(), a feature from Go 1.25 --- tests/table_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/table_test.go b/tests/table_test.go index 7de071f..679f36e 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -193,7 +193,9 @@ func TestLockForUpdate(t *testing.T) { var wg sync.WaitGroup for range 10 { - wg.Go(func() { + wg.Add(1) + go func() { + defer wg.Done() err := db.Reviews.LockForUpdate(ctx, cond, orderBy, 10, func(reviews []*Review) { now := time.Now().UTC() @@ -207,7 +209,7 @@ func TestLockForUpdate(t *testing.T) { }) require.NoError(t, err, "lock for update") - }) + }() } wg.Wait() From b218fb5ea3976eeff2d1842a78dd540d97cf1755 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Fri, 17 Oct 2025 17:29:16 +0200 Subject: [PATCH 10/34] LockForUpdate that can pass transaction to update fn --- table.go | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/table.go b/table.go index 320e23d..c36b915 100644 --- a/table.go +++ b/table.go @@ -173,11 +173,12 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { } } -// LockForUpdate locks 0..limit records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// for safe concurrent processing where each record is processed exactly once. -// Complete updateFn() quickly to avoid holding the transaction. For long-running work: -// update status to "processing" and return early, then process asynchronously. -func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func([]PT)) error { +// LockForUpdate locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// for safe concurrent processing. Each record is processed exactly once across multiple workers. +// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid +// holding the transaction. For long-running work, update status to "processing" and return early, +// then process asynchronously and update status to "completed" or "failed" when done. +func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(pgTx pgx.Tx, records []PT)) error { return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} @@ -198,7 +199,7 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, return fmt.Errorf("select for update skip locked: %w", err) } - updateFn(records) + updateFn(pgTx, records) for _, record := range records { q := tx.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) @@ -210,3 +211,31 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, return nil }) } + +// LockOneForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// for safe concurrent processing. The record is processed exactly once across multiple workers. +// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid +// holding the transaction. For long-running work, update status to "processing" and return early, +// then process asynchronously and update status to "completed" or "failed" when done. +// +// Returns ErrNoRows if no records match the condition. +func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(pgTx pgx.Tx, record PT)) error { + var noRows bool + + err := t.LockForUpdate(ctx, cond, orderBy, 1, func(pgTx pgx.Tx, records []PT) { + if len(records) > 0 { + updateFn(pgTx, records[0]) + } else { + noRows = true + } + }) + if err != nil { + return fmt.Errorf("lock for update one: %w", err) + } + + if noRows { + return ErrNoRows + } + + return nil +} From 61898a166f6e81bb6eecae353ef61af498295bb2 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Fri, 17 Oct 2025 22:10:18 +0200 Subject: [PATCH 11/34] Refactor LockForUpdate() to reuse tx if possible --- table.go | 73 ++++++++++++++++++++++++------------------ tests/database_test.go | 4 +-- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/table.go b/table.go index c36b915..62c32d6 100644 --- a/table.go +++ b/table.go @@ -175,56 +175,65 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { // LockForUpdate locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern // for safe concurrent processing. Each record is processed exactly once across multiple workers. -// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid +// Records are automatically updated after updateFn() completes. Keep updateFn() fast to avoid // holding the transaction. For long-running work, update status to "processing" and return early, -// then process asynchronously and update status to "completed" or "failed" when done. -func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(pgTx pgx.Tx, records []PT)) error { +// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed". +func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { + // Check if we're already in a transaction + if t.DB.Query.Tx != nil { + return t.lockForUpdateWithTx(ctx, t.DB.Query.Tx, cond, orderBy, limit, updateFn) + } + return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { - if len(orderBy) == 0 { - orderBy = []string{t.IDColumn} - } + return t.lockForUpdateWithTx(ctx, pgTx, cond, orderBy, limit, updateFn) + }) +} - tx := t.WithTx(pgTx) +func (t *Table[T, PT, IDT]) lockForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { + if len(orderBy) == 0 { + orderBy = []string{t.IDColumn} + } - q := tx.SQL. - Select("*"). - From(t.Name). - Where(cond). - OrderBy(orderBy...). - Limit(limit). - Suffix("FOR UPDATE SKIP LOCKED") + q := t.SQL. + Select("*"). + From(t.Name). + Where(cond). + OrderBy(orderBy...). + Limit(limit). + Suffix("FOR UPDATE SKIP LOCKED") - var records []PT - if err := tx.Query.GetAll(ctx, q, &records); err != nil { - return fmt.Errorf("select for update skip locked: %w", err) - } + txQuery := t.DB.TxQuery(pgTx) + + var records []PT + if err := txQuery.GetAll(ctx, q, &records); err != nil { + return fmt.Errorf("select for update skip locked: %w", err) + } - updateFn(pgTx, records) + updateFn(records) - for _, record := range records { - q := tx.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) - if _, err := tx.Query.Exec(ctx, q); err != nil { - return fmt.Errorf("update record: %w", err) - } + for _, record := range records { + q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := txQuery.Exec(ctx, q); err != nil { + return fmt.Errorf("update record: %w", err) } + } - return nil - }) + return nil } // LockOneForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern // for safe concurrent processing. The record is processed exactly once across multiple workers. -// Records are automatically updated after updateFn() completes. Complete updateFn() quickly to avoid +// The record is automatically updated after updateFn() completes. Keep updateFn() fast to avoid // holding the transaction. For long-running work, update status to "processing" and return early, -// then process asynchronously and update status to "completed" or "failed" when done. +// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed". // -// Returns ErrNoRows if no records match the condition. -func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(pgTx pgx.Tx, record PT)) error { +// Returns ErrNoRows if no matching records are available for locking. +func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { var noRows bool - err := t.LockForUpdate(ctx, cond, orderBy, 1, func(pgTx pgx.Tx, records []PT) { + err := t.LockForUpdate(ctx, cond, orderBy, 1, func(records []PT) { if len(records) > 0 { - updateFn(pgTx, records[0]) + updateFn(records[0]) } else { noRows = true } diff --git a/tests/database_test.go b/tests/database_test.go index d236678..803b2dd 100644 --- a/tests/database_test.go +++ b/tests/database_test.go @@ -26,12 +26,12 @@ func initDB(db *pgkit.DB) *Database { func (db *Database) BeginTx(ctx context.Context, fn func(tx *Database) error) error { return pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { - tx := db.WithTxQuery(pgTx) + tx := db.WithTx(pgTx) return fn(tx) }) } -func (db *Database) WithTxQuery(tx pgx.Tx) *Database { +func (db *Database) WithTx(tx pgx.Tx) *Database { pgkitDB := &pgkit.DB{ Conn: db.Conn, SQL: db.SQL, From a1f7de5ffe06e71ad19a1d0cb796e36d7c5b5069 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Sat, 18 Oct 2025 11:22:56 +0200 Subject: [PATCH 12/34] Simplify data models further .Save() - variadic arg instead of .Save() and .SaveAll() .List() - renamed from .GetAll() .Get() - renamed from .GetOne() .LockForUpdates() - renamed from .LockForUpdate() .LockForUpdate() - renamed from .LockOneForUpdate() --- table.go | 67 +++++++++++++++++++++++++-------------------- tests/table_test.go | 27 +++++++++++------- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/table.go b/table.go index 62c32d6..f7eb909 100644 --- a/table.go +++ b/table.go @@ -29,8 +29,14 @@ type hasSetDeletedAt interface { SetDeletedAt(time.Time) } -// Save inserts or updates a record. Auto-detects insert vs update by ID. -func (t *Table[T, PT, IDT]) Save(ctx context.Context, record PT) error { +// Save inserts or updates given records. Auto-detects insert vs update by ID. +func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { + if len(records) != 1 { + return t.saveAll(ctx, records) + } + + record := records[0] + if err := record.Validate(); err != nil { return err //nolint:wrapcheck } @@ -59,8 +65,9 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, record PT) error { return nil } -// SaveAll saves multiple records sequentially. -func (t *Table[T, PT, IDT]) SaveAll(ctx context.Context, records []PT) error { +// saveAll saves multiple records sequentially. +// TODO: This can be likely optimized to use a batch insert. +func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { for i := range records { if err := t.Save(ctx, records[i]); err != nil { return err @@ -70,30 +77,30 @@ func (t *Table[T, PT, IDT]) SaveAll(ctx context.Context, records []PT) error { return nil } -// GetOne returns the first record matching the condition. -func (t *Table[T, PT, IDT]) GetOne(ctx context.Context, cond sq.Sqlizer, orderBy []string) (PT, error) { +// Get returns the first record matching the condition. +func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } - dest := new(T) + record := new(T) q := t.SQL. Select("*"). From(t.Name). - Where(cond). + Where(where). Limit(1). OrderBy(orderBy...) - if err := t.Query.GetOne(ctx, q, dest); err != nil { + if err := t.Query.GetOne(ctx, q, record); err != nil { return nil, err } - return dest, nil + return record, nil } -// GetAll returns all records matching the condition. -func (t *Table[T, PT, IDT]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy []string) ([]PT, error) { +// List returns all records matching the condition. +func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]PT, error) { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -101,34 +108,34 @@ func (t *Table[T, PT, IDT]) GetAll(ctx context.Context, cond sq.Sqlizer, orderBy q := t.SQL. Select("*"). From(t.Name). - Where(cond). + Where(where). OrderBy(orderBy...) - var dest []PT - if err := t.Query.GetAll(ctx, q, &dest); err != nil { + var records []PT + if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } - return dest, nil + return records, nil } // GetByID returns a record by its ID. func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) { - return t.GetOne(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) + return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } // GetByIDs returns records by their IDs. func (t *Table[T, PT, IDT]) GetByIDs(ctx context.Context, ids []IDT) ([]PT, error) { - return t.GetAll(ctx, sq.Eq{t.IDColumn: ids}, nil) + return t.List(ctx, sq.Eq{t.IDColumn: ids}, nil) } // Count returns the number of matching records. -func (t *Table[T, PT, IDT]) Count(ctx context.Context, cond sq.Sqlizer) (uint64, error) { +func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) { var count uint64 q := t.SQL. Select("COUNT(1)"). From(t.Name). - Where(cond) + Where(where) if err := t.Query.GetOne(ctx, q, &count); err != nil { return count, fmt.Errorf("get one: %w", err) @@ -173,23 +180,23 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { } } -// LockForUpdate locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// LockForUpdates locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern // for safe concurrent processing. Each record is processed exactly once across multiple workers. // Records are automatically updated after updateFn() completes. Keep updateFn() fast to avoid // holding the transaction. For long-running work, update status to "processing" and return early, // then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed". -func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { +func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { // Check if we're already in a transaction if t.DB.Query.Tx != nil { - return t.lockForUpdateWithTx(ctx, t.DB.Query.Tx, cond, orderBy, limit, updateFn) + return t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn) } return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { - return t.lockForUpdateWithTx(ctx, pgTx, cond, orderBy, limit, updateFn) + return t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn) }) } -func (t *Table[T, PT, IDT]) lockForUpdateWithTx(ctx context.Context, pgTx pgx.Tx, cond sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { +func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -197,7 +204,7 @@ func (t *Table[T, PT, IDT]) lockForUpdateWithTx(ctx context.Context, pgTx pgx.Tx q := t.SQL. Select("*"). From(t.Name). - Where(cond). + Where(where). OrderBy(orderBy...). Limit(limit). Suffix("FOR UPDATE SKIP LOCKED") @@ -221,17 +228,17 @@ func (t *Table[T, PT, IDT]) lockForUpdateWithTx(ctx context.Context, pgTx pgx.Tx return nil } -// LockOneForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// LockForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern // for safe concurrent processing. The record is processed exactly once across multiple workers. // The record is automatically updated after updateFn() completes. Keep updateFn() fast to avoid // holding the transaction. For long-running work, update status to "processing" and return early, -// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed". +// then process asynchronously. Use defer LockForUpdate() to update status to "completed" or "failed". // // Returns ErrNoRows if no matching records are available for locking. -func (t *Table[T, PT, IDT]) LockOneForUpdate(ctx context.Context, cond sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { +func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { var noRows bool - err := t.LockForUpdate(ctx, cond, orderBy, 1, func(records []PT) { + err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) { if len(records) > 0 { updateFn(records[0]) } else { diff --git a/tests/table_test.go b/tests/table_test.go index 679f36e..7c9a89c 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -3,6 +3,7 @@ package pgkit_test import ( "context" "fmt" + "log" "slices" "sync" "testing" @@ -75,8 +76,8 @@ func TestTable(t *testing.T) { } // Save articles (3x insert). - err = tx.Articles.SaveAll(ctx, articles) - require.NoError(t, err, "SaveAll failed") + err = tx.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save failed") for _, article := range articles { require.NotZero(t, article.ID, "ID should be set") @@ -88,8 +89,8 @@ func TestTable(t *testing.T) { // Save articles (3x update, 1x insert). articles = append(articles, &Article{Author: "Fourth", AccountID: account.ID}) - err = tx.Articles.SaveAll(ctx, articles) - require.NoError(t, err, "SaveAll failed") + err = tx.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save failed") for _, article := range articles { require.NotZero(t, article.ID, "ID should be set") @@ -107,7 +108,7 @@ func TestTable(t *testing.T) { require.Equal(t, article.AccountID, articleCheck.AccountID, "Article AccountID should match") require.Equal(t, article.CreatedAt, articleCheck.CreatedAt, "Article CreatedAt should match") //require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match") - //require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .SaveAll() aboe updates the timestamp. + //require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .Save() aboe updates the timestamp. require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match") } @@ -180,14 +181,17 @@ func TestLockForUpdate(t *testing.T) { Status: ReviewStatusPending, } } - err = db.Reviews.SaveAll(ctx, reviews) + err = db.Reviews.Save(ctx, reviews...) require.NoError(t, err, "create review") - cond := sq.Eq{ + where := sq.Eq{ "status": ReviewStatusPending, "deleted_at": nil, } - orderBy := []string{"created_at ASC"} + orderBy := []string{ + "created_at ASC", + } + limit := uint64(10) var uniqueIDs [][]uint64 = make([][]uint64, 10) var wg sync.WaitGroup @@ -197,7 +201,7 @@ func TestLockForUpdate(t *testing.T) { go func() { defer wg.Done() - err := db.Reviews.LockForUpdate(ctx, cond, orderBy, 10, func(reviews []*Review) { + err := db.Reviews.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { now := time.Now().UTC() for i, review := range reviews { review.Status = ReviewStatusProcessing @@ -225,5 +229,8 @@ func TestLockForUpdate(t *testing.T) { func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) { time.Sleep(1 * time.Second) review.Status = ReviewStatusApproved - db.Reviews.Save(ctx, review) + err := db.Reviews.Save(ctx, review) + if err != nil { + log.Printf("failed to save review: %v", err) + } } From 25465622fb5c1e475e62a6979b822a93cb2055bf Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Sat, 18 Oct 2025 12:28:22 +0200 Subject: [PATCH 13/34] Improve tests for async processing --- tests/schema_test.go | 1 + tests/table_test.go | 82 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/tests/schema_test.go b/tests/schema_test.go index c0c392e..a8da6de 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -88,6 +88,7 @@ const ( ReviewStatusProcessing ReviewStatusApproved ReviewStatusRejected + ReviewStatusFailed ) type Log struct { diff --git a/tests/table_test.go b/tests/table_test.go index 7c9a89c..856d8e2 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "math/rand" "slices" "sync" "testing" @@ -154,15 +155,15 @@ func TestTable(t *testing.T) { }) } -func TestLockForUpdate(t *testing.T) { +func TestLockForUpdates(t *testing.T) { truncateAllTables(t) ctx := t.Context() db := initDB(DB) - t.Run("TestLockForUpdate", func(t *testing.T) { + t.Run("TestLockForUpdates", func(t *testing.T) { // Create account. - account := &Account{Name: "LockForUpdate Account"} + account := &Account{Name: "LockForUpdates Account"} err := db.Accounts.Save(ctx, account) require.NoError(t, err, "Create account failed") @@ -193,7 +194,7 @@ func TestLockForUpdate(t *testing.T) { } limit := uint64(10) - var uniqueIDs [][]uint64 = make([][]uint64, 10) + var processedIDs [][]uint64 = make([][]uint64, 10) var wg sync.WaitGroup for range 10 { @@ -201,36 +202,81 @@ func TestLockForUpdate(t *testing.T) { go func() { defer wg.Done() + var processReviews []*Review + err := db.Reviews.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { now := time.Now().UTC() - for i, review := range reviews { + for _, review := range reviews { review.Status = ReviewStatusProcessing review.ProcessedAt = &now - go processReviewAsynchronously(ctx, db, review) - - uniqueIDs[i] = append(uniqueIDs[i], review.ID) } + + processReviews = reviews }) require.NoError(t, err, "lock for update") + for _, review := range processReviews { + go processReviewAsynchronously(ctx, db, review) + } + + for i, review := range processReviews { + processedIDs[i] = append(processedIDs[i], review.ID) + } }() } wg.Wait() - ids := slices.Concat(uniqueIDs...) - slices.Sort(ids) - ids = slices.Compact(ids) + // Ensure that all reviews were picked up for processing exactly once. + uniqueIDs := slices.Concat(processedIDs...) + slices.Sort(uniqueIDs) + uniqueIDs = slices.Compact(uniqueIDs) + require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") - require.Equal(t, 100, len(ids), "number of processed unique reviews should be 100") + // Wait for all reviews to be processed asynchronously. + time.Sleep(2 * time.Second) + + // Double check there's no reviews stuck in "processing" status. + count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusProcessing}) + require.NoError(t, err, "count reviews") + require.Zero(t, count, "there should be no reviews stuck in 'processing' status") }) } -// TODO: defer() save status (success/failure) or put back to queue for processing. -func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) { - time.Sleep(1 * time.Second) +func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) (err error) { + defer func() { + // Always update status to "approved", "rejected" or "failed". + noCtx := context.Background() + err = db.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { + now := time.Now().UTC() + update.ProcessedAt = &now + if err != nil { + update.Status = ReviewStatusFailed + return + } + update.Status = review.Status + }) + if err != nil { + log.Printf("failed to save review: %v", err) + } + }() + + // Simulate long-running work. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(1 * time.Second): + } + + // Simulate external API call to an LLM. + if rand.Intn(2) == 0 { + return fmt.Errorf("failed to process review: ") + } + review.Status = ReviewStatusApproved - err := db.Reviews.Save(ctx, review) - if err != nil { - log.Printf("failed to save review: %v", err) + if rand.Intn(2) == 0 { + review.Status = ReviewStatusRejected } + now := time.Now().UTC() + review.ProcessedAt = &now + return nil } From f9917e9873f047a8d7cf8e668ea73d1bef6b4dba Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Sat, 18 Oct 2025 20:34:52 +0200 Subject: [PATCH 14/34] Tests: Implement in-memory worker pattern via simple WaitGroup --- table.go | 70 ++++++++++++++++++++++-------------------- tests/database_test.go | 2 ++ tests/table_test.go | 47 ++-------------------------- tests/worker_test.go | 64 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 77 deletions(-) create mode 100644 tests/worker_test.go diff --git a/table.go b/table.go index f7eb909..bf0d999 100644 --- a/table.go +++ b/table.go @@ -180,11 +180,43 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { } } -// LockForUpdates locks and processes records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// for safe concurrent processing. Each record is processed exactly once across multiple workers. -// Records are automatically updated after updateFn() completes. Keep updateFn() fast to avoid -// holding the transaction. For long-running work, update status to "processing" and return early, -// then process asynchronously. Use defer LockOneForUpdate() to update status to "completed" or "failed". +// LockForUpdate locks and updates one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// within a database transaction for safe concurrent processing. The record is processed exactly +// once across multiple workers. The record is automatically updated after updateFn() completes. +// +// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status +// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() +// to update status to "completed" or "failed". +// +// Returns ErrNoRows if no matching records are available for locking. +func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { + var noRows bool + + err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) { + if len(records) > 0 { + updateFn(records[0]) + } else { + noRows = true + } + }) + if err != nil { + return fmt.Errorf("lock for update one: %w", err) + } + + if noRows { + return ErrNoRows + } + + return nil +} + +// LockForUpdates locks and updates records using PostgreSQL's FOR UPDATE SKIP LOCKED pattern +// within a database transaction for safe concurrent processing. Each record is processed exactly +// once across multiple workers. Records are automatically updated after updateFn() completes. +// +// Keep updateFn() fast to avoid holding the transaction. For long-running work, update status +// to "processing" and return early, then process asynchronously. Use defer LockForUpdate() +// to update status to "completed" or "failed". func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { // Check if we're already in a transaction if t.DB.Query.Tx != nil { @@ -227,31 +259,3 @@ func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.T return nil } - -// LockForUpdate locks and processes one record using PostgreSQL's FOR UPDATE SKIP LOCKED pattern -// for safe concurrent processing. The record is processed exactly once across multiple workers. -// The record is automatically updated after updateFn() completes. Keep updateFn() fast to avoid -// holding the transaction. For long-running work, update status to "processing" and return early, -// then process asynchronously. Use defer LockForUpdate() to update status to "completed" or "failed". -// -// Returns ErrNoRows if no matching records are available for locking. -func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { - var noRows bool - - err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) { - if len(records) > 0 { - updateFn(records[0]) - } else { - noRows = true - } - }) - if err != nil { - return fmt.Errorf("lock for update one: %w", err) - } - - if noRows { - return ErrNoRows - } - - return nil -} diff --git a/tests/database_test.go b/tests/database_test.go index 803b2dd..e97f9db 100644 --- a/tests/database_test.go +++ b/tests/database_test.go @@ -41,6 +41,8 @@ func (db *Database) WithTx(tx pgx.Tx) *Database { return initDB(pgkitDB) } +func (db *Database) Close() { db.DB.Conn.Close() } + type accountsTable struct { *pgkit.Table[Account, *Account, int64] } diff --git a/tests/table_test.go b/tests/table_test.go index 856d8e2..f96a5c3 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -1,10 +1,7 @@ package pgkit_test import ( - "context" "fmt" - "log" - "math/rand" "slices" "sync" "testing" @@ -160,6 +157,7 @@ func TestLockForUpdates(t *testing.T) { ctx := t.Context() db := initDB(DB) + worker := &Worker{DB: db} t.Run("TestLockForUpdates", func(t *testing.T) { // Create account. @@ -216,7 +214,7 @@ func TestLockForUpdates(t *testing.T) { require.NoError(t, err, "lock for update") for _, review := range processReviews { - go processReviewAsynchronously(ctx, db, review) + go worker.ProcessReview(ctx, review) } for i, review := range processReviews { @@ -233,7 +231,7 @@ func TestLockForUpdates(t *testing.T) { require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") // Wait for all reviews to be processed asynchronously. - time.Sleep(2 * time.Second) + worker.Wait() // Double check there's no reviews stuck in "processing" status. count, err := db.Reviews.Count(ctx, sq.Eq{"status": ReviewStatusProcessing}) @@ -241,42 +239,3 @@ func TestLockForUpdates(t *testing.T) { require.Zero(t, count, "there should be no reviews stuck in 'processing' status") }) } - -func processReviewAsynchronously(ctx context.Context, db *Database, review *Review) (err error) { - defer func() { - // Always update status to "approved", "rejected" or "failed". - noCtx := context.Background() - err = db.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { - now := time.Now().UTC() - update.ProcessedAt = &now - if err != nil { - update.Status = ReviewStatusFailed - return - } - update.Status = review.Status - }) - if err != nil { - log.Printf("failed to save review: %v", err) - } - }() - - // Simulate long-running work. - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(1 * time.Second): - } - - // Simulate external API call to an LLM. - if rand.Intn(2) == 0 { - return fmt.Errorf("failed to process review: ") - } - - review.Status = ReviewStatusApproved - if rand.Intn(2) == 0 { - review.Status = ReviewStatusRejected - } - now := time.Now().UTC() - review.ProcessedAt = &now - return nil -} diff --git a/tests/worker_test.go b/tests/worker_test.go new file mode 100644 index 0000000..711a3af --- /dev/null +++ b/tests/worker_test.go @@ -0,0 +1,64 @@ +package pgkit_test + +import ( + "context" + "fmt" + "log" + "math/rand" + "sync" + "time" + + sq "github.com/Masterminds/squirrel" +) + +type Worker struct { + DB *Database + + wg sync.WaitGroup +} + +func (w *Worker) Wait() { + w.wg.Wait() +} + +func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) { + w.wg.Add(1) + defer w.wg.Done() + + defer func() { + // Always update review status to "approved", "rejected" or "failed". + noCtx := context.Background() + err = w.DB.Reviews.LockForUpdate(noCtx, sq.Eq{"id": review.ID}, []string{"id DESC"}, func(update *Review) { + now := time.Now().UTC() + update.ProcessedAt = &now + if err != nil { + update.Status = ReviewStatusFailed + return + } + update.Status = review.Status + }) + if err != nil { + log.Printf("failed to save review: %v", err) + } + }() + + // Simulate long-running work. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(1 * time.Second): + } + + // Simulate external API call to an LLM. + if rand.Intn(2) == 0 { + return fmt.Errorf("failed to process review: ") + } + + review.Status = ReviewStatusApproved + if rand.Intn(2) == 0 { + review.Status = ReviewStatusRejected + } + now := time.Now().UTC() + review.ProcessedAt = &now + return nil +} From 73ba584cd766add7416b64cf4283f086a2d4df59 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Sat, 18 Oct 2025 21:00:47 +0200 Subject: [PATCH 15/34] A better "dequeue" abstraction defined on reviews table --- tests/database_test.go | 14 ++----------- tests/table_test.go | 34 ++++++------------------------ tests/tables_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 40 deletions(-) create mode 100644 tests/tables_test.go diff --git a/tests/database_test.go b/tests/database_test.go index e97f9db..daefa63 100644 --- a/tests/database_test.go +++ b/tests/database_test.go @@ -41,16 +41,6 @@ func (db *Database) WithTx(tx pgx.Tx) *Database { return initDB(pgkitDB) } -func (db *Database) Close() { db.DB.Conn.Close() } - -type accountsTable struct { - *pgkit.Table[Account, *Account, int64] -} - -type articlesTable struct { - *pgkit.Table[Article, *Article, uint64] -} - -type reviewsTable struct { - *pgkit.Table[Review, *Review, uint64] +func (db *Database) Close() { + db.DB.Conn.Close() } diff --git a/tests/table_test.go b/tests/table_test.go index f96a5c3..b8a6f42 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -5,7 +5,6 @@ import ( "slices" "sync" "testing" - "time" sq "github.com/Masterminds/squirrel" "github.com/stretchr/testify/require" @@ -183,16 +182,7 @@ func TestLockForUpdates(t *testing.T) { err = db.Reviews.Save(ctx, reviews...) require.NoError(t, err, "create review") - where := sq.Eq{ - "status": ReviewStatusPending, - "deleted_at": nil, - } - orderBy := []string{ - "created_at ASC", - } - limit := uint64(10) - - var processedIDs [][]uint64 = make([][]uint64, 10) + var ids [][]uint64 = make([][]uint64, 10) var wg sync.WaitGroup for range 10 { @@ -200,32 +190,20 @@ func TestLockForUpdates(t *testing.T) { go func() { defer wg.Done() - var processReviews []*Review + reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) + require.NoError(t, err, "dequeue reviews") - err := db.Reviews.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { - now := time.Now().UTC() - for _, review := range reviews { - review.Status = ReviewStatusProcessing - review.ProcessedAt = &now - } - - processReviews = reviews - }) - require.NoError(t, err, "lock for update") - - for _, review := range processReviews { + for i, review := range reviews { go worker.ProcessReview(ctx, review) - } - for i, review := range processReviews { - processedIDs[i] = append(processedIDs[i], review.ID) + ids[i] = append(ids[i], review.ID) } }() } wg.Wait() // Ensure that all reviews were picked up for processing exactly once. - uniqueIDs := slices.Concat(processedIDs...) + uniqueIDs := slices.Concat(ids...) slices.Sort(uniqueIDs) uniqueIDs = slices.Compact(uniqueIDs) require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") diff --git a/tests/tables_test.go b/tests/tables_test.go new file mode 100644 index 0000000..3815656 --- /dev/null +++ b/tests/tables_test.go @@ -0,0 +1,47 @@ +package pgkit_test + +import ( + "context" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" +) + +type accountsTable struct { + *pgkit.Table[Account, *Account, int64] +} + +type articlesTable struct { + *pgkit.Table[Article, *Article, uint64] +} + +type reviewsTable struct { + *pgkit.Table[Review, *Review, uint64] +} + +func (w *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { + var dequeued []*Review + where := sq.Eq{ + "status": ReviewStatusPending, + "deleted_at": nil, + } + orderBy := []string{ + "created_at ASC", + } + + err := w.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { + now := time.Now().UTC() + for _, review := range reviews { + review.Status = ReviewStatusProcessing + review.ProcessedAt = &now + } + dequeued = reviews + }) + if err != nil { + return nil, fmt.Errorf("lock for updates: %w", err) + } + + return dequeued, nil +} From 6c9b2004a1069c2e71944180af3e1c1c6bf61fcb Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 20 Oct 2025 11:33:41 +0200 Subject: [PATCH 16/34] Rename GetByIDs() to ListByIDs() --- table.go | 4 ++-- tests/table_test.go | 6 +++--- tests/tables_test.go | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/table.go b/table.go index bf0d999..4ec104b 100644 --- a/table.go +++ b/table.go @@ -124,8 +124,8 @@ func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) { return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } -// GetByIDs returns records by their IDs. -func (t *Table[T, PT, IDT]) GetByIDs(ctx context.Context, ids []IDT) ([]PT, error) { +// ListByIDs returns records by their IDs. +func (t *Table[T, PT, IDT]) ListByIDs(ctx context.Context, ids []IDT) ([]PT, error) { return t.List(ctx, sq.Eq{t.IDColumn: ids}, nil) } diff --git a/tests/table_test.go b/tests/table_test.go index b8a6f42..4b6b1d6 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -109,13 +109,13 @@ func TestTable(t *testing.T) { require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match") } - // Verify we can load all articles with .GetByIDs() + // Verify we can load all articles with .ListByIDs() articleIDs := make([]uint64, len(articles)) for _, article := range articles { articleIDs = append(articleIDs, article.ID) } - articlesCheck, err := tx.Articles.GetByIDs(ctx, articleIDs) - require.NoError(t, err, "GetByIDs failed") + articlesCheck, err := tx.Articles.ListByIDs(ctx, articleIDs) + require.NoError(t, err, "ListByIDs failed") require.Equal(t, len(articles), len(articlesCheck), "Number of articles should match") for i, _ := range articlesCheck { require.Equal(t, articles[i].ID, articlesCheck[i].ID, "Article ID should match") diff --git a/tests/tables_test.go b/tests/tables_test.go index 3815656..d734da0 100644 --- a/tests/tables_test.go +++ b/tests/tables_test.go @@ -21,7 +21,7 @@ type reviewsTable struct { *pgkit.Table[Review, *Review, uint64] } -func (w *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { +func (t *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ([]*Review, error) { var dequeued []*Review where := sq.Eq{ "status": ReviewStatusPending, @@ -31,7 +31,7 @@ func (w *reviewsTable) DequeueForProcessing(ctx context.Context, limit uint64) ( "created_at ASC", } - err := w.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { + err := t.LockForUpdates(ctx, where, orderBy, limit, func(reviews []*Review) { now := time.Now().UTC() for _, review := range reviews { review.Status = ReviewStatusProcessing From 77cb1573ae8c739991eef93a657a1ba2afda88c0 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 21 Oct 2025 13:20:37 +0200 Subject: [PATCH 17/34] Fix updated_at field, thanks @shunkakinoki --- table.go | 2 +- tests/schema_test.go | 6 +++--- tests/testdata/pgkit_test_db.sql | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/table.go b/table.go index 4ec104b..e81b5be 100644 --- a/table.go +++ b/table.go @@ -12,7 +12,7 @@ import ( // Table provides basic CRUD operations for database records. // Records must implement GetID() and Validate() methods. type Table[T any, PT interface { - *T // Enforce T is a pointer; and thus all methods are defined on a pointer receiver. + *T // Enforce T is a pointer. GetID() IDT Validate() error }, IDT comparable] struct { diff --git a/tests/schema_test.go b/tests/schema_test.go index a8da6de..dd2c3ee 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -12,7 +12,7 @@ type Account struct { Name string `db:"name"` Disabled bool `db:"disabled"` CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT - UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT } func (a *Account) DBTableName() string { return "accounts" } @@ -34,7 +34,7 @@ type Article struct { Content Content `db:"content"` // using JSONB postgres datatype AccountID int64 `db:"account_id"` CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT - UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT DeletedAt *time.Time `db:"deleted_at"` } @@ -65,7 +65,7 @@ type Review struct { ArticleID uint64 `db:"article_id"` ProcessedAt *time.Time `db:"processed_at"` CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT - UpdatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on postgres DEFAULT + UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on postgres DEFAULT DeletedAt *time.Time `db:"deleted_at"` } diff --git a/tests/testdata/pgkit_test_db.sql b/tests/testdata/pgkit_test_db.sql index c4301e5..406d655 100644 --- a/tests/testdata/pgkit_test_db.sql +++ b/tests/testdata/pgkit_test_db.sql @@ -3,7 +3,8 @@ CREATE TABLE accounts ( name VARCHAR(255), disabled BOOLEAN, new_column_not_in_code BOOLEAN, -- test for backward-compatible migrations, see https://github.com/goware/pgkit/issues/13 - created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL ); CREATE TABLE articles ( From 82b1e7b361ce75ccc538b57021eaf772be7eed25 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 21 Oct 2025 13:34:50 +0200 Subject: [PATCH 18/34] PR feedback: Improve error annotations --- table.go | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/table.go b/table.go index e81b5be..4abcd65 100644 --- a/table.go +++ b/table.go @@ -38,7 +38,7 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { record := records[0] if err := record.Validate(); err != nil { - return err //nolint:wrapcheck + return fmt.Errorf("save: validate record: %w", err) } if row, ok := any(record).(hasSetUpdatedAt); ok { @@ -50,7 +50,7 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { if record.GetID() == zero { q := t.SQL.InsertRecord(record).Into(t.Name).Suffix("RETURNING *") if err := t.Query.GetOne(ctx, q, record); err != nil { - return fmt.Errorf("insert records: %w", err) + return fmt.Errorf("save: insert record: %w", err) } return nil @@ -59,14 +59,14 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { // Update q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) if _, err := t.Query.Exec(ctx, q); err != nil { - return fmt.Errorf("update record: %w", err) + return fmt.Errorf("save: update record: %w", err) } return nil } -// saveAll saves multiple records sequentially. -// TODO: This can be likely optimized to use a batch insert. +// saveAll saves multiple records. +// TODO: This function can be likely optimized to use a batch insert query. func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { for i := range records { if err := t.Save(ctx, records[i]); err != nil { @@ -93,7 +93,7 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [ OrderBy(orderBy...) if err := t.Query.GetOne(ctx, q, record); err != nil { - return nil, err + return nil, fmt.Errorf("get record: %w", err) } return record, nil @@ -138,7 +138,7 @@ func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64 Where(where) if err := t.Query.GetOne(ctx, q, &count); err != nil { - return count, fmt.Errorf("get one: %w", err) + return count, fmt.Errorf("count: %w", err) } return count, nil @@ -148,13 +148,15 @@ func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64 func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { record, err := t.GetByID(ctx, id) if err != nil { - return err + return fmt.Errorf("delete: %w", err) } // Soft delete. if row, ok := any(record).(hasSetDeletedAt); ok { row.SetDeletedAt(time.Now().UTC()) - return t.Save(ctx, record) + if err := t.Save(ctx, record); err != nil { + return fmt.Errorf("soft delete: %w", err) + } } // Hard delete for tables without timestamps. @@ -164,8 +166,10 @@ func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { // HardDeleteByID permanently deletes a record by ID. func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error { q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) - _, err := t.Query.Exec(ctx, q) - return err + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("hard delete: %w", err) + } + return nil } // WithTx returns a table instance bound to the given transaction. @@ -200,7 +204,7 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, } }) if err != nil { - return fmt.Errorf("lock for update one: %w", err) + return err //nolint:wrapcheck } if noRows { @@ -220,11 +224,16 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { // Check if we're already in a transaction if t.DB.Query.Tx != nil { - return t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn) + if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { + return fmt.Errorf("lock for update (with tx): %w", err) + } } return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { - return t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn) + if err := t.lockForUpdatesWithTx(ctx, pgTx, where, orderBy, limit, updateFn); err != nil { + return fmt.Errorf("lock for update (new tx): %w", err) + } + return nil }) } From 8bd8f1f4216ceec398896162a3627242fcef06f3 Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Tue, 21 Oct 2025 13:38:53 +0200 Subject: [PATCH 19/34] Fix tests --- table.go | 1 + 1 file changed, 1 insertion(+) diff --git a/table.go b/table.go index 4abcd65..e916788 100644 --- a/table.go +++ b/table.go @@ -157,6 +157,7 @@ func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { if err := t.Save(ctx, record); err != nil { return fmt.Errorf("soft delete: %w", err) } + return nil } // Hard delete for tables without timestamps. From f28591efbcf2521358e8bd42b015709d62c63f29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Sedl=C3=A1=C4=8Dek?= Date: Fri, 24 Oct 2025 17:23:22 +0200 Subject: [PATCH 20/34] Save multiple (#33) * save multiple * pr comments * return error instead of nil on nil record --- table.go | 99 ++++++++++++++++++++++++++++++++++++++++----- tests/table_test.go | 43 ++++++++++++++++++-- 2 files changed, 128 insertions(+), 14 deletions(-) diff --git a/table.go b/table.go index e916788..201b6c6 100644 --- a/table.go +++ b/table.go @@ -3,6 +3,7 @@ package pgkit import ( "context" "fmt" + "slices" "time" sq "github.com/Masterminds/squirrel" @@ -21,6 +22,10 @@ type Table[T any, PT interface { IDColumn string } +type hasSetCreatedAt interface { + SetCreatedAt(time.Time) +} + type hasSetUpdatedAt interface { SetUpdatedAt(time.Time) } @@ -29,16 +34,25 @@ type hasSetDeletedAt interface { SetDeletedAt(time.Time) } -// Save inserts or updates given records. Auto-detects insert vs update by ID. +// Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record. func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { - if len(records) != 1 { + switch len(records) { + case 0: + return nil + case 1: + return t.saveOne(ctx, records[0]) + default: return t.saveAll(ctx, records) } +} - record := records[0] +func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error { + if record == nil { + return fmt.Errorf("record is nil") + } if err := record.Validate(); err != nil { - return fmt.Errorf("save: validate record: %w", err) + return fmt.Errorf("validate record: %w", err) } if row, ok := any(record).(hasSetUpdatedAt); ok { @@ -48,7 +62,11 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { // Insert var zero IDT if record.GetID() == zero { - q := t.SQL.InsertRecord(record).Into(t.Name).Suffix("RETURNING *") + q := t.SQL. + InsertRecord(record). + Into(t.Name). + Suffix("RETURNING *") + if err := t.Query.GetOne(ctx, q, record); err != nil { return fmt.Errorf("save: insert record: %w", err) } @@ -65,12 +83,73 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { return nil } -// saveAll saves multiple records. -// TODO: This function can be likely optimized to use a batch insert query. +const chunkSize = 1000 + func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { - for i := range records { - if err := t.Save(ctx, records[i]); err != nil { - return err + now := time.Now().UTC() + + insertRecords := make([]PT, 0) + insertIndices := make([]int, 0) // keep track of original indices, so we can update the records with IDs in passed slice + + updateQueries := make(Queries, 0) + + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(r).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + + var zero IDT + if r.GetID() == zero { + if row, ok := any(r).(hasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + + insertRecords = append(insertRecords, r) + insertIndices = append(insertIndices, i) // remember index + } else { + updateQueries.Add(t.SQL. + UpdateRecord(r, sq.Eq{"id": r.GetID()}, t.Name). + SuffixExpr(sq.Expr(" RETURNING *")), + ) + } + } + + // Handle inserts in chunks, has to be done manually, slices.Chunk does not return index :/ + for start := 0; start < len(insertRecords); start += chunkSize { + end := start + chunkSize + if end > len(insertRecords) { + end = len(insertRecords) + } + + chunk := insertRecords[start:end] + q := t.SQL. + InsertRecords(chunk). + Into(t.Name). + SuffixExpr(sq.Expr(" RETURNING *")) + + if err := t.Query.GetAll(ctx, q, &chunk); err != nil { + return fmt.Errorf("insert records: %w", err) + } + + // update original slice + for i, rr := range chunk { + records[insertIndices[start+i]] = rr + } + } + + if len(updateQueries) > 0 { + for chunk := range slices.Chunk(updateQueries, chunkSize) { + if _, err := t.Query.BatchExec(ctx, chunk); err != nil { + return fmt.Errorf("update records: %w", err) + } } } diff --git a/tests/table_test.go b/tests/table_test.go index 4b6b1d6..0c88798 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -56,6 +56,41 @@ func TestTable(t *testing.T) { require.Equal(t, uint64(1), count, "Expected 1 account") }) + t.Run("Save multiple", func(t *testing.T) { + t.Parallel() + // Create account. + account := &Account{Name: "Save Multiple Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "Create account failed") + articles := []*Article{ + {Author: "FirstNew", AccountID: account.ID}, + {Author: "SecondNew", AccountID: account.ID}, + {ID: 10001, Author: "FirstOld", AccountID: account.ID}, + {ID: 10002, Author: "SecondOld", AccountID: account.ID}, + } + err = db.Articles.Save(ctx, articles...) + require.NoError(t, err, "Save articles") + require.NotZero(t, articles[0].ID, "ID should be set") + require.NotZero(t, articles[1].ID, "ID should be set") + require.Equal(t, uint64(10001), articles[2].ID, "ID should be same") + require.Equal(t, uint64(10002), articles[3].ID, "ID should be same") + // test update for multiple records + updateArticles := []*Article{ + articles[0], + articles[1], + } + updateArticles[0].Author = "Updated Author Name 1" + updateArticles[1].Author = "Updated Author Name 2" + err = db.Articles.Save(ctx, updateArticles...) + require.NoError(t, err, "Save articles") + updateArticle0, err := db.Articles.GetByID(ctx, articles[0].ID) + require.NoError(t, err, "Get By ID") + require.Equal(t, updateArticles[0].Author, updateArticle0.Author, "Author should be same") + updateArticle1, err := db.Articles.GetByID(ctx, articles[1].ID) + require.NoError(t, err, "Get By ID") + require.Equal(t, updateArticles[1].Author, updateArticle1.Author, "Author should be same") + }) + t.Run("Complex Transaction", func(t *testing.T) { t.Parallel() ctx := t.Context() @@ -104,8 +139,8 @@ func TestTable(t *testing.T) { require.Equal(t, article.Author, articleCheck.Author, "Article Author should match") require.Equal(t, article.AccountID, articleCheck.AccountID, "Article AccountID should match") require.Equal(t, article.CreatedAt, articleCheck.CreatedAt, "Article CreatedAt should match") - //require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match") - //require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .Save() aboe updates the timestamp. + // require.Equal(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt should match") + // require.NotEqual(t, article.UpdatedAt, articleCheck.UpdatedAt, "Article UpdatedAt shouldn't match") // The .Save() aboe updates the timestamp. require.Equal(t, article.DeletedAt, articleCheck.DeletedAt, "Article DeletedAt should match") } @@ -117,12 +152,12 @@ func TestTable(t *testing.T) { articlesCheck, err := tx.Articles.ListByIDs(ctx, articleIDs) require.NoError(t, err, "ListByIDs failed") require.Equal(t, len(articles), len(articlesCheck), "Number of articles should match") - for i, _ := range articlesCheck { + for i := range articlesCheck { require.Equal(t, articles[i].ID, articlesCheck[i].ID, "Article ID should match") require.Equal(t, articles[i].Author, articlesCheck[i].Author, "Article Author should match") require.Equal(t, articles[i].AccountID, articlesCheck[i].AccountID, "Article AccountID should match") require.Equal(t, articles[i].CreatedAt, articlesCheck[i].CreatedAt, "Article CreatedAt should match") - //require.Equal(t, articles[i].UpdatedAt, articlesCheck[i].UpdatedAt, "Article UpdatedAt should match") + // require.Equal(t, articles[i].UpdatedAt, articlesCheck[i].UpdatedAt, "Article UpdatedAt should match") require.Equal(t, articles[i].DeletedAt, articlesCheck[i].DeletedAt, "Article DeletedAt should match") } From 320b0f3f097a2aefb82fdadd976665504309c2b8 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 10 Dec 2025 13:21:20 +0100 Subject: [PATCH 21/34] Add iterator method for accounts and update tests --- table.go | 53 ++++++++++++++++++++++++++++++++------------- tests/table_test.go | 9 ++++++++ 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/table.go b/table.go index 201b6c6..cfd2fe7 100644 --- a/table.go +++ b/table.go @@ -2,7 +2,9 @@ package pgkit import ( "context" + "errors" "fmt" + "iter" "slices" "time" @@ -156,20 +158,25 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { return nil } -// Get returns the first record matching the condition. -func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) { +// getListQuery builds a base select query for listing records. +func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } - record := new(T) - q := t.SQL. Select("*"). From(t.Name). Where(where). - Limit(1). OrderBy(orderBy...) + return q +} + +// Get returns the first record matching the condition. +func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) { + record := new(T) + + q := t.getListQuery(where, orderBy).Limit(1) if err := t.Query.GetOne(ctx, q, record); err != nil { return nil, fmt.Errorf("get record: %w", err) @@ -180,16 +187,7 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [ // List returns all records matching the condition. func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]PT, error) { - if len(orderBy) == 0 { - orderBy = []string{t.IDColumn} - } - - q := t.SQL. - Select("*"). - From(t.Name). - Where(where). - OrderBy(orderBy...) - + q := t.getListQuery(where, orderBy) var records []PT if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err @@ -198,6 +196,31 @@ func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy return records, nil } +// Iter returns an iterator for records matching the condition. +func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[PT, error], error) { + q := t.getListQuery(where, orderBy) + rows, err := t.Query.QueryRows(ctx, q) + if err != nil { + return nil, fmt.Errorf("query rows: %w", err) + } + + return func(yield func(PT, error) bool) { + defer rows.Close() + for rows.Next() { + var record T + if err := t.Query.Scan.ScanOne(&record, rows); err != nil { + if !errors.Is(err, pgx.ErrNoRows) { + yield(nil, err) + } + return + } + if !yield(&record, nil) { + return + } + } + }, nil +} + // GetByID returns a record by its ID. func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) { return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) diff --git a/tests/table_test.go b/tests/table_test.go index 0c88798..fa5a458 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -54,6 +54,15 @@ func TestTable(t *testing.T) { count, err = db.Accounts.Count(ctx, nil) require.NoError(t, err, "FindAll failed") require.Equal(t, uint64(1), count, "Expected 1 account") + + // Iterate all accounts. + iter, err := db.Accounts.Iter(ctx, nil, nil) + require.NoError(t, err, "Iter failed") + var accounts []Account + for account, err := range iter { + require.NoError(t, err, "Iter error") + accounts = append(accounts, *account) + } }) t.Run("Save multiple", func(t *testing.T) { From b2ce7a79274bbbc2f609068cbf5c4af6e672a438 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 5 Mar 2026 16:37:43 +0100 Subject: [PATCH 22/34] Rename types for clarity --- table.go | 86 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/table.go b/table.go index cfd2fe7..739de9c 100644 --- a/table.go +++ b/table.go @@ -12,32 +12,38 @@ import ( "github.com/jackc/pgx/v5" ) -// Table provides basic CRUD operations for database records. -// Records must implement GetID() and Validate() methods. -type Table[T any, PT interface { +// ID is a comparable type used for record IDs. +type ID comparable + +// Records must be a pointer with the methods defined on the pointer. +type Record[T any, I ID] interface { *T // Enforce T is a pointer. - GetID() IDT + GetID() I Validate() error -}, IDT comparable] struct { +} + +// Table provides basic CRUD operations for database records. +type Table[T any, P Record[T, I], I ID] struct { *DB Name string IDColumn string } -type hasSetCreatedAt interface { - SetCreatedAt(time.Time) -} - -type hasSetUpdatedAt interface { - SetUpdatedAt(time.Time) -} - -type hasSetDeletedAt interface { - SetDeletedAt(time.Time) -} +// helpers for setting timestamp fields +type ( + hasSetCreatedAt interface { + SetCreatedAt(time.Time) + } + hasSetUpdatedAt interface { + SetUpdatedAt(time.Time) + } + hasSetDeletedAt interface { + SetDeletedAt(time.Time) + } +) // Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record. -func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { +func (t *Table[T, P, I]) Save(ctx context.Context, records ...P) error { switch len(records) { case 0: return nil @@ -48,7 +54,7 @@ func (t *Table[T, PT, IDT]) Save(ctx context.Context, records ...PT) error { } } -func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error { +func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error { if record == nil { return fmt.Errorf("record is nil") } @@ -62,7 +68,7 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error { } // Insert - var zero IDT + var zero I if record.GetID() == zero { q := t.SQL. InsertRecord(record). @@ -87,10 +93,10 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error { const chunkSize = 1000 -func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { +func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error { now := time.Now().UTC() - insertRecords := make([]PT, 0) + insertRecords := make([]P, 0) insertIndices := make([]int, 0) // keep track of original indices, so we can update the records with IDs in passed slice updateQueries := make(Queries, 0) @@ -108,7 +114,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { row.SetUpdatedAt(now) } - var zero IDT + var zero I if r.GetID() == zero { if row, ok := any(r).(hasSetCreatedAt); ok { row.SetCreatedAt(now) @@ -159,7 +165,7 @@ func (t *Table[T, PT, IDT]) saveAll(ctx context.Context, records []PT) error { } // getListQuery builds a base select query for listing records. -func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder { +func (t *Table[T, P, I]) getListQuery(where sq.Sqlizer, orderBy []string) sq.SelectBuilder { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -173,7 +179,7 @@ func (t *Table[T, PT, IDT]) getListQuery(where sq.Sqlizer, orderBy []string) sq. } // Get returns the first record matching the condition. -func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (PT, error) { +func (t *Table[T, P, I]) Get(ctx context.Context, where sq.Sqlizer, orderBy []string) (P, error) { record := new(T) q := t.getListQuery(where, orderBy).Limit(1) @@ -186,9 +192,9 @@ func (t *Table[T, PT, IDT]) Get(ctx context.Context, where sq.Sqlizer, orderBy [ } // List returns all records matching the condition. -func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]PT, error) { +func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []string) ([]P, error) { q := t.getListQuery(where, orderBy) - var records []PT + var records []P if err := t.Query.GetAll(ctx, q, &records); err != nil { return nil, err } @@ -197,14 +203,14 @@ func (t *Table[T, PT, IDT]) List(ctx context.Context, where sq.Sqlizer, orderBy } // Iter returns an iterator for records matching the condition. -func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[PT, error], error) { +func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[P, error], error) { q := t.getListQuery(where, orderBy) rows, err := t.Query.QueryRows(ctx, q) if err != nil { return nil, fmt.Errorf("query rows: %w", err) } - return func(yield func(PT, error) bool) { + return func(yield func(P, error) bool) { defer rows.Close() for rows.Next() { var record T @@ -222,17 +228,17 @@ func (t *Table[T, PT, IDT]) Iter(ctx context.Context, where sq.Sqlizer, orderBy } // GetByID returns a record by its ID. -func (t *Table[T, PT, IDT]) GetByID(ctx context.Context, id IDT) (PT, error) { +func (t *Table[T, P, I]) GetByID(ctx context.Context, id I) (P, error) { return t.Get(ctx, sq.Eq{t.IDColumn: id}, []string{t.IDColumn}) } // ListByIDs returns records by their IDs. -func (t *Table[T, PT, IDT]) ListByIDs(ctx context.Context, ids []IDT) ([]PT, error) { +func (t *Table[T, P, I]) ListByIDs(ctx context.Context, ids []I) ([]P, error) { return t.List(ctx, sq.Eq{t.IDColumn: ids}, nil) } // Count returns the number of matching records. -func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) { +func (t *Table[T, P, I]) Count(ctx context.Context, where sq.Sqlizer) (uint64, error) { var count uint64 q := t.SQL. Select("COUNT(1)"). @@ -247,7 +253,7 @@ func (t *Table[T, PT, IDT]) Count(ctx context.Context, where sq.Sqlizer) (uint64 } // DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists. -func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { +func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error { record, err := t.GetByID(ctx, id) if err != nil { return fmt.Errorf("delete: %w", err) @@ -267,7 +273,7 @@ func (t *Table[T, PT, IDT]) DeleteByID(ctx context.Context, id IDT) error { } // HardDeleteByID permanently deletes a record by ID. -func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error { +func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error { q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) if _, err := t.Query.Exec(ctx, q); err != nil { return fmt.Errorf("hard delete: %w", err) @@ -276,8 +282,8 @@ func (t *Table[T, PT, IDT]) HardDeleteByID(ctx context.Context, id IDT) error { } // WithTx returns a table instance bound to the given transaction. -func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { - return &Table[T, PT, IDT]{ +func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { + return &Table[T, P, I]{ DB: &DB{ Conn: t.DB.Conn, SQL: t.DB.SQL, @@ -296,10 +302,10 @@ func (t *Table[T, PT, IDT]) WithTx(tx pgx.Tx) *Table[T, PT, IDT] { // to update status to "completed" or "failed". // // Returns ErrNoRows if no matching records are available for locking. -func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record PT)) error { +func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, orderBy []string, updateFn func(record P)) error { var noRows bool - err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []PT) { + err := t.LockForUpdates(ctx, where, orderBy, 1, func(records []P) { if len(records) > 0 { updateFn(records[0]) } else { @@ -324,7 +330,7 @@ func (t *Table[T, PT, IDT]) LockForUpdate(ctx context.Context, where sq.Sqlizer, // Keep updateFn() fast to avoid holding the transaction. For long-running work, update status // to "processing" and return early, then process asynchronously. Use defer LockForUpdate() // to update status to "completed" or "failed". -func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { +func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { // Check if we're already in a transaction if t.DB.Query.Tx != nil { if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { @@ -340,7 +346,7 @@ func (t *Table[T, PT, IDT]) LockForUpdates(ctx context.Context, where sq.Sqlizer }) } -func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []PT)) error { +func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { if len(orderBy) == 0 { orderBy = []string{t.IDColumn} } @@ -355,7 +361,7 @@ func (t *Table[T, PT, IDT]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.T txQuery := t.DB.TxQuery(pgTx) - var records []PT + var records []P if err := txQuery.GetAll(ctx, q, &records); err != nil { return fmt.Errorf("select for update skip locked: %w", err) } From 95252fc5489065c22d52d2c8bbf156c5b3498552 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 5 Mar 2026 17:28:05 +0100 Subject: [PATCH 23/34] add back truncateAllTables --- tests/helpers_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/helpers_test.go b/tests/helpers_test.go index f25ced6..cde9604 100644 --- a/tests/helpers_test.go +++ b/tests/helpers_test.go @@ -8,6 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) +func truncateAllTables(t *testing.T) { + truncateTable(t, "accounts") + truncateTable(t, "reviews") + truncateTable(t, "logs") + truncateTable(t, "stats") + truncateTable(t, "articles") +} + func truncateTable(t *testing.T, tableName string) { _, err := DB.Conn.Exec(context.Background(), fmt.Sprintf(`TRUNCATE TABLE %q CASCADE`, tableName)) assert.NoError(t, err) From a3fdad7dc8a2e2513e1e89774fe8ffb229b32935 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 6 Mar 2026 17:51:02 +0100 Subject: [PATCH 24/34] Add IDColumn to table context and enhance WithTx tests --- table.go | 3 ++- tests/table_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/table.go b/table.go index 739de9c..5f736d3 100644 --- a/table.go +++ b/table.go @@ -289,7 +289,8 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { SQL: t.DB.SQL, Query: t.DB.TxQuery(tx), }, - Name: t.Name, + Name: t.Name, + IDColumn: t.IDColumn, } } diff --git a/tests/table_test.go b/tests/table_test.go index fa5a458..a97540f 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -7,6 +7,7 @@ import ( "testing" sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" ) @@ -193,6 +194,34 @@ func TestTable(t *testing.T) { }) require.NoError(t, err, "SaveTx transaction failed") }) + + t.Run("WithTx keeps IDColumn", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithTx IDColumn Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err, "create account failed") + + article := &Article{AccountID: account.ID, Author: "WithTx author"} + err = db.Articles.Save(ctx, article) + require.NoError(t, err, "create article failed") + + err = pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + txTable := db.Articles.Table.WithTx(pgTx) + if err := txTable.HardDeleteByID(ctx, article.ID); err != nil { + return err + } + + _, err := txTable.GetByID(ctx, article.ID) + require.Error(t, err, "article should be deleted inside tx") + + return nil + }) + require.NoError(t, err, "WithTx HardDeleteByID failed") + + _, err = db.Articles.GetByID(ctx, article.ID) + require.Error(t, err, "article should be deleted") + }) } func TestLockForUpdates(t *testing.T) { From cf0cd568f6fe9abb6b11601041d75047ee491d8b Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Mon, 9 Mar 2026 14:19:35 +0100 Subject: [PATCH 25/34] Add pagination support to Table with ListPaged and WithPaginator methods --- table.go | 32 +++++++++++++-- tests/table_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 4 deletions(-) diff --git a/table.go b/table.go index 5f736d3..a98a31b 100644 --- a/table.go +++ b/table.go @@ -25,8 +25,9 @@ type Record[T any, I ID] interface { // Table provides basic CRUD operations for database records. type Table[T any, P Record[T, I], I ID] struct { *DB - Name string - IDColumn string + Name string + IDColumn string + Paginator Paginator[P] } // helpers for setting timestamp fields @@ -202,6 +203,18 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s return records, nil } +// ListPaged returns paginated records matching the condition. +func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *Page) ([]P, *Page, error) { + q := t.SQL.Select("*").From(t.Name).Where(where) + + result, q := t.Paginator.PrepareQuery(q, page) + if err := t.Query.GetAll(ctx, q, &result); err != nil { + return nil, nil, err + } + result = t.Paginator.PrepareResult(result, page) + return result, page, nil +} + // Iter returns an iterator for records matching the condition. func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []string) (iter.Seq2[P, error], error) { q := t.getListQuery(where, orderBy) @@ -281,6 +294,16 @@ func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error { return nil } +// WithPaginator returns a table instance with the given paginator. +func (t *Table[T, P, I]) WithPaginator(opts ...PaginatorOption) *Table[T, P, I] { + return &Table[T, P, I]{ + DB: t.DB, + Name: t.Name, + IDColumn: t.IDColumn, + Paginator: NewPaginator[P](opts...), + } +} + // WithTx returns a table instance bound to the given transaction. func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { return &Table[T, P, I]{ @@ -289,8 +312,9 @@ func (t *Table[T, P, I]) WithTx(tx pgx.Tx) *Table[T, P, I] { SQL: t.DB.SQL, Query: t.DB.TxQuery(tx), }, - Name: t.Name, - IDColumn: t.IDColumn, + Name: t.Name, + IDColumn: t.IDColumn, + Paginator: t.Paginator, } } diff --git a/tests/table_test.go b/tests/table_test.go index a97540f..bf3aa70 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -7,6 +7,7 @@ import ( "testing" sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" ) @@ -195,6 +196,104 @@ func TestTable(t *testing.T) { require.NoError(t, err, "SaveTx transaction failed") }) + t.Run("ListPaged", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "ListPaged Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + // Create 15 articles. + for i := range 15 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("Author %02d", i), + }) + require.NoError(t, err) + } + + // Default paginator (page size 10). + page := pgkit.NewPage(0, 1) + results, retPage, err := db.Articles.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 10) + require.True(t, retPage.More, "should have more pages") + + // Second page. + page2 := pgkit.NewPage(0, 2) + results2, retPage2, err := db.Articles.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page2) + require.NoError(t, err) + require.Len(t, results2, 5) + require.False(t, retPage2.More, "should not have more pages") + + // No overlap between pages. + for _, r1 := range results { + for _, r2 := range results2 { + require.NotEqual(t, r1.ID, r2.ID, "pages should not overlap") + } + } + }) + + t.Run("WithPaginator", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithPaginator Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for i := range 10 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("PagAuthor %02d", i), + }) + require.NoError(t, err) + } + + // Use a custom paginator with page size 3. + pagedTable := db.Articles.Table.WithPaginator(pgkit.WithDefaultSize(3), pgkit.WithMaxSize(5)) + + page := pgkit.NewPage(0, 1) + results, retPage, err := pagedTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 3, "should return 3 records with custom paginator") + require.True(t, retPage.More) + + // Request size larger than max should be capped. + bigPage := pgkit.NewPage(100, 1) + results, _, err = pagedTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, bigPage) + require.NoError(t, err) + require.Len(t, results, 5, "should be capped at max size 5") + }) + + t.Run("WithTx preserves Paginator", func(t *testing.T) { + ctx := t.Context() + + account := &Account{Name: "WithTx Paginator Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for i := range 5 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: fmt.Sprintf("TxPag %02d", i), + }) + require.NoError(t, err) + } + + pagedTable := db.Articles.Table.WithPaginator(pgkit.WithDefaultSize(2)) + + err = pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { + txTable := pagedTable.WithTx(pgTx) + page := pgkit.NewPage(0, 1) + results, retPage, err := txTable.ListPaged(ctx, sq.Eq{"account_id": account.ID}, page) + require.NoError(t, err) + require.Len(t, results, 2, "paginator should be preserved through WithTx") + require.True(t, retPage.More) + return nil + }) + require.NoError(t, err) + }) + t.Run("WithTx keeps IDColumn", func(t *testing.T) { ctx := t.Context() From f74541ec1882d8b6f2daa114572999620410520c Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Tue, 17 Mar 2026 21:31:54 +0100 Subject: [PATCH 26/34] feat: add Insert and Update methods in Table --- table.go | 159 ++++++++++++++++++++++++++++++++++++++++++++ tests/table_test.go | 140 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 299 insertions(+) diff --git a/table.go b/table.go index a98a31b..a120c23 100644 --- a/table.go +++ b/table.go @@ -43,6 +43,165 @@ type ( } ) +// Insert inserts one or more records. Sets CreatedAt and UpdatedAt timestamps if available. +// Records are returned with their generated fields populated via RETURNING *. +func (t *Table[T, P, I]) Insert(ctx context.Context, records ...P) error { + switch len(records) { + case 0: + return nil + case 1: + return t.insertOne(ctx, records[0]) + default: + return t.insertAll(ctx, records) + } +} + +func (t *Table[T, P, I]) insertOne(ctx context.Context, record P) error { + if record == nil { + return fmt.Errorf("record is nil") + } + + if err := record.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + now := time.Now().UTC() + if row, ok := any(record).(hasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + if row, ok := any(record).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + + q := t.SQL. + InsertRecord(record). + Into(t.Name). + Suffix("RETURNING *") + + if err := t.Query.GetOne(ctx, q, record); err != nil { + return fmt.Errorf("insert record: %w", err) + } + + return nil +} + +func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { + now := time.Now().UTC() + + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(r).(hasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + if row, ok := any(r).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + } + + for start := 0; start < len(records); start += chunkSize { + end := start + chunkSize + if end > len(records) { + end = len(records) + } + + chunk := records[start:end] + q := t.SQL. + InsertRecords(chunk). + Into(t.Name). + SuffixExpr(sq.Expr(" RETURNING *")) + + if err := t.Query.GetAll(ctx, q, &chunk); err != nil { + return fmt.Errorf("insert records: %w", err) + } + } + + return nil +} + +// Update updates one or more records by their ID. Sets UpdatedAt timestamp if available. +// Returns an error if any record has a zero ID. +func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) error { + switch len(records) { + case 0: + return nil + case 1: + return t.updateOne(ctx, records[0]) + default: + return t.updateAll(ctx, records) + } +} + +func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) error { + if record == nil { + return fmt.Errorf("record is nil") + } + + var zero I + if record.GetID() == zero { + return fmt.Errorf("update record: ID is zero") + } + + if err := record.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(record).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(time.Now().UTC()) + } + + q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) + if _, err := t.Query.Exec(ctx, q); err != nil { + return fmt.Errorf("update record: %w", err) + } + + return nil +} + +func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error { + now := time.Now().UTC() + + queries := make(Queries, 0, len(records)) + var zero I + + for i, r := range records { + if r == nil { + return fmt.Errorf("record with index=%d is nil", i) + } + + if r.GetID() == zero { + return fmt.Errorf("update record with index=%d: ID is zero", i) + } + + if err := r.Validate(); err != nil { + return fmt.Errorf("validate record: %w", err) + } + + if row, ok := any(r).(hasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } + + queries.Add(t.SQL. + UpdateRecord(r, sq.Eq{t.IDColumn: r.GetID()}, t.Name). + SuffixExpr(sq.Expr(" RETURNING *")), + ) + } + + for chunk := range slices.Chunk(queries, chunkSize) { + if _, err := t.Query.BatchExec(ctx, chunk); err != nil { + return fmt.Errorf("update records: %w", err) + } + } + + return nil +} + // Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record. func (t *Table[T, P, I]) Save(ctx context.Context, records ...P) error { switch len(records) { diff --git a/tests/table_test.go b/tests/table_test.go index bf3aa70..a6a5f3b 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -323,6 +323,146 @@ func TestTable(t *testing.T) { }) } +func TestInsert(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("Insert single", func(t *testing.T) { + account := &Account{Name: "Insert Account"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + require.NotZero(t, account.ID, "ID should be set after insert") + require.NotZero(t, account.UpdatedAt, "UpdatedAt should be set") + + // Verify in DB. + got, err := db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err) + require.Equal(t, account.Name, got.Name) + }) + + t.Run("Insert multiple", func(t *testing.T) { + account := &Account{Name: "Insert Multiple Account"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + articles := []*Article{ + {Author: "Author A", AccountID: account.ID}, + {Author: "Author B", AccountID: account.ID}, + {Author: "Author C", AccountID: account.ID}, + } + err = db.Articles.Insert(ctx, articles...) + require.NoError(t, err) + + for _, a := range articles { + require.NotZero(t, a.ID, "ID should be set after bulk insert") + require.NotZero(t, a.UpdatedAt, "UpdatedAt should be set") + } + + // Verify all in DB. + for _, a := range articles { + got, err := db.Articles.GetByID(ctx, a.ID) + require.NoError(t, err) + require.Equal(t, a.Author, got.Author) + } + }) + + t.Run("Insert nil record", func(t *testing.T) { + err := db.Accounts.Insert(ctx, nil) + require.Error(t, err) + }) + + t.Run("Insert invalid record", func(t *testing.T) { + err := db.Accounts.Insert(ctx, &Account{Name: ""}) + require.Error(t, err, "should fail validation") + }) + + t.Run("Insert zero records", func(t *testing.T) { + err := db.Accounts.Insert(ctx) + require.NoError(t, err, "inserting zero records should be a no-op") + }) +} + +func TestUpdate(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("Update single", func(t *testing.T) { + account := &Account{Name: "Before Update"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + account.Name = "After Update" + err = db.Accounts.Update(ctx, account) + require.NoError(t, err) + + got, err := db.Accounts.GetByID(ctx, account.ID) + require.NoError(t, err) + require.Equal(t, "After Update", got.Name) + }) + + t.Run("Update multiple", func(t *testing.T) { + account := &Account{Name: "Update Multiple Account"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + articles := []*Article{ + {Author: "Original A", AccountID: account.ID}, + {Author: "Original B", AccountID: account.ID}, + } + err = db.Articles.Insert(ctx, articles...) + require.NoError(t, err) + + articles[0].Author = "Updated A" + articles[1].Author = "Updated B" + err = db.Articles.Update(ctx, articles...) + require.NoError(t, err) + + for _, a := range articles { + got, err := db.Articles.GetByID(ctx, a.ID) + require.NoError(t, err) + require.Equal(t, a.Author, got.Author) + } + }) + + t.Run("Update with zero ID fails", func(t *testing.T) { + err := db.Accounts.Update(ctx, &Account{Name: "No ID"}) + require.Error(t, err, "should fail with zero ID") + }) + + t.Run("Update multiple with zero ID fails", func(t *testing.T) { + account := &Account{Name: "Update Zero ID Account"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + err = db.Accounts.Update(ctx, account, &Account{Name: "No ID"}) + require.Error(t, err, "should fail when any record has zero ID") + }) + + t.Run("Update nil record", func(t *testing.T) { + err := db.Accounts.Update(ctx, nil) + require.Error(t, err) + }) + + t.Run("Update invalid record", func(t *testing.T) { + account := &Account{Name: "Valid"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + account.Name = "" + err = db.Accounts.Update(ctx, account) + require.Error(t, err, "should fail validation") + }) + + t.Run("Update zero records", func(t *testing.T) { + err := db.Accounts.Update(ctx) + require.NoError(t, err, "updating zero records should be a no-op") + }) +} + func TestLockForUpdates(t *testing.T) { truncateAllTables(t) From 2fa36057f1c9a087257e3dabe60bab59fcd9be3f Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Mon, 23 Mar 2026 11:28:13 +0100 Subject: [PATCH 27/34] feat: add RestoreByID and export timestamp interfaces - Add RestoreByID method that clears DeletedAt by passing zero time - Export HasSetCreatedAt, HasSetUpdatedAt, HasSetDeletedAt interfaces with godoc explaining the contract for each lifecycle hook - Update SetDeletedAt implementations to treat zero time as restore (nil) --- table.go | 81 ++++++++++++++++++++++++++++++++------------ tests/schema_test.go | 16 +++++++-- tests/table_test.go | 13 +++++++ 3 files changed, 86 insertions(+), 24 deletions(-) diff --git a/table.go b/table.go index a120c23..1e98bc1 100644 --- a/table.go +++ b/table.go @@ -30,18 +30,34 @@ type Table[T any, P Record[T, I], I ID] struct { Paginator Paginator[P] } -// helpers for setting timestamp fields -type ( - hasSetCreatedAt interface { - SetCreatedAt(time.Time) - } - hasSetUpdatedAt interface { - SetUpdatedAt(time.Time) - } - hasSetDeletedAt interface { - SetDeletedAt(time.Time) - } -) +// HasSetCreatedAt is implemented by records that track creation time. +// Insert will automatically call SetCreatedAt with the current UTC time. +type HasSetCreatedAt interface { + SetCreatedAt(time.Time) +} + +// HasSetUpdatedAt is implemented by records that track update time. +// Insert, Update, and Save will automatically call SetUpdatedAt with the current UTC time. +type HasSetUpdatedAt interface { + SetUpdatedAt(time.Time) +} + +// HasSetDeletedAt is implemented by records that support soft delete. +// DeleteByID will call SetDeletedAt with the current UTC time to soft-delete, +// and RestoreByID will call SetDeletedAt with a zero time.Time{} to restore. +// +// Implementations should treat a zero time as a restore (clear the timestamp): +// +// func (r *MyRecord) SetDeletedAt(t time.Time) { +// if t.IsZero() { +// r.DeletedAt = nil // restore: clear the timestamp +// return +// } +// r.DeletedAt = &t // soft delete: set the timestamp +// } +type HasSetDeletedAt interface { + SetDeletedAt(time.Time) +} // Insert inserts one or more records. Sets CreatedAt and UpdatedAt timestamps if available. // Records are returned with their generated fields populated via RETURNING *. @@ -66,10 +82,10 @@ func (t *Table[T, P, I]) insertOne(ctx context.Context, record P) error { } now := time.Now().UTC() - if row, ok := any(record).(hasSetCreatedAt); ok { + if row, ok := any(record).(HasSetCreatedAt); ok { row.SetCreatedAt(now) } - if row, ok := any(record).(hasSetUpdatedAt); ok { + if row, ok := any(record).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } @@ -97,10 +113,10 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { return fmt.Errorf("validate record: %w", err) } - if row, ok := any(r).(hasSetCreatedAt); ok { + if row, ok := any(r).(HasSetCreatedAt); ok { row.SetCreatedAt(now) } - if row, ok := any(r).(hasSetUpdatedAt); ok { + if row, ok := any(r).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } } @@ -152,7 +168,7 @@ func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) error { return fmt.Errorf("validate record: %w", err) } - if row, ok := any(record).(hasSetUpdatedAt); ok { + if row, ok := any(record).(HasSetUpdatedAt); ok { row.SetUpdatedAt(time.Now().UTC()) } @@ -183,7 +199,7 @@ func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error { return fmt.Errorf("validate record: %w", err) } - if row, ok := any(r).(hasSetUpdatedAt); ok { + if row, ok := any(r).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } @@ -223,7 +239,7 @@ func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error { return fmt.Errorf("validate record: %w", err) } - if row, ok := any(record).(hasSetUpdatedAt); ok { + if row, ok := any(record).(HasSetUpdatedAt); ok { row.SetUpdatedAt(time.Now().UTC()) } @@ -270,13 +286,13 @@ func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error { return fmt.Errorf("validate record: %w", err) } - if row, ok := any(r).(hasSetUpdatedAt); ok { + if row, ok := any(r).(HasSetUpdatedAt); ok { row.SetUpdatedAt(now) } var zero I if r.GetID() == zero { - if row, ok := any(r).(hasSetCreatedAt); ok { + if row, ok := any(r).(HasSetCreatedAt); ok { row.SetCreatedAt(now) } @@ -432,7 +448,7 @@ func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error { } // Soft delete. - if row, ok := any(record).(hasSetDeletedAt); ok { + if row, ok := any(record).(HasSetDeletedAt); ok { row.SetDeletedAt(time.Now().UTC()) if err := t.Save(ctx, record); err != nil { return fmt.Errorf("soft delete: %w", err) @@ -444,6 +460,27 @@ func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error { return t.HardDeleteByID(ctx, id) } +// RestoreByID restores a soft-deleted record by ID by clearing its DeletedAt timestamp. +// Returns an error if the record does not implement .SetDeletedAt(). +func (t *Table[T, P, I]) RestoreByID(ctx context.Context, id I) error { + record, err := t.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("restore: %w", err) + } + + row, ok := any(record).(HasSetDeletedAt) + if !ok { + return fmt.Errorf("restore: record does not support soft delete") + } + + row.SetDeletedAt(time.Time{}) + if err := t.Save(ctx, record); err != nil { + return fmt.Errorf("restore: %w", err) + } + + return nil +} + // HardDeleteByID permanently deletes a record by ID. func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error { q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) diff --git a/tests/schema_test.go b/tests/schema_test.go index dd2c3ee..05454d8 100644 --- a/tests/schema_test.go +++ b/tests/schema_test.go @@ -40,7 +40,13 @@ type Article struct { func (a *Article) GetID() uint64 { return a.ID } func (a *Article) SetUpdatedAt(t time.Time) { a.UpdatedAt = t } -func (a *Article) SetDeletedAt(t time.Time) { a.DeletedAt = &t } +func (a *Article) SetDeletedAt(t time.Time) { + if t.IsZero() { + a.DeletedAt = nil + return + } + a.DeletedAt = &t +} func (a *Article) Validate() error { if a.Author == "" { @@ -71,7 +77,13 @@ type Review struct { func (r *Review) GetID() uint64 { return r.ID } func (r *Review) SetUpdatedAt(t time.Time) { r.UpdatedAt = t } -func (r *Review) SetDeletedAt(t time.Time) { r.DeletedAt = &t } +func (r *Review) SetDeletedAt(t time.Time) { + if t.IsZero() { + r.DeletedAt = nil + return + } + r.DeletedAt = &t +} func (r *Review) Validate() error { if len(r.Comment) < 3 { diff --git a/tests/table_test.go b/tests/table_test.go index a6a5f3b..a5afecf 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -182,6 +182,19 @@ func TestTable(t *testing.T) { require.Equal(t, firstArticle.ID, article.ID, "DeletedAt should be set") require.NotNil(t, article.DeletedAt, "DeletedAt should be set") + // Restore first article. + err = tx.Articles.RestoreByID(ctx, firstArticle.ID) + require.NoError(t, err, "RestoreByID failed") + + // Check if article is restored. + article, err = tx.Articles.GetByID(ctx, firstArticle.ID) + require.NoError(t, err, "GetByID failed after restore") + require.Nil(t, article.DeletedAt, "DeletedAt should be nil after restore") + + // Soft-delete again for the hard-delete test. + err = tx.Articles.DeleteByID(ctx, firstArticle.ID) + require.NoError(t, err, "DeleteByID failed") + // Hard-delete first article. err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID) require.NoError(t, err, "HardDeleteByID failed") From 24ad27f0c8aebdcb7ea174b45392779c14a3d346 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 11:16:23 +0200 Subject: [PATCH 28/34] fix: LockForUpdates missing return, saveAll hardcoded ID column, Iter missing rows.Err, saveOne missing SetCreatedAt - LockForUpdates: add missing return after existing-tx branch to prevent double execution of updateFn via a second transaction - saveAll: use t.IDColumn instead of hardcoded "id" for update WHERE clause - Iter: check rows.Err() after iteration loop to surface driver errors - saveOne: call SetCreatedAt on insert path (consistent with saveAll) - Use min() builtin for chunk bounds --- table.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/table.go b/table.go index 1e98bc1..348c4e3 100644 --- a/table.go +++ b/table.go @@ -122,10 +122,7 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { } for start := 0; start < len(records); start += chunkSize { - end := start + chunkSize - if end > len(records) { - end = len(records) - } + end := min(start+chunkSize, len(records)) chunk := records[start:end] q := t.SQL. @@ -239,13 +236,18 @@ func (t *Table[T, P, I]) saveOne(ctx context.Context, record P) error { return fmt.Errorf("validate record: %w", err) } + now := time.Now().UTC() if row, ok := any(record).(HasSetUpdatedAt); ok { - row.SetUpdatedAt(time.Now().UTC()) + row.SetUpdatedAt(now) } // Insert var zero I if record.GetID() == zero { + if row, ok := any(record).(HasSetCreatedAt); ok { + row.SetCreatedAt(now) + } + q := t.SQL. InsertRecord(record). Into(t.Name). @@ -300,7 +302,7 @@ func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error { insertIndices = append(insertIndices, i) // remember index } else { updateQueries.Add(t.SQL. - UpdateRecord(r, sq.Eq{"id": r.GetID()}, t.Name). + UpdateRecord(r, sq.Eq{t.IDColumn: r.GetID()}, t.Name). SuffixExpr(sq.Expr(" RETURNING *")), ) } @@ -308,10 +310,7 @@ func (t *Table[T, P, I]) saveAll(ctx context.Context, records []P) error { // Handle inserts in chunks, has to be done manually, slices.Chunk does not return index :/ for start := 0; start < len(insertRecords); start += chunkSize { - end := start + chunkSize - if end > len(insertRecords) { - end = len(insertRecords) - } + end := min(start+chunkSize, len(insertRecords)) chunk := insertRecords[start:end] q := t.SQL. @@ -412,6 +411,9 @@ func (t *Table[T, P, I]) Iter(ctx context.Context, where sq.Sqlizer, orderBy []s return } } + if err := rows.Err(); err != nil { + yield(nil, err) + } }, nil } @@ -552,11 +554,12 @@ func (t *Table[T, P, I]) LockForUpdate(ctx context.Context, where sq.Sqlizer, or // to "processing" and return early, then process asynchronously. Use defer LockForUpdate() // to update status to "completed" or "failed". func (t *Table[T, P, I]) LockForUpdates(ctx context.Context, where sq.Sqlizer, orderBy []string, limit uint64, updateFn func(records []P)) error { - // Check if we're already in a transaction + // Reuse existing transaction if available. if t.DB.Query.Tx != nil { if err := t.lockForUpdatesWithTx(ctx, t.DB.Query.Tx, where, orderBy, limit, updateFn); err != nil { return fmt.Errorf("lock for update (with tx): %w", err) } + return nil } return pgx.BeginFunc(ctx, t.DB.Conn, func(pgTx pgx.Tx) error { From c7fd6eae038299a393ffa7bebfa49e2ea33f86be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 11:31:58 +0200 Subject: [PATCH 29/34] fix(tests): race conditions in TestLockForUpdates and wrong ListByIDs slice - Move wg.Add(1) before goroutine dispatch in ProcessReview to prevent worker.Wait() returning early before all goroutines register - Replace shared ids[][]uint64 slice (data race across goroutines) with per-worker local slices merged under mutex - Fix articleIDs slice: make([]uint64, len) + append produced leading zeros, changed to make([]uint64, 0, len) --- tests/table_test.go | 18 ++++++++++++------ tests/worker_test.go | 1 - 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/table_test.go b/tests/table_test.go index a5afecf..3b2db44 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -156,7 +156,7 @@ func TestTable(t *testing.T) { } // Verify we can load all articles with .ListByIDs() - articleIDs := make([]uint64, len(articles)) + articleIDs := make([]uint64, 0, len(articles)) for _, article := range articles { articleIDs = append(articleIDs, article.ID) } @@ -507,7 +507,8 @@ func TestLockForUpdates(t *testing.T) { err = db.Reviews.Save(ctx, reviews...) require.NoError(t, err, "create review") - var ids [][]uint64 = make([][]uint64, 10) + var mu sync.Mutex + var allIDs []uint64 var wg sync.WaitGroup for range 10 { @@ -518,17 +519,22 @@ func TestLockForUpdates(t *testing.T) { reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) require.NoError(t, err, "dequeue reviews") - for i, review := range reviews { + var localIDs []uint64 + for _, review := range reviews { + localIDs = append(localIDs, review.ID) + worker.wg.Add(1) go worker.ProcessReview(ctx, review) - - ids[i] = append(ids[i], review.ID) } + + mu.Lock() + allIDs = append(allIDs, localIDs...) + mu.Unlock() }() } wg.Wait() // Ensure that all reviews were picked up for processing exactly once. - uniqueIDs := slices.Concat(ids...) + uniqueIDs := slices.Clone(allIDs) slices.Sort(uniqueIDs) uniqueIDs = slices.Compact(uniqueIDs) require.Equal(t, 100, len(uniqueIDs), "number of unique reviews picked up for processing should be 100") diff --git a/tests/worker_test.go b/tests/worker_test.go index 711a3af..6bc6416 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -22,7 +22,6 @@ func (w *Worker) Wait() { } func (w *Worker) ProcessReview(ctx context.Context, review *Review) (err error) { - w.wg.Add(1) defer w.wg.Done() defer func() { From d753ee8913090f34d543e584d6e0d6f2e200d930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 11:32:04 +0200 Subject: [PATCH 30/34] fix: ListPaged nil page panic Normalize nil page before passing to PrepareQuery/PrepareResult. PrepareQuery creates a local &Page{} but the caller's pointer stays nil, causing PrepareResult to panic on page.More assignment. --- table.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/table.go b/table.go index 348c4e3..c1ab134 100644 --- a/table.go +++ b/table.go @@ -379,6 +379,9 @@ func (t *Table[T, P, I]) List(ctx context.Context, where sq.Sqlizer, orderBy []s // ListPaged returns paginated records matching the condition. func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page *Page) ([]P, *Page, error) { + if page == nil { + page = &Page{} + } q := t.SQL.Select("*").From(t.Name).Where(where) result, q := t.Paginator.PrepareQuery(q, page) From 70cdd93e6eb49c3e74189c9b817b1aead14a10f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 11:47:40 +0200 Subject: [PATCH 31/34] fix: LockForUpdates validate/timestamp, zero-value Paginator, ListPaged ordering - lockForUpdatesWithTx: call Validate() and SetUpdatedAt() on records after updateFn, matching Insert/Update/Save behavior - Page.SetDefaults: fall back to DefaultPageSize/MaxPageSize when PaginatorSettings has zero values (fixes zero-value Paginator capping page size to 0) - ListPaged: add IDColumn fallback ordering when no sort is configured, ensuring deterministic pagination - TestLockForUpdates: use assert instead of require in goroutine (require.FailNow is unsafe off the test goroutine) --- page.go | 19 ++++++++++++------- table.go | 11 +++++++++++ tests/table_test.go | 3 ++- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/page.go b/page.go index d6fcd76..7bf03e1 100644 --- a/page.go +++ b/page.go @@ -93,16 +93,21 @@ func NewPage(size, page uint32, sort ...Sort) *Page { func (p *Page) SetDefaults(o *PaginatorSettings) { if o == nil { - o = &PaginatorSettings{ - DefaultSize: DefaultPageSize, - MaxSize: MaxPageSize, - } + o = &PaginatorSettings{} + } + defaultSize := o.DefaultSize + if defaultSize == 0 { + defaultSize = DefaultPageSize + } + maxSize := o.MaxSize + if maxSize == 0 { + maxSize = MaxPageSize } if p.Size == 0 { - p.Size = o.DefaultSize + p.Size = defaultSize } - if p.Size > o.MaxSize { - p.Size = o.MaxSize + if p.Size > maxSize { + p.Size = maxSize } if p.Page == 0 { p.Page = 1 diff --git a/table.go b/table.go index c1ab134..1f6822d 100644 --- a/table.go +++ b/table.go @@ -382,6 +382,10 @@ func (t *Table[T, P, I]) ListPaged(ctx context.Context, where sq.Sqlizer, page * if page == nil { page = &Page{} } + // Ensure deterministic ordering for stable pagination. + if len(page.Sort) == 0 && page.Column == "" && len(t.Paginator.settings.Sort) == 0 { + page.Sort = []Sort{{Column: t.IDColumn, Order: Asc}} + } q := t.SQL.Select("*").From(t.Name).Where(where) result, q := t.Paginator.PrepareQuery(q, page) @@ -595,7 +599,14 @@ func (t *Table[T, P, I]) lockForUpdatesWithTx(ctx context.Context, pgTx pgx.Tx, updateFn(records) + now := time.Now().UTC() for _, record := range records { + if err := record.Validate(); err != nil { + return fmt.Errorf("validate record after update: %w", err) + } + if row, ok := any(record).(HasSetUpdatedAt); ok { + row.SetUpdatedAt(now) + } q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) if _, err := txQuery.Exec(ctx, q); err != nil { return fmt.Errorf("update record: %w", err) diff --git a/tests/table_test.go b/tests/table_test.go index 3b2db44..128f616 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/goware/pgkit/v2" "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -517,7 +518,7 @@ func TestLockForUpdates(t *testing.T) { defer wg.Done() reviews, err := db.Reviews.DequeueForProcessing(ctx, 10) - require.NoError(t, err, "dequeue reviews") + assert.NoError(t, err, "dequeue reviews") var localIDs []uint64 for _, review := range reviews { From ecbed78fa38ceecf1471bb50be6771e343872b80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 11:58:36 +0200 Subject: [PATCH 32/34] fix: PrepareRaw empty ORDER BY and unbound limit/offset params - Skip ORDER BY clause when no sort columns are configured - Always inject pgx.NamedArgs for limit/offset even when args is empty --- page.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/page.go b/page.go index 7bf03e1..35621e7 100644 --- a/page.go +++ b/page.go @@ -259,18 +259,22 @@ func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, limit, offset := page.Limit(), page.Offset() - q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") + if order := p.getOrder(page); len(order) > 0 { + q = q + " ORDER BY " + strings.Join(order, ", ") + } q = q + " LIMIT @limit OFFSET @offset" - for i, arg := range args { + injected := false + for _, arg := range args { if existing, ok := arg.(pgx.NamedArgs); ok { existing["limit"] = limit + 1 existing["offset"] = offset + injected = true break } - if i == len(args)-1 { - args = append(args, pgx.NamedArgs{"limit": limit + 1, "offset": offset}) - } + } + if !injected { + args = append(args, pgx.NamedArgs{"limit": limit + 1, "offset": offset}) } return make([]T, 0, limit+1), q, args From 63e0dbc677545a80c4672e872bd4fcde9a41fcd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Fri, 3 Apr 2026 12:20:29 +0200 Subject: [PATCH 33/34] feat: Update, DeleteByID, HardDeleteByID return (bool, error) Return true when at least one row was affected, false when zero rows matched. Callers can now distinguish "success" from "not found" without reimplementing methods with RowsAffected checks. Breaking change: signatures go from error to (bool, error). --- table.go | 64 ++++++++++------- tests/table_test.go | 166 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 193 insertions(+), 37 deletions(-) diff --git a/table.go b/table.go index 1f6822d..a0b39c7 100644 --- a/table.go +++ b/table.go @@ -139,11 +139,11 @@ func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error { } // Update updates one or more records by their ID. Sets UpdatedAt timestamp if available. -// Returns an error if any record has a zero ID. -func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) error { +// Returns (true, nil) if at least one row was updated, (false, nil) if no rows matched. +func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) (bool, error) { switch len(records) { case 0: - return nil + return false, nil case 1: return t.updateOne(ctx, records[0]) default: @@ -151,18 +151,18 @@ func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) error { } } -func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) error { +func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) (bool, error) { if record == nil { - return fmt.Errorf("record is nil") + return false, fmt.Errorf("record is nil") } var zero I if record.GetID() == zero { - return fmt.Errorf("update record: ID is zero") + return false, fmt.Errorf("update record: ID is zero") } if err := record.Validate(); err != nil { - return fmt.Errorf("validate record: %w", err) + return false, fmt.Errorf("validate record: %w", err) } if row, ok := any(record).(HasSetUpdatedAt); ok { @@ -170,14 +170,15 @@ func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) error { } q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name) - if _, err := t.Query.Exec(ctx, q); err != nil { - return fmt.Errorf("update record: %w", err) + tag, err := t.Query.Exec(ctx, q) + if err != nil { + return false, fmt.Errorf("update record: %w", err) } - return nil + return tag.RowsAffected() > 0, nil } -func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error { +func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) (bool, error) { now := time.Now().UTC() queries := make(Queries, 0, len(records)) @@ -185,15 +186,15 @@ func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error { for i, r := range records { if r == nil { - return fmt.Errorf("record with index=%d is nil", i) + return false, fmt.Errorf("record with index=%d is nil", i) } if r.GetID() == zero { - return fmt.Errorf("update record with index=%d: ID is zero", i) + return false, fmt.Errorf("update record with index=%d: ID is zero", i) } if err := r.Validate(); err != nil { - return fmt.Errorf("validate record: %w", err) + return false, fmt.Errorf("validate record: %w", err) } if row, ok := any(r).(HasSetUpdatedAt); ok { @@ -206,13 +207,20 @@ func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error { ) } + var affected bool for chunk := range slices.Chunk(queries, chunkSize) { - if _, err := t.Query.BatchExec(ctx, chunk); err != nil { - return fmt.Errorf("update records: %w", err) + tags, err := t.Query.BatchExec(ctx, chunk) + if err != nil { + return false, fmt.Errorf("update records: %w", err) + } + for _, tag := range tags { + if tag.RowsAffected() > 0 { + affected = true + } } } - return nil + return affected, nil } // Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record. @@ -450,19 +458,23 @@ func (t *Table[T, P, I]) Count(ctx context.Context, where sq.Sqlizer) (uint64, e } // DeleteByID deletes a record by ID. Uses soft delete if .SetDeletedAt() method exists. -func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) error { +// Returns (true, nil) if a row was deleted, (false, nil) if no row matched. +func (t *Table[T, P, I]) DeleteByID(ctx context.Context, id I) (bool, error) { record, err := t.GetByID(ctx, id) if err != nil { - return fmt.Errorf("delete: %w", err) + if errors.Is(err, ErrNoRows) { + return false, nil + } + return false, fmt.Errorf("delete: %w", err) } // Soft delete. if row, ok := any(record).(HasSetDeletedAt); ok { row.SetDeletedAt(time.Now().UTC()) if err := t.Save(ctx, record); err != nil { - return fmt.Errorf("soft delete: %w", err) + return false, fmt.Errorf("soft delete: %w", err) } - return nil + return true, nil } // Hard delete for tables without timestamps. @@ -491,12 +503,14 @@ func (t *Table[T, P, I]) RestoreByID(ctx context.Context, id I) error { } // HardDeleteByID permanently deletes a record by ID. -func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) error { +// Returns (true, nil) if a row was deleted, (false, nil) if no row matched. +func (t *Table[T, P, I]) HardDeleteByID(ctx context.Context, id I) (bool, error) { q := t.SQL.Delete(t.Name).Where(sq.Eq{t.IDColumn: id}) - if _, err := t.Query.Exec(ctx, q); err != nil { - return fmt.Errorf("hard delete: %w", err) + tag, err := t.Query.Exec(ctx, q) + if err != nil { + return false, fmt.Errorf("hard delete: %w", err) } - return nil + return tag.RowsAffected() > 0, nil } // WithPaginator returns a table instance with the given paginator. diff --git a/tests/table_test.go b/tests/table_test.go index 128f616..52f0a77 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -174,8 +174,9 @@ func TestTable(t *testing.T) { } // Soft-delete first article. - err = tx.Articles.DeleteByID(ctx, firstArticle.ID) + ok, err := tx.Articles.DeleteByID(ctx, firstArticle.ID) require.NoError(t, err, "DeleteByID failed") + require.True(t, ok, "DeleteByID should return true for existing record") // Check if article is soft-deleted. article, err := tx.Articles.GetByID(ctx, firstArticle.ID) @@ -193,12 +194,14 @@ func TestTable(t *testing.T) { require.Nil(t, article.DeletedAt, "DeletedAt should be nil after restore") // Soft-delete again for the hard-delete test. - err = tx.Articles.DeleteByID(ctx, firstArticle.ID) + ok, err = tx.Articles.DeleteByID(ctx, firstArticle.ID) require.NoError(t, err, "DeleteByID failed") + require.True(t, ok, "DeleteByID should return true for existing record") // Hard-delete first article. - err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID) + ok, err = tx.Articles.HardDeleteByID(ctx, firstArticle.ID) require.NoError(t, err, "HardDeleteByID failed") + require.True(t, ok, "HardDeleteByID should return true for existing record") // Check if article is hard-deleted. article, err = tx.Articles.GetByID(ctx, firstArticle.ID) @@ -321,11 +324,13 @@ func TestTable(t *testing.T) { err = pgx.BeginFunc(ctx, db.Conn, func(pgTx pgx.Tx) error { txTable := db.Articles.Table.WithTx(pgTx) - if err := txTable.HardDeleteByID(ctx, article.ID); err != nil { + ok, err := txTable.HardDeleteByID(ctx, article.ID) + if err != nil { return err } + require.True(t, ok, "HardDeleteByID should return true for existing record") - _, err := txTable.GetByID(ctx, article.ID) + _, err = txTable.GetByID(ctx, article.ID) require.Error(t, err, "article should be deleted inside tx") return nil @@ -410,8 +415,9 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) account.Name = "After Update" - err = db.Accounts.Update(ctx, account) + ok, err := db.Accounts.Update(ctx, account) require.NoError(t, err) + require.True(t, ok, "should return true for existing record") got, err := db.Accounts.GetByID(ctx, account.ID) require.NoError(t, err) @@ -432,8 +438,9 @@ func TestUpdate(t *testing.T) { articles[0].Author = "Updated A" articles[1].Author = "Updated B" - err = db.Articles.Update(ctx, articles...) + ok, err := db.Articles.Update(ctx, articles...) require.NoError(t, err) + require.True(t, ok, "should return true when records exist") for _, a := range articles { got, err := db.Articles.GetByID(ctx, a.ID) @@ -443,7 +450,7 @@ func TestUpdate(t *testing.T) { }) t.Run("Update with zero ID fails", func(t *testing.T) { - err := db.Accounts.Update(ctx, &Account{Name: "No ID"}) + _, err := db.Accounts.Update(ctx, &Account{Name: "No ID"}) require.Error(t, err, "should fail with zero ID") }) @@ -452,12 +459,12 @@ func TestUpdate(t *testing.T) { err := db.Accounts.Insert(ctx, account) require.NoError(t, err) - err = db.Accounts.Update(ctx, account, &Account{Name: "No ID"}) + _, err = db.Accounts.Update(ctx, account, &Account{Name: "No ID"}) require.Error(t, err, "should fail when any record has zero ID") }) t.Run("Update nil record", func(t *testing.T) { - err := db.Accounts.Update(ctx, nil) + _, err := db.Accounts.Update(ctx, nil) require.Error(t, err) }) @@ -467,13 +474,148 @@ func TestUpdate(t *testing.T) { require.NoError(t, err) account.Name = "" - err = db.Accounts.Update(ctx, account) + _, err = db.Accounts.Update(ctx, account) require.Error(t, err, "should fail validation") }) t.Run("Update zero records", func(t *testing.T) { - err := db.Accounts.Update(ctx) + ok, err := db.Accounts.Update(ctx) require.NoError(t, err, "updating zero records should be a no-op") + require.False(t, ok, "should return false for zero records") + }) + + t.Run("Update non-existent record returns false", func(t *testing.T) { + account := &Account{ID: 999999, Name: "Ghost"} + ok, err := db.Accounts.Update(ctx, account) + require.NoError(t, err) + require.False(t, ok, "should return false for non-existent record") + }) +} + +func TestDeleteByID(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + account := &Account{Name: "DeleteByID Account"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + t.Run("soft delete existing returns true", func(t *testing.T) { + article := &Article{Author: "Author", AccountID: account.ID} + err := db.Articles.Insert(ctx, article) + require.NoError(t, err) + + ok, err := db.Articles.DeleteByID(ctx, article.ID) + require.NoError(t, err) + require.True(t, ok) + + got, err := db.Articles.GetByID(ctx, article.ID) + require.NoError(t, err) + require.NotNil(t, got.DeletedAt) + }) + + t.Run("soft delete missing returns false", func(t *testing.T) { + ok, err := db.Articles.DeleteByID(ctx, 999999) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("soft delete already-deleted returns false", func(t *testing.T) { + article := &Article{Author: "Double Delete", AccountID: account.ID} + err := db.Articles.Insert(ctx, article) + require.NoError(t, err) + + ok, err := db.Articles.DeleteByID(ctx, article.ID) + require.NoError(t, err) + require.True(t, ok) + + // Hard-delete so GetByID returns ErrNoRows. + _, err = db.Articles.HardDeleteByID(ctx, article.ID) + require.NoError(t, err) + + ok, err = db.Articles.DeleteByID(ctx, article.ID) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("hard delete existing returns true", func(t *testing.T) { + // Account has no SetDeletedAt — DeleteByID falls through to hard delete. + acct := &Account{Name: "HardDelete via DeleteByID"} + err := db.Accounts.Insert(ctx, acct) + require.NoError(t, err) + + ok, err := db.Accounts.DeleteByID(ctx, acct.ID) + require.NoError(t, err) + require.True(t, ok) + + _, err = db.Accounts.GetByID(ctx, acct.ID) + require.Error(t, err) + }) + + t.Run("hard delete missing returns false", func(t *testing.T) { + ok, err := db.Accounts.DeleteByID(ctx, 999999) + require.NoError(t, err) + require.False(t, ok) + }) +} + +func TestHardDeleteByID(t *testing.T) { + truncateAllTables(t) + + ctx := t.Context() + db := initDB(DB) + + t.Run("existing record returns true", func(t *testing.T) { + account := &Account{Name: "HardDelete Found"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + ok, err := db.Accounts.HardDeleteByID(ctx, account.ID) + require.NoError(t, err) + require.True(t, ok) + + _, err = db.Accounts.GetByID(ctx, account.ID) + require.Error(t, err) + }) + + t.Run("missing record returns false", func(t *testing.T) { + ok, err := db.Accounts.HardDeleteByID(ctx, 999999) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("hard delete on soft-deletable record", func(t *testing.T) { + account := &Account{Name: "HardDelete Soft-Deletable"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + article := &Article{Author: "Soft-Deletable", AccountID: account.ID} + err = db.Articles.Insert(ctx, article) + require.NoError(t, err) + + // HardDeleteByID bypasses soft delete even on soft-deletable records. + ok, err := db.Articles.HardDeleteByID(ctx, article.ID) + require.NoError(t, err) + require.True(t, ok) + + _, err = db.Articles.GetByID(ctx, article.ID) + require.Error(t, err, "should be permanently deleted") + }) + + t.Run("double hard delete returns false on second call", func(t *testing.T) { + account := &Account{Name: "Double HardDelete"} + err := db.Accounts.Insert(ctx, account) + require.NoError(t, err) + + ok, err := db.Accounts.HardDeleteByID(ctx, account.ID) + require.NoError(t, err) + require.True(t, ok) + + ok, err = db.Accounts.HardDeleteByID(ctx, account.ID) + require.NoError(t, err) + require.False(t, ok) }) } From 0cbca295f79c79566581d8711a11ac537556df6e Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Fri, 3 Apr 2026 14:41:52 +0200 Subject: [PATCH 34/34] NOTICE: Experimental. Table and its methods are subject to change. --- table.go | 1 + 1 file changed, 1 insertion(+) diff --git a/table.go b/table.go index a0b39c7..e0d591e 100644 --- a/table.go +++ b/table.go @@ -23,6 +23,7 @@ type Record[T any, I ID] interface { } // Table provides basic CRUD operations for database records. +// NOTICE: Experimental. Table and its methods are subject to change. type Table[T any, P Record[T, I], I ID] struct { *DB Name string