Skip to content

Commit 7ce075c

Browse files
committed
add StmtWhereIn
1 parent c289693 commit 7ce075c

5 files changed

Lines changed: 116 additions & 23 deletions

File tree

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ if err := qsql.QueryElem(mdb, &count, "SELECT count(*) FROM a WHERE id = ?", id)
181181
// ...
182182
}
183183
```
184+
## Extend the where in stmt
185+
// Example for the first input:
186+
mdb := db.GetCache("master")
187+
args:=[]int{1,2,3}
188+
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))))
189+
// Or
190+
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL))
191+
192+
// Example for the second input:
193+
mdb.Query(fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args)))
184194

185195
## Mass query.
186196
```text

api.go

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package qsql
66
import (
77
"context"
88
"database/sql"
9+
"fmt"
910
"io"
1011
"runtime/debug"
1112

@@ -19,6 +20,9 @@ const (
1920
DRV_NAME_POSTGRES = "postgres"
2021
DRV_NAME_SQLITE3 = "sqlite3"
2122
DRV_NAME_SQLSERVER = "sqlserver" // or "mssql"
23+
24+
_DRV_NAME_OCI8 = "oci8"
25+
_DRV_NAME_MSSQL = "mssql"
2226
)
2327

2428
var (
@@ -31,6 +35,71 @@ var (
3135
REFLECT_DRV_NAME = DRV_NAME_MYSQL
3236
)
3337

38+
func getDrvName(exec Execer, driverName ...string) string {
39+
drvName := REFLECT_DRV_NAME
40+
db, ok := exec.(*DB)
41+
if ok {
42+
drvName = db.DriverName()
43+
} else {
44+
drvNamesLen := len(driverName)
45+
if drvNamesLen > 0 {
46+
if drvNamesLen != 1 {
47+
panic(errors.New("'drvName' expect only one argument").As(driverName))
48+
}
49+
drvName = driverName[0]
50+
}
51+
}
52+
return drvName
53+
}
54+
55+
// Extend the where in stmt
56+
//
57+
// Example for the first input:
58+
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))
59+
// Or
60+
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL)
61+
//
62+
// Example for the second input:
63+
// fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args))
64+
//
65+
func StmtWhereIn(paramIdx, paramsLen int, driverName ...string) string {
66+
drvName := getDrvName(nil, driverName...)
67+
switch drvName {
68+
case DRV_NAME_ORACLE, _DRV_NAME_OCI8:
69+
// *outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
70+
panic("unknow how to implemented")
71+
case DRV_NAME_POSTGRES:
72+
result := []byte{}
73+
for i := 0; i < paramsLen; i++ {
74+
result = append(result, []byte(fmt.Sprintf(":%d,", paramIdx+i))...)
75+
}
76+
if len(result) > 0 {
77+
return string(result[:len(result)-1]) // remove the last ','
78+
}
79+
return string(result)
80+
case DRV_NAME_SQLSERVER, _DRV_NAME_MSSQL:
81+
result := []byte{}
82+
for i := 0; i < paramsLen; i++ {
83+
result = append(result, []byte(fmt.Sprintf("@p%d,", paramIdx+i))...)
84+
}
85+
if len(result) > 0 {
86+
return string(result[:len(result)-1]) // remove the last ','
87+
}
88+
return string(result)
89+
default:
90+
resultLen := paramsLen * 2
91+
result := make([]byte, resultLen)
92+
for i := 0; i < resultLen; i += 2 {
93+
result[i] = '?'
94+
result[i+1] = ','
95+
}
96+
if len(result) > 0 {
97+
return string(result[:len(result)-1]) // remove the last ','
98+
}
99+
return string(result)
100+
}
101+
}
102+
34103
type Execer interface {
35104
Exec(query string, args ...interface{}) (sql.Result, error)
36105
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
@@ -129,11 +198,11 @@ func ExecMultiTxContext(tx *sql.Tx, ctx context.Context, mTx []*MultiTx) error {
129198

130199
// Reflect one db data to the struct. the struct tag format like `db:"field_title"`, reference to: http://github.com/jmoiron/sqlx
131200
// When you no set the REFLECT_DRV_NAME, you can point out with the drvName
132-
func InsertStruct(exec Execer, obj interface{}, tbName string, drvNames ...string) (sql.Result, error) {
133-
return insertStruct(exec, context.TODO(), obj, tbName, drvNames...)
201+
func InsertStruct(exec Execer, obj interface{}, tbName string, drvName ...string) (sql.Result, error) {
202+
return insertStruct(exec, context.TODO(), obj, tbName, drvName...)
134203
}
135-
func InsertStructContext(exec Execer, ctx context.Context, obj interface{}, tbName string, drvNames ...string) (sql.Result, error) {
136-
return insertStruct(exec, ctx, obj, tbName, drvNames...)
204+
func InsertStructContext(exec Execer, ctx context.Context, obj interface{}, tbName string, drvName ...string) (sql.Result, error) {
205+
return insertStruct(exec, ctx, obj, tbName, drvName...)
137206
}
138207

139208
// A sql.Query implements

api_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package qsql
2+
3+
import "testing"
4+
5+
func TestStmtWhereIn(t *testing.T) {
6+
sqlite3Output := StmtWhereIn(0, 3, DRV_NAME_SQLITE3)
7+
if sqlite3Output != "?,?,?" {
8+
t.Fatalf("expect '?,?,?', but: %s", sqlite3Output)
9+
}
10+
pgOutput := StmtWhereIn(0, 3, DRV_NAME_POSTGRES)
11+
if pgOutput != ":0,:1,:2" {
12+
t.Fatalf("expect ':0,:1,:2', but: %s", pgOutput)
13+
}
14+
pgOutput1 := StmtWhereIn(1, 3, DRV_NAME_POSTGRES)
15+
if pgOutput1 != ":1,:2,:3" {
16+
t.Fatalf("expect ':1,:2,:3', but: %s", pgOutput1)
17+
}
18+
msOutput := StmtWhereIn(0, 3, DRV_NAME_SQLSERVER)
19+
if msOutput != "@p0,@p1,@p2" {
20+
t.Fatalf("expect '@p0,@p1,@p2', but@ %s", msOutput)
21+
}
22+
msOutput1 := StmtWhereIn(1, 3, DRV_NAME_SQLSERVER)
23+
if msOutput1 != "@p1,@p2,@p3" {
24+
t.Fatalf("expect '@p1,@p2,@p3', but@ %s", msOutput1)
25+
}
26+
}

qsql.go

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,8 @@ const (
2525

2626
// field flag like: `db:"name"`
2727
// more: github.com/jmoiron/sqlx
28-
func insertStruct(exec Execer, ctx context.Context, obj interface{}, tbName string, drvNames ...string) (sql.Result, error) {
29-
drvName := REFLECT_DRV_NAME
30-
db, ok := exec.(*DB)
31-
if ok {
32-
drvName = db.DriverName()
33-
} else {
34-
drvNamesLen := len(drvNames)
35-
if drvNamesLen > 0 {
36-
if drvNamesLen != 0 {
37-
panic(errors.New("'drvNames' expect only one argument").As(drvNames))
38-
}
39-
drvName = drvNames[0]
40-
}
41-
}
28+
func insertStruct(exec Execer, ctx context.Context, obj interface{}, tbName string, driverName ...string) (sql.Result, error) {
29+
drvName := getDrvName(exec, driverName...)
4230

4331
fields, err := reflectInsertStruct(obj, drvName)
4432
if err != nil {

struct.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,20 +206,20 @@ func travelStructField(f *reflectx.FieldInfo, v *reflect.Value, order *int, drvN
206206
}
207207

208208
*outputVals = append(*outputVals, v.Interface())
209-
switch {
210-
case strings.Index(*drvName, "oracle") > -1, strings.Index(*drvName, "oci8") > -1:
209+
switch *drvName {
210+
case DRV_NAME_ORACLE, _DRV_NAME_OCI8:
211211
*order += 1
212212
*outputNames = append(*outputNames, []byte("\""+f.Name+"\",")...)
213213
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
214-
case strings.Index(*drvName, "postgres") > -1:
214+
case DRV_NAME_POSTGRES:
215215
*outputNames = append(*outputNames, []byte("\""+f.Name+"\",")...)
216216
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%d,", *order))...)
217217
*order += 1
218-
case strings.Index(*drvName, "sqlserver") > -1, strings.Index(*drvName, "mssql") > -1:
218+
case DRV_NAME_SQLSERVER, _DRV_NAME_MSSQL:
219219
*outputNames = append(*outputNames, []byte("["+f.Name+"],")...)
220220
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf("@p%d,", *order))...)
221221
*order += 1
222-
case strings.Index(*drvName, "mysql") > -1:
222+
case DRV_NAME_MYSQL:
223223
*order += 1
224224
*outputNames = append(*outputNames, []byte("`"+f.Name+"`,")...)
225225
*outputInputs = append(*outputInputs, []byte("?,")...)

0 commit comments

Comments
 (0)