11import os
22from datetime import datetime
33
4- import requests
4+
55from celery import shared_task
6- from fastapi import HTTPException
76
7+ from app .core .async_sqlmap_api import (
8+ async_get ,
9+ async_post ,
10+ async_fetch_sqlmap_status ,
11+ async_fetch_sqlmap_logs ,
12+ async_fetch_sqlmap_result ,
13+ )
814from app .database .celery_sync_database import SessionLocal
915from app .models .sqlmap_result import (
1016 SqlmapScanPayload ,
1319 SqlmapScanLog ,
1420)
1521from app .core .sqlmap_core import celery_task_add
22+ import httpx
23+ import asyncio
1624
1725SQLMAP_API = os .getenv ("SQLMAP_API" )
1826AUTH = (os .getenv ("SQLMAP_USERNAME" ), os .getenv ("SQLMAP_PASSWORD" )) # Basic Auth
@@ -60,15 +68,8 @@ def normalize_sqlmap_result(raw: dict) -> dict:
6068 return result
6169
6270
63- def fetch_sqlmap_logs (session , task : SqlmapScanPayload ):
64- resp = requests .get (
65- f"{ SQLMAP_API } /scan/{ task .task_id } /log" ,
66- auth = AUTH ,
67- )
68- if not resp .ok :
69- return
70-
71- logs = resp .json ().get ("log" , [])
71+ def fetch_sqlmap_logs (session , task : SqlmapScanPayload , logs_json : dict ):
72+ logs = logs_json .get ("log" , [])
7273
7374 # 已存在日志(避免重复写)
7475 existing = {
@@ -94,15 +95,8 @@ def fetch_sqlmap_logs(session, task: SqlmapScanPayload):
9495 )
9596
9697
97- def fetch_sqlmap_result (session , task_id : str ):
98- resp = requests .get (
99- f"{ SQLMAP_API } /scan/{ task_id } /data" ,
100- auth = AUTH ,
101- )
102- if not resp .ok :
103- return
104-
105- data = resp .json ().get ("data" , [])
98+ def fetch_sqlmap_result (session , task_id : str , result_json : dict ):
99+ data = result_json .get ("data" , [])
106100
107101 result = SqlmapScanResult (
108102 target_url = "" ,
@@ -119,13 +113,12 @@ def fetch_sqlmap_result(session, task_id: str):
119113# 轮询运行状态任务
120114@shared_task (
121115 bind = True ,
122- autoretry_for = (requests . RequestException ,),
116+ autoretry_for = (httpx . RequestError ,),
123117 retry_backoff = 5 ,
124118 retry_kwargs = {"max_retries" : 3 },
125119)
126120def poll_single_sqlmap_task (self , sqlmap_task_id : str ):
127121 session = SessionLocal ()
128-
129122 try :
130123 task = (
131124 session .query (SqlmapScanPayload )
@@ -135,42 +128,50 @@ def poll_single_sqlmap_task(self, sqlmap_task_id: str):
135128 if not task :
136129 return
137130
138- # 查询扫描状态
139- status_resp = requests .get (
140- f"{ SQLMAP_API } /scan/{ sqlmap_task_id } /status" ,
141- auth = AUTH ,
142- )
131+ # 异步查询扫描状态
132+ status_json = asyncio .run (async_fetch_sqlmap_status (sqlmap_task_id ))
143133
144- if status_resp .status_code != 200 :
145- task .status = ScanStatus .failed
146- session .commit ()
147- return
134+ print (status_json )
148135
149- status_json = status_resp .json ()
150136 if not status_json .get ("success" ):
151137 task .status = ScanStatus .failed
152138 session .commit ()
153139 return
154140
155141 sqlmap_status = status_json ["status" ]
156142
157- # 状态同步
158143 if sqlmap_status == "running" :
159144 task .status = ScanStatus .running
160145
146+ # 异步获取日志
147+ logs_json = asyncio .run (async_fetch_sqlmap_logs (sqlmap_task_id ))
148+ fetch_sqlmap_logs (session , task , logs_json )
149+
150+ session .commit ()
151+
152+ # 再次轮询
153+ self .apply_async (args = [sqlmap_task_id ])
154+ return
155+
161156 elif sqlmap_status in ("terminated" , "not running" ):
162157 task .status = ScanStatus .success
163158 task .finished_at = datetime .utcnow ()
164- fetch_sqlmap_result (session , sqlmap_task_id )
159+
160+ logs_json = asyncio .run (async_fetch_sqlmap_logs (sqlmap_task_id ))
161+ result_json = asyncio .run (async_fetch_sqlmap_result (sqlmap_task_id ))
162+
163+ print (result_json )
164+
165+ fetch_sqlmap_logs (session , task , logs_json )
166+ fetch_sqlmap_result (session , sqlmap_task_id , result_json )
167+
168+ session .commit ()
169+ return
165170
166171 elif sqlmap_status == "error" :
167172 task .status = ScanStatus .failed
168173 task .finished_at = datetime .utcnow ()
169-
170- # 同步写入日志
171- fetch_sqlmap_logs (session , task )
172-
173- session .commit ()
174+ session .commit ()
174175
175176 finally :
176177 session .close ()
@@ -186,26 +187,21 @@ def poll_single_sqlmap_task(self, sqlmap_task_id: str):
186187def sqlmap_scan_task (self , payload : dict ):
187188 session = SessionLocal ()
188189 try :
189- # 1. 创建 SQLMap 任务
190- r = requests .get (f"{ SQLMAP_API } /task/new" , auth = AUTH , timeout = 10 )
191- r .raise_for_status ()
192- sqlmap_task_id = r .json ()["taskid" ]
193-
194- # 2. 启动扫描
195- start = requests .post (
196- f"{ SQLMAP_API } /scan/{ sqlmap_task_id } /start" ,
197- json = payload ,
198- auth = AUTH ,
199- timeout = 30 ,
190+ # 异步创建 SQLMap 任务
191+ task_json = asyncio .run (async_get ("/task/new" , timeout = 10 ))
192+ sqlmap_task_id = task_json ["taskid" ]
193+
194+ # 异步启动扫描
195+ start_json = asyncio .run (
196+ async_post (f"/scan/{ sqlmap_task_id } /start" , json = payload , timeout = 30 )
200197 )
201- start .raise_for_status ()
202198
203- # 3. 扫描启动成功后,调用 celery_task_add 写入 DB
199+ # 写入数据库
204200 celery_task_add (
205201 session = session ,
206202 task_id = sqlmap_task_id ,
207- celery_task_id = self .request .id , # Celery 任务 ID
208- scan_url = str (payload ["url" ]), # 转成 str,防止 HttpUrl 错误
203+ celery_task_id = self .request .id ,
204+ scan_url = str (payload ["url" ]),
209205 status = "running" ,
210206 scan_risk = payload .get ("risk" , 1 ),
211207 scan_level = payload .get ("level" , 1 ),
0 commit comments