Skip to content

Commit c9add67

Browse files
committed
增加定时和轮询任务
1 parent 22ce13d commit c9add67

3 files changed

Lines changed: 126 additions & 0 deletions

File tree

app/middleware/celery_beat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from celery.schedules import crontab
2+
from app.middleware import celery_app
3+
4+
celery_app.conf.beat_schedule = {
5+
"poll-sqlmap-tasks-every-5-seconds": {
6+
"task": "app.tasks.sqlmap_scheduler.poll_active_sqlmap_tasks",
7+
"schedule": 5.0,
8+
}
9+
}

app/tasks/sqlmap_scheduler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import asyncio
2+
from celery import shared_task
3+
from sqlalchemy import select
4+
5+
from app.database.database import AsyncSessionLocal
6+
from app.models.sqlmap_result import SqlmapScanPayload, ScanStatus
7+
from app.tasks.sqlmap_worker import poll_single_sqlmap_task
8+
9+
10+
# 轮询查询数据库中的正在运行状态的数据
11+
@shared_task
12+
def poll_active_sqlmap_tasks():
13+
async def _run():
14+
async with AsyncSessionLocal() as session:
15+
result = await session.execute(
16+
select(SqlmapScanPayload).where(
17+
SqlmapScanPayload.status.in_(
18+
[ScanStatus.pending, ScanStatus.running]
19+
)
20+
)
21+
)
22+
tasks = result.scalars().all()
23+
24+
for task in tasks:
25+
poll_single_sqlmap_task.delay(task.task_id)
26+
27+
asyncio.run(_run())

app/tasks/sqlmap_worker.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import asyncio
2+
import requests
3+
from celery import shared_task
4+
from datetime import datetime
5+
from sqlalchemy import select
6+
7+
from app.database.database import AsyncSessionLocal
8+
from app.models.sqlmap_result import (
9+
SqlmapScanPayload,
10+
SqlmapScanLog,
11+
ScanStatus,
12+
)
13+
import os
14+
15+
SQLMAP_API = os.getenv("SQLMAP_API")
16+
AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD"))
17+
18+
19+
def fetch_status(task_id: str) -> dict:
20+
r = requests.get(f"{SQLMAP_API}/scan/{task_id}/status", auth=AUTH, timeout=10)
21+
r.raise_for_status()
22+
return r.json()
23+
24+
25+
def fetch_log(task_id: str) -> dict:
26+
r = requests.get(f"{SQLMAP_API}/scan/{task_id}/log", auth=AUTH, timeout=10)
27+
r.raise_for_status()
28+
return r.json()
29+
30+
31+
@shared_task(
32+
bind=True,
33+
autoretry_for=(Exception,),
34+
retry_backoff=5,
35+
retry_kwargs={"max_retries": 3},
36+
)
37+
def poll_single_sqlmap_task(self, task_id: str):
38+
"""
39+
Worker:轮询单个 sqlmap 任务
40+
"""
41+
42+
async def _run():
43+
async with AsyncSessionLocal() as session:
44+
task = await session.scalar(
45+
select(SqlmapScanPayload).where(SqlmapScanPayload.task_id == task_id)
46+
)
47+
48+
if not task:
49+
return
50+
51+
# 已结束 → 不再轮询
52+
if task.status in (
53+
ScanStatus.success,
54+
ScanStatus.failed,
55+
ScanStatus.stopped,
56+
):
57+
return
58+
59+
# --- 调 sqlmap ---
60+
status_data = fetch_status(task_id)
61+
log_data = fetch_log(task_id)
62+
63+
sqlmap_status = status_data.get("status")
64+
65+
# --- 状态映射 ---
66+
if sqlmap_status == "running":
67+
task.status = ScanStatus.running
68+
69+
elif sqlmap_status == "terminated":
70+
task.status = ScanStatus.success
71+
task.finished_at = datetime.utcnow()
72+
73+
elif sqlmap_status == "error":
74+
task.status = ScanStatus.failed
75+
task.finished_at = datetime.utcnow()
76+
77+
# --- 写日志(全量 or 增量)---
78+
for item in log_data.get("log", []):
79+
session.add(
80+
SqlmapScanLog(
81+
task_id=task_id,
82+
level=item.get("level", "INFO"),
83+
message=item.get("message", ""),
84+
log_time=item.get("time"),
85+
)
86+
)
87+
88+
await session.commit()
89+
90+
asyncio.run(_run())

0 commit comments

Comments
 (0)