Skip to content

Commit 390e355

Browse files
committed
fix stmt where in
1 parent 7ce075c commit 390e355

6 files changed

Lines changed: 136 additions & 90 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ if err := qsql.QueryElem(mdb, &count, "SELECT count(*) FROM a WHERE id = ?", id)
185185
// Example for the first input:
186186
mdb := db.GetCache("master")
187187
args:=[]int{1,2,3}
188-
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))))
188+
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))), qsql.SliceToArgs(args)...)
189189
// Or
190-
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL))
190+
mdb.Query(fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL), qsql.SliceToArgs(args)...)
191191

192192
// Example for the second input:
193-
mdb.Query(fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args)))
193+
mdb.Query(fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args)), id, qsql.SliceToArgs(args)...)
194194

195195
## Mass query.
196196
```text

api.go

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,100 +6,13 @@ package qsql
66
import (
77
"context"
88
"database/sql"
9-
"fmt"
109
"io"
1110
"runtime/debug"
1211

1312
"github.com/gwaylib/errors"
1413
"github.com/gwaylib/log"
1514
)
1615

17-
const (
18-
DRV_NAME_MYSQL = "mysql"
19-
DRV_NAME_ORACLE = "oracle" // or "oci8"
20-
DRV_NAME_POSTGRES = "postgres"
21-
DRV_NAME_SQLITE3 = "sqlite3"
22-
DRV_NAME_SQLSERVER = "sqlserver" // or "mssql"
23-
24-
_DRV_NAME_OCI8 = "oci8"
25-
_DRV_NAME_MSSQL = "mssql"
26-
)
27-
28-
var (
29-
// Whe reflect the QueryStruct, InsertStruct, it need set the Driver first.
30-
// For example:
31-
// func init(){
32-
// qsql.REFLECT_DRV_NAME = qsql.DEV_NAME_SQLITE3
33-
// }
34-
// Default is using the mysql driver.
35-
REFLECT_DRV_NAME = DRV_NAME_MYSQL
36-
)
37-
38-
func getDrvName(exec Execer, driverName ...string) string {
39-
drvName := REFLECT_DRV_NAME
40-
db, ok := exec.(*DB)
41-
if ok {
42-
drvName = db.DriverName()
43-
} else {
44-
drvNamesLen := len(driverName)
45-
if drvNamesLen > 0 {
46-
if drvNamesLen != 1 {
47-
panic(errors.New("'drvName' expect only one argument").As(driverName))
48-
}
49-
drvName = driverName[0]
50-
}
51-
}
52-
return drvName
53-
}
54-
55-
// Extend the where in stmt
56-
//
57-
// Example for the first input:
58-
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))
59-
// Or
60-
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL)
61-
//
62-
// Example for the second input:
63-
// fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args))
64-
//
65-
func StmtWhereIn(paramIdx, paramsLen int, driverName ...string) string {
66-
drvName := getDrvName(nil, driverName...)
67-
switch drvName {
68-
case DRV_NAME_ORACLE, _DRV_NAME_OCI8:
69-
// *outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
70-
panic("unknow how to implemented")
71-
case DRV_NAME_POSTGRES:
72-
result := []byte{}
73-
for i := 0; i < paramsLen; i++ {
74-
result = append(result, []byte(fmt.Sprintf(":%d,", paramIdx+i))...)
75-
}
76-
if len(result) > 0 {
77-
return string(result[:len(result)-1]) // remove the last ','
78-
}
79-
return string(result)
80-
case DRV_NAME_SQLSERVER, _DRV_NAME_MSSQL:
81-
result := []byte{}
82-
for i := 0; i < paramsLen; i++ {
83-
result = append(result, []byte(fmt.Sprintf("@p%d,", paramIdx+i))...)
84-
}
85-
if len(result) > 0 {
86-
return string(result[:len(result)-1]) // remove the last ','
87-
}
88-
return string(result)
89-
default:
90-
resultLen := paramsLen * 2
91-
result := make([]byte, resultLen)
92-
for i := 0; i < resultLen; i += 2 {
93-
result[i] = '?'
94-
result[i+1] = ','
95-
}
96-
if len(result) > 0 {
97-
return string(result[:len(result)-1]) // remove the last ','
98-
}
99-
return string(result)
100-
}
101-
}
102-
10316
type Execer interface {
10417
Exec(query string, args ...interface{}) (sql.Result, error)
10518
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

drv.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package qsql
2+
3+
import "github.com/gwaylib/errors"
4+
5+
const (
6+
DRV_NAME_MYSQL = "mysql"
7+
DRV_NAME_ORACLE = "oracle" // or "oci8"
8+
DRV_NAME_POSTGRES = "postgres"
9+
DRV_NAME_SQLITE3 = "sqlite3"
10+
DRV_NAME_SQLSERVER = "sqlserver" // or "mssql"
11+
12+
_DRV_NAME_OCI8 = "oci8"
13+
_DRV_NAME_MSSQL = "mssql"
14+
)
15+
16+
var (
17+
// Whe reflect the QueryStruct, InsertStruct, it need set the Driver first.
18+
// For example:
19+
// func init(){
20+
// qsql.REFLECT_DRV_NAME = qsql.DEV_NAME_SQLITE3
21+
// }
22+
// Default is using the mysql driver.
23+
REFLECT_DRV_NAME = DRV_NAME_MYSQL
24+
)
25+
26+
func getDrvName(exec Execer, driverName ...string) string {
27+
drvName := REFLECT_DRV_NAME
28+
db, ok := exec.(*DB)
29+
if ok {
30+
drvName = db.DriverName()
31+
} else {
32+
drvNamesLen := len(driverName)
33+
if drvNamesLen > 0 {
34+
if drvNamesLen != 1 {
35+
panic(errors.New("'drvName' expect only one argument").As(driverName))
36+
}
37+
drvName = driverName[0]
38+
}
39+
}
40+
return drvName
41+
}

example/qsql.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ func main() {
7272
if len(users) != 2 {
7373
panic("expect len==2")
7474
}
75+
7576
// query elememt data
7677
pwd := ""
7778
if err := qsql.QueryElem(mdb, &pwd, "SELECT passwd FROM user WHERE username=?", "t1"); err != nil {
@@ -89,6 +90,19 @@ func main() {
8990
}
9091
fmt.Printf("ids:%+v\n", ids)
9192

93+
// query where in
94+
whereIn := []string{"t1", "t2"}
95+
whereInCount := 0
96+
if err := qsql.QueryElem(mdb, &whereInCount,
97+
fmt.Sprintf("SELECT COUNT(*) FROM user WHERE username in (%s)", qsql.StmtWhereIn(0, len(whereIn))),
98+
qsql.SliceToArgs(whereIn)...,
99+
); err != nil {
100+
panic(err)
101+
}
102+
if whereInCount != 2 {
103+
panic("expect count of whereIn is 2")
104+
}
105+
92106
// query data in string
93107
// table type
94108
titles, data, err := qsql.QueryPageArr(mdb, "SELECT * FROM user LIMIT 10")

where.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package qsql
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
)
7+
8+
// Extend the where in stmt
9+
//
10+
// Example for the first input:
11+
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))
12+
// Or
13+
// fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL)
14+
//
15+
// Example for the second input:
16+
// fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args))
17+
//
18+
func StmtWhereIn(paramIdx, paramsLen int, driverName ...string) string {
19+
drvName := getDrvName(nil, driverName...)
20+
switch drvName {
21+
case DRV_NAME_ORACLE, _DRV_NAME_OCI8:
22+
// *outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
23+
panic("unknow how to implemented")
24+
case DRV_NAME_POSTGRES:
25+
result := []byte{}
26+
for i := 0; i < paramsLen; i++ {
27+
result = append(result, []byte(fmt.Sprintf(":%d,", paramIdx+i))...)
28+
}
29+
if len(result) > 0 {
30+
return string(result[:len(result)-1]) // remove the last ','
31+
}
32+
return string(result)
33+
case DRV_NAME_SQLSERVER, _DRV_NAME_MSSQL:
34+
result := []byte{}
35+
for i := 0; i < paramsLen; i++ {
36+
result = append(result, []byte(fmt.Sprintf("@p%d,", paramIdx+i))...)
37+
}
38+
if len(result) > 0 {
39+
return string(result[:len(result)-1]) // remove the last ','
40+
}
41+
return string(result)
42+
default:
43+
resultLen := paramsLen * 2
44+
result := make([]byte, resultLen)
45+
for i := 0; i < resultLen; i += 2 {
46+
result[i] = '?'
47+
result[i+1] = ','
48+
}
49+
if len(result) > 0 {
50+
return string(result[:len(result)-1]) // remove the last ','
51+
}
52+
return string(result)
53+
}
54+
}
55+
56+
func SliceToArgs(arr interface{}) []interface{} {
57+
val := reflect.ValueOf(arr)
58+
result := make([]interface{}, val.Len())
59+
for i := len(result) - 1; i > -1; i-- {
60+
result[i] = val.Index(i).Interface()
61+
}
62+
return result
63+
}

api_test.go renamed to where_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,18 @@ func TestStmtWhereIn(t *testing.T) {
2424
t.Fatalf("expect '@p1,@p2,@p3', but@ %s", msOutput1)
2525
}
2626
}
27+
28+
func TestStmtSliceToArr(t *testing.T) {
29+
in := []string{"a", "b", "c"}
30+
out := SliceToArgs(in)
31+
if len(out) != 3 {
32+
t.Fatalf("expect 3, but:%d", len(out))
33+
}
34+
arg1, ok := out[0].(string)
35+
if !ok {
36+
t.Fatal("expect string, but not")
37+
}
38+
if arg1 != "a" {
39+
t.Fatalf("expect 'a' , but: %s ", arg1)
40+
}
41+
}

0 commit comments

Comments
 (0)