Skip to content

Commit feae763

Browse files
committed
fix sql builder
1 parent c9a44f3 commit feae763

5 files changed

Lines changed: 122 additions & 93 deletions

File tree

qsql.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"database/sql"
66
"fmt"
77
"reflect"
8+
"strings"
89

910
"github.com/gwaylib/errors"
1011
"github.com/jmoiron/sqlx/reflectx"
@@ -23,7 +24,7 @@ func insertStruct(exec Execer, ctx context.Context, obj interface{}, tbName stri
2324
if err != nil {
2425
return nil, errors.As(err)
2526
}
26-
execSql := fmt.Sprintf(addObjSql, tbName, fields.Names, fields.Stmts)
27+
execSql := fmt.Sprintf(addObjSql, tbName, strings.Join(fields.Names, ", "), strings.Join(fields.Stmts, ", "))
2728
// log.Debugf("%s%+v", execSql, vals)
2829
result, err := exec.ExecContext(ctx, execSql, fields.Values...)
2930
if err != nil {

reflect.go

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,17 @@ var refxM = reflectx.NewMapperTagFunc("db", func(in string) string {
2424
})
2525

2626
// return is it a auto_increment field
27-
func travelStructField(f *reflectx.FieldInfo, v *reflect.Value, order *int, drvName *string, outputNames *[]byte, outputInputs *[]byte, outputVals *[]interface{}) *reflect.Value {
28-
*order += 1
27+
func _travelStructField(f *reflectx.FieldInfo, v *reflect.Value, drvName *string, fieldIdx *int, selectNames *[]string, stmtParams *[]string, scanVals *[]interface{}) *reflect.Value {
28+
*fieldIdx += 1
2929
switch v.Kind() {
3030
case reflect.Invalid:
3131
// nil value
3232
return nil
3333
case
3434
reflect.Bool,
35-
reflect.Int,
36-
reflect.Int8,
37-
reflect.Int16,
38-
reflect.Int32,
39-
reflect.Int64,
40-
reflect.Uint,
41-
reflect.Uint8,
42-
reflect.Uint16,
43-
reflect.Uint32,
44-
reflect.Uint64,
45-
reflect.Float32,
46-
reflect.Float64,
35+
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
36+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
37+
reflect.Float32, reflect.Float64,
4738
reflect.String:
4839
// continue
4940
break
@@ -64,11 +55,9 @@ func travelStructField(f *reflectx.FieldInfo, v *reflect.Value, order *int, drvN
6455
continue
6556
}
6657
fieldVal := reflect.Indirect(*v).Field(i)
67-
autoFiled := travelStructField(
68-
child,
69-
&fieldVal,
70-
order, drvName,
71-
outputNames, outputInputs, outputVals,
58+
autoFiled := _travelStructField(
59+
child, &fieldVal, drvName,
60+
fieldIdx, selectNames, stmtParams, scanVals,
7261
)
7362
if autoFiled != nil {
7463
autoIncrement = autoFiled
@@ -101,35 +90,36 @@ func travelStructField(f *reflectx.FieldInfo, v *reflect.Value, order *int, drvN
10190
return v
10291
}
10392

104-
*outputVals = append(*outputVals, v.Interface())
10593
switch *drvName {
10694
case DRV_NAME_ORACLE, _DRV_NAME_OCI8:
107-
*order += 1
108-
*outputNames = append(*outputNames, []byte("\""+f.Name+"\",")...)
109-
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%s,", f.Name))...)
95+
*fieldIdx += 1
96+
*selectNames = append(*selectNames, "\""+f.Name+"\"")
97+
*stmtParams = append(*stmtParams, fmt.Sprintf(":%s", f.Name))
11098
case DRV_NAME_POSTGRES:
111-
*outputNames = append(*outputNames, []byte("\""+f.Name+"\",")...)
112-
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf(":%d,", *order))...)
113-
*order += 1
99+
*selectNames = append(*selectNames, "\""+f.Name+"\"")
100+
*stmtParams = append(*stmtParams, fmt.Sprintf(":%d", *fieldIdx))
101+
*fieldIdx += 1
114102
case DRV_NAME_SQLSERVER, _DRV_NAME_MSSQL:
115-
*outputNames = append(*outputNames, []byte("["+f.Name+"],")...)
116-
*outputInputs = append(*outputInputs, []byte(fmt.Sprintf("@p%d,", *order))...)
117-
*order += 1
103+
*selectNames = append(*selectNames, "["+f.Name+"]")
104+
*stmtParams = append(*stmtParams, fmt.Sprintf("@p%d", *fieldIdx))
105+
*fieldIdx += 1
118106
case DRV_NAME_MYSQL:
119-
*order += 1
120-
*outputNames = append(*outputNames, []byte("`"+f.Name+"`,")...)
121-
*outputInputs = append(*outputInputs, []byte("?,")...)
107+
*fieldIdx += 1
108+
*selectNames = append(*selectNames, "`"+f.Name+"`")
109+
*stmtParams = append(*stmtParams, "?")
122110
default:
123-
*outputNames = append(*outputNames, []byte("\""+f.Name+"\",")...)
124-
*outputInputs = append(*outputInputs, []byte("?,")...)
111+
*selectNames = append(*selectNames, "\""+f.Name+"\"")
112+
*stmtParams = append(*stmtParams, "?")
125113
}
114+
*scanVals = append(*scanVals, v.Interface())
126115

116+
// recursive end by nil
127117
return nil
128118
}
129119

130120
type reflectInsertField struct {
131-
Names string
132-
Stmts string
121+
Names []string
122+
Stmts []string
133123
Values []interface{}
134124

135125
AutoIncrement *reflect.Value
@@ -154,13 +144,13 @@ func reflectInsertStruct(i interface{}, drvName string) (*reflectInsertField, er
154144

155145
tm := refxM.TypeMap(v.Type())
156146

157-
names := []byte{}
158-
inputs := []byte{}
159-
vals := []interface{}{}
147+
outputSelectNames := []string{}
148+
outputStmtParams := []string{}
149+
outputFieldVals := []interface{}{}
160150
var autoIncrement *reflect.Value
161151

162152
childrenLen := len(tm.Tree.Children)
163-
order := 0
153+
fieldIdx := 0
164154
for i := 0; i < childrenLen; i++ {
165155
field := tm.Tree.Children[i]
166156
if field == nil {
@@ -169,19 +159,23 @@ func reflectInsertStruct(i interface{}, drvName string) (*reflectInsertField, er
169159
}
170160

171161
fieldVal := v.Field(i)
172-
autoField := travelStructField(field, &fieldVal, &order, &drvName, &names, &inputs, &vals)
162+
autoField := _travelStructField(
163+
field, &fieldVal, &drvName,
164+
&fieldIdx,
165+
&outputSelectNames, &outputStmtParams, &outputFieldVals,
166+
)
173167
if autoField != nil {
174168
autoIncrement = autoField
175169
}
176170
}
177171

178-
if len(names) == 0 {
172+
if len(outputSelectNames) == 0 {
179173
panic("No public field in struct")
180174
}
181175
return &reflectInsertField{
182-
Names: string(names[:len(names)-1]),
183-
Stmts: string(inputs[:len(inputs)-1]),
184-
Values: vals,
176+
Names: outputSelectNames,
177+
Stmts: outputStmtParams,
178+
Values: outputFieldVals,
185179
AutoIncrement: autoIncrement,
186180
}, nil
187181
}

reflect_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"fmt"
66
"reflect"
7+
"strings"
78
"testing"
89
"time"
910
)
@@ -50,10 +51,10 @@ func TestReflect(t *testing.T) {
5051
if err != nil {
5152
t.Fatal(err)
5253
}
53-
if refVal.Names != "`a`,`time`,`data`,`byte`,`dbdata`,`null_string`,`C`" {
54+
if strings.Join(refVal.Names, ",") != "`a`,`time`,`data`,`byte`,`dbdata`,`null_string`,`C`" {
5455
t.Fatal(refVal.Names)
5556
}
56-
if refVal.Stmts != "?,?,?,?,?,?,?" {
57+
if strings.Join(refVal.Stmts, ",") != "?,?,?,?,?,?,?" {
5758
t.Fatal(refVal.Stmts)
5859
}
5960
if len(refVal.Values) != 7 {
@@ -74,10 +75,10 @@ func TestReflect(t *testing.T) {
7475
if err != nil {
7576
t.Fatal(err)
7677
}
77-
if refVal.Names != "`id`,`a`,`C`" {
78+
if strings.Join(refVal.Names, ",") != "`id`,`a`,`C`" {
7879
t.Fatal(refVal.Names)
7980
}
80-
if refVal.Stmts != "?,?,?" {
81+
if strings.Join(refVal.Stmts, ",") != "?,?,?" {
8182
t.Fatal(refVal.Stmts)
8283
}
8384
if len(refVal.Values) != 3 {
@@ -96,10 +97,10 @@ func TestReflect(t *testing.T) {
9697
if err != nil {
9798
t.Fatal(err)
9899
}
99-
if refVal.Names != `"a","time","data","byte","dbdata","null_string","C","d","id","a","C","e"` {
100+
if strings.Join(refVal.Names, ",") != `"a","time","data","byte","dbdata","null_string","C","d","id","a","C","e"` {
100101
t.Fatal(refVal.Names)
101102
}
102-
if refVal.Stmts != ":a,:time,:data,:byte,:dbdata,:null_string,:C,:d,:id,:a,:C,:e" {
103+
if strings.Join(refVal.Stmts, ",") != ":a,:time,:data,:byte,:dbdata,:null_string,:C,:d,:id,:a,:C,:e" {
103104
t.Fatal(refVal.Stmts)
104105
}
105106
if fmt.Sprintf("%+v", refVal.Values) != `[100 0001-01-01 00:00:00 +0000 UTC [97 98 99] 0 {String: Valid:false} testing d 1 101 testing1 e]` {

sqlbuilder.go

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@ import (
66

77
type SqlBuilder struct {
88
drvName string
9-
buff strings.Builder
10-
args []interface{}
9+
10+
selectStr string
11+
fromBuff strings.Builder
12+
13+
args []interface{}
14+
15+
indent string
1116
}
1217

1318
func NewSqlBuilder(drvName ...string) *SqlBuilder {
14-
b := &SqlBuilder{}
19+
b := &SqlBuilder{
20+
indent: " ",
21+
}
1522
if len(drvName) > 0 {
1623
b.drvName = drvName[0]
1724
} else {
@@ -20,53 +27,74 @@ func NewSqlBuilder(drvName ...string) *SqlBuilder {
2027
return b
2128
}
2229

30+
func (b *SqlBuilder) Copy() *SqlBuilder {
31+
n := &SqlBuilder{
32+
drvName: b.drvName,
33+
selectStr: b.selectStr,
34+
args: make([]interface{}, len(b.args)),
35+
indent: b.indent,
36+
}
37+
copy(n.args, b.args)
38+
n.fromBuff.WriteString(b.fromBuff.String())
39+
return n
40+
}
41+
func (b *SqlBuilder) SetIndent(indent string) *SqlBuilder {
42+
b.indent = indent
43+
return b
44+
}
45+
2346
func (b *SqlBuilder) Add(key string, args ...interface{}) *SqlBuilder {
2447
if len(key) > 0 {
25-
b.buff.WriteString(key)
48+
b.fromBuff.WriteString(b.indent)
49+
b.fromBuff.WriteString(key)
2650
}
2751
if len(args) > 0 {
2852
b.args = append(b.args, args...)
2953
}
30-
b.buff.WriteString("\n")
3154
return b
3255
}
3356

34-
func (b *SqlBuilder) AddTab(key string, args ...interface{}) *SqlBuilder {
35-
b.buff.WriteString(" ")
36-
return b.Add(key, args)
37-
}
38-
func (b *SqlBuilder) AddTabOK(ok bool, key string, args ...interface{}) *SqlBuilder {
57+
func (b *SqlBuilder) AddIf(ok bool, key string, args ...interface{}) *SqlBuilder {
3958
if !ok {
4059
return b
4160
}
42-
return b.AddTab(key, args...)
61+
return b.Add(key, args...)
4362
}
4463

4564
func (b *SqlBuilder) In(in []interface{}) string {
4665
if len(in) == 0 {
47-
panic("need condition")
66+
panic("need arguments of in condition")
4867
}
4968
b.args = append(b.args, in...)
5069
return stmtIn(len(b.args)-1, len(in), b.drvName)
5170
}
5271

53-
func (b *SqlBuilder) Args() []interface{} {
54-
return b.args
55-
}
56-
57-
func (b *SqlBuilder) Select(column ...string) string {
58-
selectStr := "SELECT\n "
72+
func (b *SqlBuilder) Select(column ...string) *SqlBuilder {
5973
if len(column) > 0 {
60-
selectStr += strings.Join(column, ", ")
74+
b.selectStr = strings.Join(column, ", ")
6175
} else {
62-
selectStr += "*"
76+
b.selectStr = "*"
6377
}
64-
return selectStr + "\n" + b.buff.String()
78+
return b
6579
}
66-
func (b *SqlBuilder) SelectStruct(obj interface{}) string {
80+
func (b *SqlBuilder) SelectStruct(obj interface{}) *SqlBuilder {
6781
fields, err := reflectInsertStruct(obj, b.drvName)
6882
if err != nil {
6983
panic(err)
7084
}
71-
return "SELECT\n " + fields.Names + "\n" + b.buff.String()
85+
b.selectStr = strings.Join(fields.Names, ", ")
86+
return b
87+
}
88+
89+
func (b *SqlBuilder) String() string {
90+
return "SELECT" + b.indent + b.selectStr +
91+
b.fromBuff.String()
92+
}
93+
94+
func (b *SqlBuilder) Args() []interface{} {
95+
return b.args
96+
}
97+
98+
func (b *SqlBuilder) Sql() []interface{} {
99+
return append([]interface{}{b.String()}, b.args...)
72100
}

sqlbuilder_test.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,29 @@ import (
66
)
77

88
func TestSqlBuilder(t *testing.T) {
9-
sb := NewSqlBuilder(DRV_NAME_POSTGRES)
10-
sb.Add("FROM")
11-
sb.AddTab("tmp tb1")
12-
sb.AddTab("INNER JOIN tmp1 tb2 ON tb2.id=tb2.tmp_id")
13-
sb.Add("WHERE")
14-
sb.AddTab("1=1")
15-
sb.AddTabOK(false, "AND (1=?)", 0)
16-
sb.AddTabOK(true, "OR (tb1 IN ("+sb.In([]interface{}{1, 2})+"))")
17-
sb.Add("GROUP BY tb1.id")
18-
sb.Add("HAVING count(*)>?", 1)
19-
fmt.Println(sb.Select("count(*)"))
9+
bd := NewSqlBuilder(DRV_NAME_POSTGRES)
10+
bd.Select("count(*)")
11+
bd.Add("FROM")
12+
bd.Add("tmp tb1")
13+
bd.Add("INNER JOIN tmp1 tb2 ON tb2.id=tb2.tmp_id")
14+
bd.Add("WHERE")
15+
bd.Add("1=1")
16+
bd.AddIf(true, "AND (1=?)", 0)
17+
bd.AddIf(true, "OR (tb1 IN ("+bd.In([]interface{}{1, 2})+"))")
18+
bd.Add("GROUP BY tb1.id")
19+
bd.Add("HAVING count(*)>?", 1)
20+
fmt.Println(bd)
2021

21-
sb.Add("ORDER BY tb1.id DESC")
22-
sb.Add("OFFSET ?", 1)
23-
sb.Add("LIMIT ?", 1)
24-
fmt.Println(sb.Select("tb1.id", "count(*)"))
22+
bd1 := bd.Copy().Select("tb1.id", "count(*)")
23+
bd1.Add("ORDER BY tb1.id DESC")
24+
bd1.Add("OFFSET ?", 1)
25+
bd1.Add("LIMIT ?", 1)
26+
fmt.Println(bd1)
2527

26-
sb1 := NewSqlBuilder(DRV_NAME_POSTGRES)
27-
sb1.Add("FROM ("+sb.Select("tb1.id", "count(*)")+") tmp", sb.Args())
28-
fmt.Println(sb1.Select("*"))
28+
bd2 := NewSqlBuilder(DRV_NAME_POSTGRES)
29+
bd2.Select("*")
30+
bd2.Add("FROM ("+bd.String()+") tmp", bd1.Args())
31+
fmt.Println(bd2)
32+
33+
fmt.Println(bd2.Sql())
2934
}

0 commit comments

Comments
 (0)