@@ -6,6 +6,7 @@ package qsql
66import (
77 "context"
88 "database/sql"
9+ "fmt"
910 "io"
1011 "runtime/debug"
1112
@@ -19,6 +20,9 @@ const (
1920 DRV_NAME_POSTGRES = "postgres"
2021 DRV_NAME_SQLITE3 = "sqlite3"
2122 DRV_NAME_SQLSERVER = "sqlserver" // or "mssql"
23+
24+ _DRV_NAME_OCI8 = "oci8"
25+ _DRV_NAME_MSSQL = "mssql"
2226)
2327
2428var (
3135 REFLECT_DRV_NAME = DRV_NAME_MYSQL
3236)
3337
38+ func getDrvName (exec Execer , driverName ... string ) string {
39+ drvName := REFLECT_DRV_NAME
40+ db , ok := exec .(* DB )
41+ if ok {
42+ drvName = db .DriverName ()
43+ } else {
44+ drvNamesLen := len (driverName )
45+ if drvNamesLen > 0 {
46+ if drvNamesLen != 1 {
47+ panic (errors .New ("'drvName' expect only one argument" ).As (driverName ))
48+ }
49+ drvName = driverName [0 ]
50+ }
51+ }
52+ return drvName
53+ }
54+
55+ // Extend the where in stmt
56+ //
57+ // Example for the first input:
58+ // fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args))
59+ // Or
60+ // fmt.Sprintf("select * from table_name where in (%s)", qsql.StmtWhereIn(0,len(args), qsql.DRV_NAME_MYSQL)
61+ //
62+ // Example for the second input:
63+ // fmt.Sprintf("select * from table_name where id=? in (%s)", qsql.StmtWhereIn(1,len(args))
64+ //
65+ func StmtWhereIn (paramIdx , paramsLen int , driverName ... string ) string {
66+ drvName := getDrvName (nil , driverName ... )
67+ switch drvName {
68+ case DRV_NAME_ORACLE , _DRV_NAME_OCI8 :
69+ // *outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
70+ panic ("unknow how to implemented" )
71+ case DRV_NAME_POSTGRES :
72+ result := []byte {}
73+ for i := 0 ; i < paramsLen ; i ++ {
74+ result = append (result , []byte (fmt .Sprintf (":%d," , paramIdx + i ))... )
75+ }
76+ if len (result ) > 0 {
77+ return string (result [:len (result )- 1 ]) // remove the last ','
78+ }
79+ return string (result )
80+ case DRV_NAME_SQLSERVER , _DRV_NAME_MSSQL :
81+ result := []byte {}
82+ for i := 0 ; i < paramsLen ; i ++ {
83+ result = append (result , []byte (fmt .Sprintf ("@p%d," , paramIdx + i ))... )
84+ }
85+ if len (result ) > 0 {
86+ return string (result [:len (result )- 1 ]) // remove the last ','
87+ }
88+ return string (result )
89+ default :
90+ resultLen := paramsLen * 2
91+ result := make ([]byte , resultLen )
92+ for i := 0 ; i < resultLen ; i += 2 {
93+ result [i ] = '?'
94+ result [i + 1 ] = ','
95+ }
96+ if len (result ) > 0 {
97+ return string (result [:len (result )- 1 ]) // remove the last ','
98+ }
99+ return string (result )
100+ }
101+ }
102+
34103type Execer interface {
35104 Exec (query string , args ... interface {}) (sql.Result , error )
36105 ExecContext (ctx context.Context , query string , args ... interface {}) (sql.Result , error )
@@ -129,11 +198,11 @@ func ExecMultiTxContext(tx *sql.Tx, ctx context.Context, mTx []*MultiTx) error {
129198
130199// Reflect one db data to the struct. the struct tag format like `db:"field_title"`, reference to: http://github.com/jmoiron/sqlx
131200// When you no set the REFLECT_DRV_NAME, you can point out with the drvName
132- func InsertStruct (exec Execer , obj interface {}, tbName string , drvNames ... string ) (sql.Result , error ) {
133- return insertStruct (exec , context .TODO (), obj , tbName , drvNames ... )
201+ func InsertStruct (exec Execer , obj interface {}, tbName string , drvName ... string ) (sql.Result , error ) {
202+ return insertStruct (exec , context .TODO (), obj , tbName , drvName ... )
134203}
135- func InsertStructContext (exec Execer , ctx context.Context , obj interface {}, tbName string , drvNames ... string ) (sql.Result , error ) {
136- return insertStruct (exec , ctx , obj , tbName , drvNames ... )
204+ func InsertStructContext (exec Execer , ctx context.Context , obj interface {}, tbName string , drvName ... string ) (sql.Result , error ) {
205+ return insertStruct (exec , ctx , obj , tbName , drvName ... )
137206}
138207
139208// A sql.Query implements
0 commit comments