Skip to content

Commit a24b286

Browse files
committed
修改为同步数据库写入
1 parent 6173722 commit a24b286

2 files changed

Lines changed: 83 additions & 90 deletions

File tree

app/tasks/sqlmap_scheduler.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
1-
import asyncio
2-
from celery import shared_task
31
from sqlalchemy import select
4-
5-
from app.database.database import AsyncSessionLocal
2+
from app.middleware.celery_app import celery_app
3+
from app.database.celery_sync_database import SessionLocal
64
from app.models.sqlmap_result import SqlmapScanPayload, ScanStatus
75
from app.tasks.sqlmap_worker import poll_single_sqlmap_task
86

97

10-
# 轮询查询数据库中的正在运行状态的数据
11-
@shared_task
8+
@celery_app.task(name="app.tasks.sqlmap_scheduler.poll_active_sqlmap_tasks")
129
def poll_active_sqlmap_tasks():
13-
async def _run():
14-
async with AsyncSessionLocal() as session:
15-
result = await session.execute(
16-
select(SqlmapScanPayload).where(
17-
SqlmapScanPayload.status.in_(
18-
[ScanStatus.pending, ScanStatus.running]
19-
)
20-
)
10+
with SessionLocal() as session:
11+
tasks = (
12+
session.query(SqlmapScanPayload)
13+
.filter(
14+
SqlmapScanPayload.status.in_([ScanStatus.pending, ScanStatus.running])
2115
)
22-
tasks = result.scalars().all()
23-
24-
for task in tasks:
25-
poll_single_sqlmap_task.delay(task.task_id)
16+
.all()
17+
)
2618

27-
asyncio.run(_run())
19+
for task in tasks:
20+
poll_single_sqlmap_task.delay(task.task_id)

app/tasks/sqlmap_worker.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,90 @@
1-
import asyncio
21
import requests
3-
from celery import shared_task
42
from datetime import datetime
5-
from sqlalchemy import select
3+
from sqlalchemy.exc import SQLAlchemyError
64

7-
from app.database.database import AsyncSessionLocal
5+
from app.middleware.celery_app import celery_app
6+
from app.database.celery_sync_database import SessionLocal
87
from app.models.sqlmap_result import (
98
SqlmapScanPayload,
109
SqlmapScanLog,
1110
ScanStatus,
11+
SqlmapScanResult,
1212
)
1313
import os
1414

1515
SQLMAP_API = os.getenv("SQLMAP_API")
16-
AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD"))
16+
AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD")) # Basic Auth
1717

1818

19-
def fetch_status(task_id: str) -> dict:
20-
r = requests.get(f"{SQLMAP_API}/scan/{task_id}/status", auth=AUTH, timeout=10)
21-
r.raise_for_status()
22-
return r.json()
23-
24-
25-
def fetch_log(task_id: str) -> dict:
26-
r = requests.get(f"{SQLMAP_API}/scan/{task_id}/log", auth=AUTH, timeout=10)
27-
r.raise_for_status()
28-
return r.json()
29-
30-
31-
@shared_task(
19+
@celery_app.task(
3220
bind=True,
3321
autoretry_for=(Exception,),
3422
retry_backoff=5,
3523
retry_kwargs={"max_retries": 3},
24+
name="app.tasks.sqlmap_worker.poll_single_sqlmap_task",
3625
)
3726
def poll_single_sqlmap_task(self, task_id: str):
38-
"""
39-
Worker:轮询单个 sqlmap 任务
40-
"""
41-
42-
async def _run():
43-
async with AsyncSessionLocal() as session:
44-
task = await session.scalar(
45-
select(SqlmapScanPayload).where(SqlmapScanPayload.task_id == task_id)
46-
)
47-
48-
if not task:
49-
return
50-
51-
# 已结束 → 不再轮询
52-
if task.status in (
53-
ScanStatus.success,
54-
ScanStatus.failed,
55-
ScanStatus.stopped,
56-
):
57-
return
58-
59-
# --- 调 sqlmap ---
60-
status_data = fetch_status(task_id)
61-
log_data = fetch_log(task_id)
62-
63-
sqlmap_status = status_data.get("status")
64-
65-
# --- 状态映射 ---
66-
if sqlmap_status == "running":
67-
task.status = ScanStatus.running
68-
69-
elif sqlmap_status == "terminated":
70-
task.status = ScanStatus.success
71-
task.finished_at = datetime.utcnow()
72-
73-
elif sqlmap_status == "error":
74-
task.status = ScanStatus.failed
75-
task.finished_at = datetime.utcnow()
76-
77-
# --- 写日志(全量 or 增量)---
78-
for item in log_data.get("log", []):
79-
session.add(
80-
SqlmapScanLog(
81-
task_id=task_id,
82-
level=item.get("level", "INFO"),
83-
message=item.get("message", ""),
84-
log_time=item.get("time"),
85-
)
86-
)
87-
88-
await session.commit()
89-
90-
asyncio.run(_run())
27+
session = SessionLocal()
28+
try:
29+
task = (
30+
session.query(SqlmapScanPayload)
31+
.filter(SqlmapScanPayload.task_id == task_id)
32+
.first()
33+
)
34+
35+
if not task:
36+
return
37+
38+
# 查询 sqlmap task 状态
39+
resp = requests.get(
40+
f"{SQLMAP_API}/scan/{task_id}/status",
41+
timeout=10,
42+
auth=AUTH,
43+
)
44+
resp.raise_for_status()
45+
status_data = resp.json()
46+
47+
status = status_data.get("status")
48+
49+
if status == "running":
50+
task.status = ScanStatus.running
51+
session.commit()
52+
return
53+
54+
if status != "terminated":
55+
return
56+
57+
# 获取扫描结果
58+
result_resp = requests.get(
59+
f"{SQLMAP_API}/scan/{task_id}/data",
60+
timeout=30,
61+
auth=AUTH,
62+
)
63+
result_resp.raise_for_status()
64+
data = result_resp.json()
65+
66+
# 解析 sqlmap 返回
67+
scan_result = SqlmapScanResult(
68+
target_url=task.scan_url,
69+
dbms=data.get("dbms"),
70+
vulnerable=bool(data.get("data")),
71+
injection_points=data.get("data"),
72+
dump_data=data.get("dump"),
73+
raw_output=data.get("raw"),
74+
command=data.get("command", ""),
75+
started_at=datetime.utcnow(),
76+
finished_at=datetime.utcnow(),
77+
)
78+
79+
session.add(scan_result)
80+
task.status = ScanStatus.success
81+
82+
session.commit()
83+
84+
except Exception:
85+
session.rollback()
86+
task.status = ScanStatus.failed
87+
session.commit()
88+
raise
89+
finally:
90+
session.close()

0 commit comments

Comments
 (0)