Skip to content

Commit cb70d4e

Browse files
committed
fix: add archive skip count tracking and no-auditable-files error
BUG-003: When an archive (.zip/.rar/.7z) contains no auditable files (all files are unsupported formats), return a clear error message "no auditable files in the archive" instead of silently creating an empty audit record. This satisfies AC-4.6. OBS-004: Track the count of skipped unsupported-format files during archive processing and include "skipped N unsupported format file(s)" in the API response message field. This satisfies AC-4.5. Changes: - Add skippedCount return value to processRarContent, process7zContent, getSqlsFromZip, getSqlsFromRar, getSqlsFrom7z, getSqlsFromArchive - Add Message field to GetSQLFromFileResp for user feedback - Add newBaseResWithMessage helper for response message composition - Update CreateSQLAuditRecord and CreateAuditTask handlers to surface the skip message in the API response - Update tests for new function signatures
1 parent d3a5b07 commit cb70d4e

6 files changed

Lines changed: 85 additions & 60 deletions

File tree

sqle/api/controller/v1/archive_7z.go

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,47 @@ import (
2121
// 使用 github.com/bodgit/sevenzip 库解压 7z 文件,遍历内部文件并调用 processArchiveEntry 处理。
2222
// 函数签名与 getSqlsFromZip / getSqlsFromRar 保持一致。
2323
// 注意:sevenzip 需要 io.ReaderAt + int64 size(不同于 RAR 的 io.Reader),需先将上传文件读入 bytes.Reader。
24-
func getSqlsFrom7z(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
24+
func getSqlsFrom7z(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, skippedCount int, exist bool, err error) {
2525
file, err := c.FormFile(InputZipFileName)
2626
if err == http.ErrMissingFile {
27-
return nil, nil, false, nil
27+
return nil, nil, 0, false, nil
2828
}
2929
if err != nil {
30-
return nil, nil, false, err
30+
return nil, nil, 0, false, err
3131
}
3232

3333
f, err := file.Open()
3434
if err != nil {
35-
return nil, nil, false, err
35+
return nil, nil, 0, false, err
3636
}
3737
defer f.Close()
3838

3939
// 使用 archiveConfig 进行压缩包总大小限制检查(上传文件大小预检)
4040
if err := defaultArchiveConfig.checkSize(0, file.Size); err != nil {
41-
return nil, nil, false, err
41+
return nil, nil, 0, false, err
4242
}
4343

4444
// sevenzip 需要 io.ReaderAt 接口,将上传文件内容读入 bytes.Reader
4545
buf, err := io.ReadAll(f)
4646
if err != nil {
47-
return nil, nil, false, fmt.Errorf("read 7z file into memory failed: %v", err)
47+
return nil, nil, 0, false, fmt.Errorf("read 7z file into memory failed: %v", err)
4848
}
4949

50-
sqlsFromSQLFile, sqlsFromXML, err = process7zContent(bytes.NewReader(buf), int64(len(buf)))
50+
sqlsFromSQLFile, sqlsFromXML, skippedCount, err = process7zContent(bytes.NewReader(buf), int64(len(buf)))
5151
if err != nil {
52-
return nil, nil, false, err
52+
return nil, nil, 0, false, err
5353
}
5454

55-
return sqlsFromSQLFile, sqlsFromXML, true, nil
55+
return sqlsFromSQLFile, sqlsFromXML, skippedCount, true, nil
5656
}
5757

5858
// process7zContent 从 io.ReaderAt 中读取 7z 内容,遍历 entry 并提取 SQL。
5959
// 该函数封装了 7z 解压的核心逻辑,独立于 echo.Context,便于单元测试。
60-
func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, err error) {
60+
func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, skippedCount int, err error) {
6161
// 使用 sevenzip.NewReader 打开 7z 文件
6262
szr, err := sevenzip.NewReader(r, size)
6363
if err != nil {
64-
return nil, nil, fmt.Errorf("open 7z file failed: %v", err)
64+
return nil, nil, 0, fmt.Errorf("open 7z file failed: %v", err)
6565
}
6666

6767
var xmlContents []xmlParser.XmlFile
@@ -77,7 +77,7 @@ func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLF
7777
// 文件数量限制检查
7878
fileCount++
7979
if err := defaultArchiveConfig.checkFileCount(fileCount); err != nil {
80-
return nil, nil, err
80+
return nil, nil, 0, err
8181
}
8282

8383
// 嵌套压缩包检查:depth=1 时跳过内层压缩包
@@ -89,18 +89,18 @@ func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLF
8989
// 打开并读取文件内容
9090
rc, err := f.Open()
9191
if err != nil {
92-
return nil, nil, fmt.Errorf("open 7z entry %q failed: %v", f.Name, err)
92+
return nil, nil, 0, fmt.Errorf("open 7z entry %q failed: %v", f.Name, err)
9393
}
9494
content, err := io.ReadAll(rc)
9595
rc.Close()
9696
if err != nil {
97-
return nil, nil, fmt.Errorf("read 7z entry %q content failed: %v", f.Name, err)
97+
return nil, nil, 0, fmt.Errorf("read 7z entry %q content failed: %v", f.Name, err)
9898
}
9999

100100
// 累计大小限制检查
101101
totalSize += int64(len(content))
102102
if err := defaultArchiveConfig.checkSize(0, totalSize); err != nil {
103-
return nil, nil, err
103+
return nil, nil, 0, err
104104
}
105105

106106
// 委托 processArchiveEntry 按扩展名分发处理
@@ -110,9 +110,10 @@ func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLF
110110
log.NewEntry().WithField("convert_to_utf8", f.Name).Errorf("convert to utf8 failed: %v", err)
111111
continue
112112
}
113-
return nil, nil, err
113+
return nil, nil, 0, err
114114
}
115115
if !isSupported {
116+
skippedCount++
116117
continue
117118
}
118119

@@ -131,7 +132,7 @@ func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLF
131132
{
132133
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents)
133134
if err != nil {
134-
return nil, nil, err
135+
return nil, nil, 0, err
135136
}
136137
sqlsFromXML = append(sqlsFromXML, sqlsFromXmls...)
137138
}
@@ -144,5 +145,5 @@ func process7zContent(r io.ReaderAt, size int64) (sqlsFromSQLFile []SQLsFromSQLF
144145
return utils.CompareNatural(sqlsFromXML[i].FilePath, sqlsFromXML[j].FilePath)
145146
})
146147

147-
return sqlsFromSQLFile, sqlsFromXML, nil
148+
return sqlsFromSQLFile, sqlsFromXML, skippedCount, nil
148149
}

sqle/api/controller/v1/archive_7z_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestProcess7zContent(t *testing.T) {
9797
t.Run(name, func(t *testing.T) {
9898
r, size := openTest7z(t, tc.szFile)
9999

100-
sqlFiles, xmlFiles, err := process7zContent(r, size)
100+
sqlFiles, xmlFiles, _, err := process7zContent(r, size)
101101

102102
// Check error
103103
if tc.expectErr {
@@ -198,7 +198,7 @@ func TestProcess7zContentFileCountLimit(t *testing.T) {
198198
func TestProcess7zContentInvalid7z(t *testing.T) {
199199
// Test with invalid 7z data
200200
invalidData := bytes.NewReader([]byte("this is not a 7z file"))
201-
_, _, err := process7zContent(invalidData, int64(len("this is not a 7z file")))
201+
_, _, _, err := process7zContent(invalidData, int64(len("this is not a 7z file")))
202202
if err == nil {
203203
t.Error("expected error for invalid 7z data, got nil")
204204
}

sqle/api/controller/v1/archive_rar.go

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,41 @@ import (
1919
// getSqlsFromRar 从 RAR 文件中提取 SQL 语句。
2020
// 使用 github.com/nwaples/rardecode 库解压 RAR 文件,遍历内部文件并调用 processArchiveEntry 处理。
2121
// 函数签名与 getSqlsFromZip 保持一致。
22-
func getSqlsFromRar(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
22+
func getSqlsFromRar(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, skippedCount int, exist bool, err error) {
2323
file, err := c.FormFile(InputZipFileName)
2424
if err == http.ErrMissingFile {
25-
return nil, nil, false, nil
25+
return nil, nil, 0, false, nil
2626
}
2727
if err != nil {
28-
return nil, nil, false, err
28+
return nil, nil, 0, false, err
2929
}
3030

3131
f, err := file.Open()
3232
if err != nil {
33-
return nil, nil, false, err
33+
return nil, nil, 0, false, err
3434
}
3535
defer f.Close()
3636

3737
// 使用 archiveConfig 进行压缩包总大小限制检查(上传文件大小预检)
3838
if err := defaultArchiveConfig.checkSize(0, file.Size); err != nil {
39-
return nil, nil, false, err
39+
return nil, nil, 0, false, err
4040
}
4141

42-
sqlsFromSQLFile, sqlsFromXML, err = processRarContent(f)
42+
sqlsFromSQLFile, sqlsFromXML, skippedCount, err = processRarContent(f)
4343
if err != nil {
44-
return nil, nil, false, err
44+
return nil, nil, 0, false, err
4545
}
4646

47-
return sqlsFromSQLFile, sqlsFromXML, true, nil
47+
return sqlsFromSQLFile, sqlsFromXML, skippedCount, true, nil
4848
}
4949

5050
// processRarContent 从 io.Reader 中读取 RAR 内容,遍历 entry 并提取 SQL。
5151
// 该函数封装了 RAR 解压的核心逻辑,独立于 echo.Context,便于单元测试。
52-
func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, err error) {
52+
func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, skippedCount int, err error) {
5353
// 使用 rardecode.NewReader 打开 RAR 文件,密码参数为空字符串(不支持加密 RAR)
5454
rr, err := rardecode.NewReader(r, "")
5555
if err != nil {
56-
return nil, nil, fmt.Errorf("open rar file failed: %v", err)
56+
return nil, nil, 0, fmt.Errorf("open rar file failed: %v", err)
5757
}
5858

5959
var xmlContents []xmlParser.XmlFile
@@ -66,7 +66,7 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
6666
break
6767
}
6868
if err != nil {
69-
return nil, nil, fmt.Errorf("read rar entry failed: %v", err)
69+
return nil, nil, 0, fmt.Errorf("read rar entry failed: %v", err)
7070
}
7171

7272
// 跳过目录
@@ -77,7 +77,7 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
7777
// 文件数量限制检查
7878
fileCount++
7979
if err := defaultArchiveConfig.checkFileCount(fileCount); err != nil {
80-
return nil, nil, err
80+
return nil, nil, 0, err
8181
}
8282

8383
// 嵌套压缩包检查:depth=1 时跳过内层压缩包
@@ -89,13 +89,13 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
8989
// 读取文件内容
9090
content, err := io.ReadAll(rr)
9191
if err != nil {
92-
return nil, nil, fmt.Errorf("read rar entry content failed: %v", err)
92+
return nil, nil, 0, fmt.Errorf("read rar entry content failed: %v", err)
9393
}
9494

9595
// 累计大小限制检查
9696
totalSize += int64(len(content))
9797
if err := defaultArchiveConfig.checkSize(0, totalSize); err != nil {
98-
return nil, nil, err
98+
return nil, nil, 0, err
9999
}
100100

101101
// 委托 processArchiveEntry 按扩展名分发处理
@@ -105,9 +105,10 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
105105
log.NewEntry().WithField("convert_to_utf8", header.Name).Errorf("convert to utf8 failed: %v", err)
106106
continue
107107
}
108-
return nil, nil, err
108+
return nil, nil, 0, err
109109
}
110110
if !isSupported {
111+
skippedCount++
111112
continue
112113
}
113114

@@ -126,7 +127,7 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
126127
{
127128
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents)
128129
if err != nil {
129-
return nil, nil, err
130+
return nil, nil, 0, err
130131
}
131132
sqlsFromXML = append(sqlsFromXML, sqlsFromXmls...)
132133
}
@@ -139,5 +140,5 @@ func processRarContent(r io.Reader) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
139140
return utils.CompareNatural(sqlsFromXML[i].FilePath, sqlsFromXML[j].FilePath)
140141
})
141142

142-
return sqlsFromSQLFile, sqlsFromXML, nil
143+
return sqlsFromSQLFile, sqlsFromXML, skippedCount, nil
143144
}

sqle/api/controller/v1/archive_rar_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestProcessRarContent(t *testing.T) {
9797
f := openTestRar(t, tc.rarFile)
9898
defer f.Close()
9999

100-
sqlFiles, xmlFiles, err := processRarContent(f)
100+
sqlFiles, xmlFiles, _, err := processRarContent(f)
101101

102102
// Check error
103103
if tc.expectErr {
@@ -204,7 +204,7 @@ func TestProcessRarContentFileCountLimit(t *testing.T) {
204204
func TestProcessRarContentInvalidRar(t *testing.T) {
205205
// Test with invalid RAR data
206206
invalidData := bytes.NewReader([]byte("this is not a rar file"))
207-
_, _, err := processRarContent(invalidData)
207+
_, _, _, err := processRarContent(invalidData)
208208
if err == nil {
209209
t.Error("expected error for invalid RAR data, got nil")
210210
}

sqle/api/controller/v1/sql_audit_record.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func CreateSQLAuditRecord(c echo.Context) error {
172172
}
173173

174174
return c.JSON(http.StatusOK, &CreateSQLAuditRecordResV1{
175-
BaseRes: controller.NewBaseReq(nil),
175+
BaseRes: newBaseResWithMessage(sqls.Message),
176176
Data: &SQLAuditRecordResData{
177177
Id: record.AuditRecordId,
178178
Task: &AuditTaskResV1{
@@ -203,6 +203,7 @@ type GetSQLFromFileResp struct {
203203
SQLsFromFormData string
204204
SQLsFromSQLFiles []SQLsFromSQLFile
205205
SQLsFromXMLs []SQLFromXML
206+
Message string // Optional message for user feedback (e.g., skipped file count)
206207
}
207208

208209
type SQLsFromSQLFile struct {
@@ -349,27 +350,27 @@ func buildOfflineTaskForAudit(userId uint64, dbType string, sqls GetSQLFromFileR
349350
}
350351

351352
// todo 此处跳过了不支持的编码格式文件
352-
func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, exist bool, err error) {
353+
func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFromXML []SQLFromXML, skippedCount int, exist bool, err error) {
353354
file, err := c.FormFile(InputZipFileName)
354355
if err == http.ErrMissingFile {
355-
return nil, nil, false, nil
356+
return nil, nil, 0, false, nil
356357
}
357358
if err != nil {
358-
return nil, nil, false, err
359+
return nil, nil, 0, false, err
359360
}
360361
f, err := file.Open()
361362
if err != nil {
362-
return nil, nil, false, err
363+
return nil, nil, 0, false, err
363364
}
364365
defer f.Close()
365366

366367
// 使用 archiveConfig 进行压缩包总大小限制检查
367368
if err := defaultArchiveConfig.checkSize(0, file.Size); err != nil {
368-
return nil, nil, false, err
369+
return nil, nil, 0, false, err
369370
}
370371
r, err := zip.NewReader(f, file.Size)
371372
if err != nil {
372-
return nil, nil, false, err
373+
return nil, nil, 0, false, err
373374
}
374375

375376
var xmlContents []xmlParser.XmlFile
@@ -384,7 +385,7 @@ func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
384385
// 使用 archiveConfig 进行文件数量限制检查
385386
fileCount++
386387
if err := defaultArchiveConfig.checkFileCount(fileCount); err != nil {
387-
return nil, nil, false, err
388+
return nil, nil, 0, false, err
388389
}
389390

390391
// 嵌套压缩包检查:depth=1 时跳过内层压缩包
@@ -396,18 +397,18 @@ func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
396397
// 读取文件内容
397398
rc, err := srcFile.Open()
398399
if err != nil {
399-
return nil, nil, false, fmt.Errorf("open src file failed: %v", err)
400+
return nil, nil, 0, false, fmt.Errorf("open src file failed: %v", err)
400401
}
401402
content, err := io.ReadAll(rc)
402403
rc.Close()
403404
if err != nil {
404-
return nil, nil, false, fmt.Errorf("read src file failed: %v", err)
405+
return nil, nil, 0, false, fmt.Errorf("read src file failed: %v", err)
405406
}
406407

407408
// 累计大小限制检查
408409
totalSize += int64(len(content))
409410
if err := defaultArchiveConfig.checkSize(0, totalSize); err != nil {
410-
return nil, nil, false, err
411+
return nil, nil, 0, false, err
411412
}
412413

413414
// 委托 processArchiveEntry 按扩展名分发处理
@@ -417,9 +418,10 @@ func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
417418
log.NewEntry().WithField("convert_to_utf8", srcFile.Name).Errorf("convert to utf8 failed: %v", err)
418419
continue
419420
}
420-
return nil, nil, false, err
421+
return nil, nil, 0, false, err
421422
}
422423
if !isSupported {
424+
skippedCount++
423425
continue
424426
}
425427

@@ -438,7 +440,7 @@ func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
438440
{
439441
sqlsFromXmls, err := parseXMLsWithFilePath(xmlContents)
440442
if err != nil {
441-
return nil, nil, false, err
443+
return nil, nil, 0, false, err
442444
}
443445
sqlsFromXML = append(sqlsFromXML, sqlsFromXmls...)
444446
}
@@ -451,7 +453,7 @@ func getSqlsFromZip(c echo.Context) (sqlsFromSQLFile []SQLsFromSQLFile, sqlsFrom
451453
return utils.CompareNatural(sqlsFromXML[i].FilePath, sqlsFromXML[j].FilePath)
452454
})
453455

454-
return sqlsFromSQLFile, sqlsFromXML, true, nil
456+
return sqlsFromSQLFile, sqlsFromXML, skippedCount, true, nil
455457
}
456458
func parseXMLsWithFilePath(xmlContents []xmlParser.XmlFile) ([]SQLFromXML, error) {
457459
allStmtsFromXml, err := xmlParser.ParseXMLs(xmlContents, xmlParser.SkipErrorQuery, xmlParser.RestoreOriginSql)

0 commit comments

Comments
 (0)