Skip to content

Commit 2f07bba

Browse files
authored
Merge pull request #6 from Vinovest/mj/single-row
Validate single row count for Row.Scan
2 parents 5b23c20 + 329a30d commit 2f07bba

3 files changed

Lines changed: 62 additions & 9 deletions

File tree

named_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func TestNamedQueries(t *testing.T) {
276276

277277
for _, p := range sls {
278278
dest := Person{}
279-
err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=?"), p.Email)
279+
err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=? limit 1"), p.Email)
280280
test.Error(err)
281281
if dest.Email != p.Email {
282282
t.Errorf("expected %s, got %s", p.Email, dest.Email)

sqlx.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ import (
1515
"github.com/vinovest/sqlx/reflectx"
1616
)
1717

18+
// ErrMultiRows is returned by functions which are expected to work with result sets
19+
// that only contain a single row but multiple rows were returned.
20+
// This typically indicates an issue with the query such as a missing join criteria or
21+
// limit condition or the use of Get(...) when Select(...) was intended.
22+
var ErrMultiRows = errors.New("sql: multiple rows returned")
23+
1824
// Although the NameMapper is convenient, in practice it should not
1925
// be relied on except for application code. If you are writing a library
2026
// that uses sqlx, you should be aware that the name mappings you expect
@@ -175,6 +181,7 @@ type Row struct {
175181

176182
// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
177183
// underlying error from the internal rows object if it exists.
184+
// Returns ErrMultiRows if the result set contains more than one row.
178185
func (r *Row) Scan(dest ...interface{}) error {
179186
if r.err != nil {
180187
return r.err
@@ -206,10 +213,16 @@ func (r *Row) Scan(dest ...interface{}) error {
206213
}
207214
return sql.ErrNoRows
208215
}
209-
err := r.rows.Scan(dest...)
210-
if err != nil {
216+
if err := r.rows.Scan(dest...); err != nil {
217+
return err
218+
}
219+
220+
if r.rows.Next() {
221+
return ErrMultiRows
222+
} else if err := r.rows.Err(); err != nil {
211223
return err
212224
}
225+
213226
// Make sure the query can be processed to completion with no errors.
214227
if err := r.rows.Close(); err != nil {
215228
return err
@@ -352,7 +365,7 @@ func (db *DB) Select(dest interface{}, query string, args ...interface{}) error
352365

353366
// Get using this DB.
354367
// Any placeholder parameters are replaced with supplied args.
355-
// An error is returned if the result set is empty.
368+
// An error is returned if the result set is empty or contains more than one row.
356369
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
357370
return Get(db, dest, query, args...)
358371
}
@@ -483,7 +496,7 @@ func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
483496

484497
// Get within a transaction.
485498
// Any placeholder parameters are replaced with supplied args.
486-
// An error is returned if the result set is empty.
499+
// An error is returned if the result set is empty or contains more than one row.
487500
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
488501
return Get(tx, dest, query, args...)
489502
}
@@ -551,7 +564,7 @@ func (s *Stmt) Select(dest interface{}, args ...interface{}) error {
551564

552565
// Get using the prepared statement.
553566
// Any placeholder parameters are replaced with supplied args.
554-
// An error is returned if the result set is empty.
567+
// An error is returned if the result set is empty or contains more than one row.
555568
func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
556569
return Get(&qStmt{s}, dest, "", args...)
557570
}
@@ -727,7 +740,7 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro
727740
// to dest. If dest is scannable, the result must only have one column. Otherwise,
728741
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
729742
// Any placeholder parameters are replaced with supplied args.
730-
// An error is returned if the result set is empty.
743+
// An error is returned if the result set is empty or contains more than one row.
731744
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
732745
r := q.QueryRowx(query, args...)
733746
return r.scanAny(dest, false)

sqlx_test.go

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func ConnectAll() {
7474
if TestMysql {
7575
mysqldb, err = Connect("mysql", mydsn)
7676
if err != nil {
77-
fmt.Printf("Disabling MySQL tests:\n %v", err)
77+
fmt.Printf("Disabling MySQL tests:\n %v\n", err)
7878
TestMysql = false
7979
}
8080
} else {
@@ -84,7 +84,7 @@ func ConnectAll() {
8484
if TestSqlite {
8585
sldb, err = Connect("sqlite3", sqdsn)
8686
if err != nil {
87-
fmt.Printf("Disabling SQLite:\n %v", err)
87+
fmt.Printf("Disabling SQLite:\n %v\n", err)
8888
TestSqlite = false
8989
}
9090
} else {
@@ -1728,6 +1728,46 @@ func TestEmbeddedLiterals(t *testing.T) {
17281728
})
17291729
}
17301730

1731+
// TestGet tests to ensure that Get behaves correctly for
1732+
// single row and multi row results.
1733+
func TestGet(t *testing.T) {
1734+
var schema = Schema{
1735+
create: `CREATE TABLE tst (v integer);`,
1736+
drop: `drop table tst;`,
1737+
}
1738+
1739+
RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) {
1740+
for _, v := range []int{1, 2} {
1741+
_, err := db.Exec(db.Rebind("INSERT INTO tst (v) VALUES (?)"), v)
1742+
if err != nil {
1743+
t.Error(err)
1744+
}
1745+
}
1746+
1747+
tests := []struct {
1748+
name string
1749+
val int
1750+
err bool
1751+
}{
1752+
{"multi-rows", 1, true},
1753+
{"single-row", 2, false},
1754+
}
1755+
for _, tc := range tests {
1756+
t.Run(tc.name, func(t *testing.T) {
1757+
var v int
1758+
err := db.Get(&v, db.Rebind("SELECT v FROM tst WHERE v >= ?"), tc.val)
1759+
if tc.err {
1760+
if err == nil {
1761+
t.Error("expected error but got nil")
1762+
}
1763+
} else if err != nil {
1764+
t.Error("unexpected error:", err)
1765+
}
1766+
})
1767+
}
1768+
})
1769+
}
1770+
17311771
func BenchmarkBindStruct(b *testing.B) {
17321772
b.StopTimer()
17331773
q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`

0 commit comments

Comments
 (0)