Skip to content

Commit f74541e

Browse files
committed
feat: add Insert and Update methods in Table
1 parent cf0cd56 commit f74541e

2 files changed

Lines changed: 299 additions & 0 deletions

File tree

table.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,165 @@ type (
4343
}
4444
)
4545

46+
// Insert inserts one or more records. Sets CreatedAt and UpdatedAt timestamps if available.
47+
// Records are returned with their generated fields populated via RETURNING *.
48+
func (t *Table[T, P, I]) Insert(ctx context.Context, records ...P) error {
49+
switch len(records) {
50+
case 0:
51+
return nil
52+
case 1:
53+
return t.insertOne(ctx, records[0])
54+
default:
55+
return t.insertAll(ctx, records)
56+
}
57+
}
58+
59+
func (t *Table[T, P, I]) insertOne(ctx context.Context, record P) error {
60+
if record == nil {
61+
return fmt.Errorf("record is nil")
62+
}
63+
64+
if err := record.Validate(); err != nil {
65+
return fmt.Errorf("validate record: %w", err)
66+
}
67+
68+
now := time.Now().UTC()
69+
if row, ok := any(record).(hasSetCreatedAt); ok {
70+
row.SetCreatedAt(now)
71+
}
72+
if row, ok := any(record).(hasSetUpdatedAt); ok {
73+
row.SetUpdatedAt(now)
74+
}
75+
76+
q := t.SQL.
77+
InsertRecord(record).
78+
Into(t.Name).
79+
Suffix("RETURNING *")
80+
81+
if err := t.Query.GetOne(ctx, q, record); err != nil {
82+
return fmt.Errorf("insert record: %w", err)
83+
}
84+
85+
return nil
86+
}
87+
88+
func (t *Table[T, P, I]) insertAll(ctx context.Context, records []P) error {
89+
now := time.Now().UTC()
90+
91+
for i, r := range records {
92+
if r == nil {
93+
return fmt.Errorf("record with index=%d is nil", i)
94+
}
95+
96+
if err := r.Validate(); err != nil {
97+
return fmt.Errorf("validate record: %w", err)
98+
}
99+
100+
if row, ok := any(r).(hasSetCreatedAt); ok {
101+
row.SetCreatedAt(now)
102+
}
103+
if row, ok := any(r).(hasSetUpdatedAt); ok {
104+
row.SetUpdatedAt(now)
105+
}
106+
}
107+
108+
for start := 0; start < len(records); start += chunkSize {
109+
end := start + chunkSize
110+
if end > len(records) {
111+
end = len(records)
112+
}
113+
114+
chunk := records[start:end]
115+
q := t.SQL.
116+
InsertRecords(chunk).
117+
Into(t.Name).
118+
SuffixExpr(sq.Expr(" RETURNING *"))
119+
120+
if err := t.Query.GetAll(ctx, q, &chunk); err != nil {
121+
return fmt.Errorf("insert records: %w", err)
122+
}
123+
}
124+
125+
return nil
126+
}
127+
128+
// Update updates one or more records by their ID. Sets UpdatedAt timestamp if available.
129+
// Returns an error if any record has a zero ID.
130+
func (t *Table[T, P, I]) Update(ctx context.Context, records ...P) error {
131+
switch len(records) {
132+
case 0:
133+
return nil
134+
case 1:
135+
return t.updateOne(ctx, records[0])
136+
default:
137+
return t.updateAll(ctx, records)
138+
}
139+
}
140+
141+
func (t *Table[T, P, I]) updateOne(ctx context.Context, record P) error {
142+
if record == nil {
143+
return fmt.Errorf("record is nil")
144+
}
145+
146+
var zero I
147+
if record.GetID() == zero {
148+
return fmt.Errorf("update record: ID is zero")
149+
}
150+
151+
if err := record.Validate(); err != nil {
152+
return fmt.Errorf("validate record: %w", err)
153+
}
154+
155+
if row, ok := any(record).(hasSetUpdatedAt); ok {
156+
row.SetUpdatedAt(time.Now().UTC())
157+
}
158+
159+
q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name)
160+
if _, err := t.Query.Exec(ctx, q); err != nil {
161+
return fmt.Errorf("update record: %w", err)
162+
}
163+
164+
return nil
165+
}
166+
167+
func (t *Table[T, P, I]) updateAll(ctx context.Context, records []P) error {
168+
now := time.Now().UTC()
169+
170+
queries := make(Queries, 0, len(records))
171+
var zero I
172+
173+
for i, r := range records {
174+
if r == nil {
175+
return fmt.Errorf("record with index=%d is nil", i)
176+
}
177+
178+
if r.GetID() == zero {
179+
return fmt.Errorf("update record with index=%d: ID is zero", i)
180+
}
181+
182+
if err := r.Validate(); err != nil {
183+
return fmt.Errorf("validate record: %w", err)
184+
}
185+
186+
if row, ok := any(r).(hasSetUpdatedAt); ok {
187+
row.SetUpdatedAt(now)
188+
}
189+
190+
queries.Add(t.SQL.
191+
UpdateRecord(r, sq.Eq{t.IDColumn: r.GetID()}, t.Name).
192+
SuffixExpr(sq.Expr(" RETURNING *")),
193+
)
194+
}
195+
196+
for chunk := range slices.Chunk(queries, chunkSize) {
197+
if _, err := t.Query.BatchExec(ctx, chunk); err != nil {
198+
return fmt.Errorf("update records: %w", err)
199+
}
200+
}
201+
202+
return nil
203+
}
204+
46205
// Save inserts or updates given records. Auto-detects insert vs update by ID based on zerovalue of ID from GetID() method on record.
47206
func (t *Table[T, P, I]) Save(ctx context.Context, records ...P) error {
48207
switch len(records) {

tests/table_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,146 @@ func TestTable(t *testing.T) {
323323
})
324324
}
325325

326+
func TestInsert(t *testing.T) {
327+
truncateAllTables(t)
328+
329+
ctx := t.Context()
330+
db := initDB(DB)
331+
332+
t.Run("Insert single", func(t *testing.T) {
333+
account := &Account{Name: "Insert Account"}
334+
err := db.Accounts.Insert(ctx, account)
335+
require.NoError(t, err)
336+
require.NotZero(t, account.ID, "ID should be set after insert")
337+
require.NotZero(t, account.UpdatedAt, "UpdatedAt should be set")
338+
339+
// Verify in DB.
340+
got, err := db.Accounts.GetByID(ctx, account.ID)
341+
require.NoError(t, err)
342+
require.Equal(t, account.Name, got.Name)
343+
})
344+
345+
t.Run("Insert multiple", func(t *testing.T) {
346+
account := &Account{Name: "Insert Multiple Account"}
347+
err := db.Accounts.Insert(ctx, account)
348+
require.NoError(t, err)
349+
350+
articles := []*Article{
351+
{Author: "Author A", AccountID: account.ID},
352+
{Author: "Author B", AccountID: account.ID},
353+
{Author: "Author C", AccountID: account.ID},
354+
}
355+
err = db.Articles.Insert(ctx, articles...)
356+
require.NoError(t, err)
357+
358+
for _, a := range articles {
359+
require.NotZero(t, a.ID, "ID should be set after bulk insert")
360+
require.NotZero(t, a.UpdatedAt, "UpdatedAt should be set")
361+
}
362+
363+
// Verify all in DB.
364+
for _, a := range articles {
365+
got, err := db.Articles.GetByID(ctx, a.ID)
366+
require.NoError(t, err)
367+
require.Equal(t, a.Author, got.Author)
368+
}
369+
})
370+
371+
t.Run("Insert nil record", func(t *testing.T) {
372+
err := db.Accounts.Insert(ctx, nil)
373+
require.Error(t, err)
374+
})
375+
376+
t.Run("Insert invalid record", func(t *testing.T) {
377+
err := db.Accounts.Insert(ctx, &Account{Name: ""})
378+
require.Error(t, err, "should fail validation")
379+
})
380+
381+
t.Run("Insert zero records", func(t *testing.T) {
382+
err := db.Accounts.Insert(ctx)
383+
require.NoError(t, err, "inserting zero records should be a no-op")
384+
})
385+
}
386+
387+
func TestUpdate(t *testing.T) {
388+
truncateAllTables(t)
389+
390+
ctx := t.Context()
391+
db := initDB(DB)
392+
393+
t.Run("Update single", func(t *testing.T) {
394+
account := &Account{Name: "Before Update"}
395+
err := db.Accounts.Insert(ctx, account)
396+
require.NoError(t, err)
397+
398+
account.Name = "After Update"
399+
err = db.Accounts.Update(ctx, account)
400+
require.NoError(t, err)
401+
402+
got, err := db.Accounts.GetByID(ctx, account.ID)
403+
require.NoError(t, err)
404+
require.Equal(t, "After Update", got.Name)
405+
})
406+
407+
t.Run("Update multiple", func(t *testing.T) {
408+
account := &Account{Name: "Update Multiple Account"}
409+
err := db.Accounts.Insert(ctx, account)
410+
require.NoError(t, err)
411+
412+
articles := []*Article{
413+
{Author: "Original A", AccountID: account.ID},
414+
{Author: "Original B", AccountID: account.ID},
415+
}
416+
err = db.Articles.Insert(ctx, articles...)
417+
require.NoError(t, err)
418+
419+
articles[0].Author = "Updated A"
420+
articles[1].Author = "Updated B"
421+
err = db.Articles.Update(ctx, articles...)
422+
require.NoError(t, err)
423+
424+
for _, a := range articles {
425+
got, err := db.Articles.GetByID(ctx, a.ID)
426+
require.NoError(t, err)
427+
require.Equal(t, a.Author, got.Author)
428+
}
429+
})
430+
431+
t.Run("Update with zero ID fails", func(t *testing.T) {
432+
err := db.Accounts.Update(ctx, &Account{Name: "No ID"})
433+
require.Error(t, err, "should fail with zero ID")
434+
})
435+
436+
t.Run("Update multiple with zero ID fails", func(t *testing.T) {
437+
account := &Account{Name: "Update Zero ID Account"}
438+
err := db.Accounts.Insert(ctx, account)
439+
require.NoError(t, err)
440+
441+
err = db.Accounts.Update(ctx, account, &Account{Name: "No ID"})
442+
require.Error(t, err, "should fail when any record has zero ID")
443+
})
444+
445+
t.Run("Update nil record", func(t *testing.T) {
446+
err := db.Accounts.Update(ctx, nil)
447+
require.Error(t, err)
448+
})
449+
450+
t.Run("Update invalid record", func(t *testing.T) {
451+
account := &Account{Name: "Valid"}
452+
err := db.Accounts.Insert(ctx, account)
453+
require.NoError(t, err)
454+
455+
account.Name = ""
456+
err = db.Accounts.Update(ctx, account)
457+
require.Error(t, err, "should fail validation")
458+
})
459+
460+
t.Run("Update zero records", func(t *testing.T) {
461+
err := db.Accounts.Update(ctx)
462+
require.NoError(t, err, "updating zero records should be a no-op")
463+
})
464+
}
465+
326466
func TestLockForUpdates(t *testing.T) {
327467
truncateAllTables(t)
328468

0 commit comments

Comments
 (0)