Skip to content

Commit ae0374c

Browse files
committed
异步httpx
1 parent 5409eae commit ae0374c

1 file changed

Lines changed: 50 additions & 54 deletions

File tree

app/tasks/sqlmap_worker.py

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import os
22
from datetime import datetime
33

4-
import requests
4+
55
from celery import shared_task
6-
from fastapi import HTTPException
76

7+
from app.core.async_sqlmap_api import (
8+
async_get,
9+
async_post,
10+
async_fetch_sqlmap_status,
11+
async_fetch_sqlmap_logs,
12+
async_fetch_sqlmap_result,
13+
)
814
from app.database.celery_sync_database import SessionLocal
915
from app.models.sqlmap_result import (
1016
SqlmapScanPayload,
@@ -13,6 +19,8 @@
1319
SqlmapScanLog,
1420
)
1521
from app.core.sqlmap_core import celery_task_add
22+
import httpx
23+
import asyncio
1624

1725
SQLMAP_API = os.getenv("SQLMAP_API")
1826
AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD")) # Basic Auth
@@ -60,15 +68,8 @@ def normalize_sqlmap_result(raw: dict) -> dict:
6068
return result
6169

6270

63-
def fetch_sqlmap_logs(session, task: SqlmapScanPayload):
64-
resp = requests.get(
65-
f"{SQLMAP_API}/scan/{task.task_id}/log",
66-
auth=AUTH,
67-
)
68-
if not resp.ok:
69-
return
70-
71-
logs = resp.json().get("log", [])
71+
def fetch_sqlmap_logs(session, task: SqlmapScanPayload, logs_json: dict):
72+
logs = logs_json.get("log", [])
7273

7374
# 已存在日志(避免重复写)
7475
existing = {
@@ -94,15 +95,8 @@ def fetch_sqlmap_logs(session, task: SqlmapScanPayload):
9495
)
9596

9697

97-
def fetch_sqlmap_result(session, task_id: str):
98-
resp = requests.get(
99-
f"{SQLMAP_API}/scan/{task_id}/data",
100-
auth=AUTH,
101-
)
102-
if not resp.ok:
103-
return
104-
105-
data = resp.json().get("data", [])
98+
def fetch_sqlmap_result(session, task_id: str, result_json: dict):
99+
data = result_json.get("data", [])
106100

107101
result = SqlmapScanResult(
108102
target_url="",
@@ -119,13 +113,12 @@ def fetch_sqlmap_result(session, task_id: str):
119113
# 轮询运行状态任务
120114
@shared_task(
121115
bind=True,
122-
autoretry_for=(requests.RequestException,),
116+
autoretry_for=(httpx.RequestError,),
123117
retry_backoff=5,
124118
retry_kwargs={"max_retries": 3},
125119
)
126120
def poll_single_sqlmap_task(self, sqlmap_task_id: str):
127121
session = SessionLocal()
128-
129122
try:
130123
task = (
131124
session.query(SqlmapScanPayload)
@@ -135,42 +128,50 @@ def poll_single_sqlmap_task(self, sqlmap_task_id: str):
135128
if not task:
136129
return
137130

138-
# 查询扫描状态
139-
status_resp = requests.get(
140-
f"{SQLMAP_API}/scan/{sqlmap_task_id}/status",
141-
auth=AUTH,
142-
)
131+
# 异步查询扫描状态
132+
status_json = asyncio.run(async_fetch_sqlmap_status(sqlmap_task_id))
143133

144-
if status_resp.status_code != 200:
145-
task.status = ScanStatus.failed
146-
session.commit()
147-
return
134+
print(status_json)
148135

149-
status_json = status_resp.json()
150136
if not status_json.get("success"):
151137
task.status = ScanStatus.failed
152138
session.commit()
153139
return
154140

155141
sqlmap_status = status_json["status"]
156142

157-
# 状态同步
158143
if sqlmap_status == "running":
159144
task.status = ScanStatus.running
160145

146+
# 异步获取日志
147+
logs_json = asyncio.run(async_fetch_sqlmap_logs(sqlmap_task_id))
148+
fetch_sqlmap_logs(session, task, logs_json)
149+
150+
session.commit()
151+
152+
# 再次轮询
153+
self.apply_async(args=[sqlmap_task_id])
154+
return
155+
161156
elif sqlmap_status in ("terminated", "not running"):
162157
task.status = ScanStatus.success
163158
task.finished_at = datetime.utcnow()
164-
fetch_sqlmap_result(session, sqlmap_task_id)
159+
160+
logs_json = asyncio.run(async_fetch_sqlmap_logs(sqlmap_task_id))
161+
result_json = asyncio.run(async_fetch_sqlmap_result(sqlmap_task_id))
162+
163+
print(result_json)
164+
165+
fetch_sqlmap_logs(session, task, logs_json)
166+
fetch_sqlmap_result(session, sqlmap_task_id, result_json)
167+
168+
session.commit()
169+
return
165170

166171
elif sqlmap_status == "error":
167172
task.status = ScanStatus.failed
168173
task.finished_at = datetime.utcnow()
169-
170-
# 同步写入日志
171-
fetch_sqlmap_logs(session, task)
172-
173-
session.commit()
174+
session.commit()
174175

175176
finally:
176177
session.close()
@@ -186,26 +187,21 @@ def poll_single_sqlmap_task(self, sqlmap_task_id: str):
186187
def sqlmap_scan_task(self, payload: dict):
187188
session = SessionLocal()
188189
try:
189-
# 1. 创建 SQLMap 任务
190-
r = requests.get(f"{SQLMAP_API}/task/new", auth=AUTH, timeout=10)
191-
r.raise_for_status()
192-
sqlmap_task_id = r.json()["taskid"]
193-
194-
# 2. 启动扫描
195-
start = requests.post(
196-
f"{SQLMAP_API}/scan/{sqlmap_task_id}/start",
197-
json=payload,
198-
auth=AUTH,
199-
timeout=30,
190+
# 异步创建 SQLMap 任务
191+
task_json = asyncio.run(async_get("/task/new", timeout=10))
192+
sqlmap_task_id = task_json["taskid"]
193+
194+
# 异步启动扫描
195+
start_json = asyncio.run(
196+
async_post(f"/scan/{sqlmap_task_id}/start", json=payload, timeout=30)
200197
)
201-
start.raise_for_status()
202198

203-
# 3. 扫描启动成功后,调用 celery_task_add 写入 DB
199+
# 写入数据库
204200
celery_task_add(
205201
session=session,
206202
task_id=sqlmap_task_id,
207-
celery_task_id=self.request.id, # Celery 任务 ID
208-
scan_url=str(payload["url"]), # 转成 str,防止 HttpUrl 错误
203+
celery_task_id=self.request.id,
204+
scan_url=str(payload["url"]),
209205
status="running",
210206
scan_risk=payload.get("risk", 1),
211207
scan_level=payload.get("level", 1),

0 commit comments

Comments
 (0)