|
1 | | -import asyncio |
2 | 1 | import requests |
3 | | -from celery import shared_task |
4 | 2 | from datetime import datetime |
5 | | -from sqlalchemy import select |
| 3 | +from sqlalchemy.exc import SQLAlchemyError |
6 | 4 |
|
7 | | -from app.database.database import AsyncSessionLocal |
| 5 | +from app.middleware.celery_app import celery_app |
| 6 | +from app.database.celery_sync_database import SessionLocal |
8 | 7 | from app.models.sqlmap_result import ( |
9 | 8 | SqlmapScanPayload, |
10 | 9 | SqlmapScanLog, |
11 | 10 | ScanStatus, |
| 11 | + SqlmapScanResult, |
12 | 12 | ) |
13 | 13 | import os |
14 | 14 |
|
15 | 15 | SQLMAP_API = os.getenv("SQLMAP_API") |
16 | | -AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD")) |
| 16 | +AUTH = (os.getenv("SQLMAP_USERNAME"), os.getenv("SQLMAP_PASSWORD")) # Basic Auth |
17 | 17 |
|
18 | 18 |
|
19 | | -def fetch_status(task_id: str) -> dict: |
20 | | - r = requests.get(f"{SQLMAP_API}/scan/{task_id}/status", auth=AUTH, timeout=10) |
21 | | - r.raise_for_status() |
22 | | - return r.json() |
23 | | - |
24 | | - |
25 | | -def fetch_log(task_id: str) -> dict: |
26 | | - r = requests.get(f"{SQLMAP_API}/scan/{task_id}/log", auth=AUTH, timeout=10) |
27 | | - r.raise_for_status() |
28 | | - return r.json() |
29 | | - |
30 | | - |
31 | | -@shared_task( |
| 19 | +@celery_app.task( |
32 | 20 | bind=True, |
33 | 21 | autoretry_for=(Exception,), |
34 | 22 | retry_backoff=5, |
35 | 23 | retry_kwargs={"max_retries": 3}, |
| 24 | + name="app.tasks.sqlmap_worker.poll_single_sqlmap_task", |
36 | 25 | ) |
37 | 26 | def poll_single_sqlmap_task(self, task_id: str): |
38 | | - """ |
39 | | - Worker:轮询单个 sqlmap 任务 |
40 | | - """ |
41 | | - |
42 | | - async def _run(): |
43 | | - async with AsyncSessionLocal() as session: |
44 | | - task = await session.scalar( |
45 | | - select(SqlmapScanPayload).where(SqlmapScanPayload.task_id == task_id) |
46 | | - ) |
47 | | - |
48 | | - if not task: |
49 | | - return |
50 | | - |
51 | | - # 已结束 → 不再轮询 |
52 | | - if task.status in ( |
53 | | - ScanStatus.success, |
54 | | - ScanStatus.failed, |
55 | | - ScanStatus.stopped, |
56 | | - ): |
57 | | - return |
58 | | - |
59 | | - # --- 调 sqlmap --- |
60 | | - status_data = fetch_status(task_id) |
61 | | - log_data = fetch_log(task_id) |
62 | | - |
63 | | - sqlmap_status = status_data.get("status") |
64 | | - |
65 | | - # --- 状态映射 --- |
66 | | - if sqlmap_status == "running": |
67 | | - task.status = ScanStatus.running |
68 | | - |
69 | | - elif sqlmap_status == "terminated": |
70 | | - task.status = ScanStatus.success |
71 | | - task.finished_at = datetime.utcnow() |
72 | | - |
73 | | - elif sqlmap_status == "error": |
74 | | - task.status = ScanStatus.failed |
75 | | - task.finished_at = datetime.utcnow() |
76 | | - |
77 | | - # --- 写日志(全量 or 增量)--- |
78 | | - for item in log_data.get("log", []): |
79 | | - session.add( |
80 | | - SqlmapScanLog( |
81 | | - task_id=task_id, |
82 | | - level=item.get("level", "INFO"), |
83 | | - message=item.get("message", ""), |
84 | | - log_time=item.get("time"), |
85 | | - ) |
86 | | - ) |
87 | | - |
88 | | - await session.commit() |
89 | | - |
90 | | - asyncio.run(_run()) |
| 27 | + session = SessionLocal() |
| 28 | + try: |
| 29 | + task = ( |
| 30 | + session.query(SqlmapScanPayload) |
| 31 | + .filter(SqlmapScanPayload.task_id == task_id) |
| 32 | + .first() |
| 33 | + ) |
| 34 | + |
| 35 | + if not task: |
| 36 | + return |
| 37 | + |
| 38 | + # 查询 sqlmap task 状态 |
| 39 | + resp = requests.get( |
| 40 | + f"{SQLMAP_API}/scan/{task_id}/status", |
| 41 | + timeout=10, |
| 42 | + auth=AUTH, |
| 43 | + ) |
| 44 | + resp.raise_for_status() |
| 45 | + status_data = resp.json() |
| 46 | + |
| 47 | + status = status_data.get("status") |
| 48 | + |
| 49 | + if status == "running": |
| 50 | + task.status = ScanStatus.running |
| 51 | + session.commit() |
| 52 | + return |
| 53 | + |
| 54 | + if status != "terminated": |
| 55 | + return |
| 56 | + |
| 57 | + # 获取扫描结果 |
| 58 | + result_resp = requests.get( |
| 59 | + f"{SQLMAP_API}/scan/{task_id}/data", |
| 60 | + timeout=30, |
| 61 | + auth=AUTH, |
| 62 | + ) |
| 63 | + result_resp.raise_for_status() |
| 64 | + data = result_resp.json() |
| 65 | + |
| 66 | + # 解析 sqlmap 返回 |
| 67 | + scan_result = SqlmapScanResult( |
| 68 | + target_url=task.scan_url, |
| 69 | + dbms=data.get("dbms"), |
| 70 | + vulnerable=bool(data.get("data")), |
| 71 | + injection_points=data.get("data"), |
| 72 | + dump_data=data.get("dump"), |
| 73 | + raw_output=data.get("raw"), |
| 74 | + command=data.get("command", ""), |
| 75 | + started_at=datetime.utcnow(), |
| 76 | + finished_at=datetime.utcnow(), |
| 77 | + ) |
| 78 | + |
| 79 | + session.add(scan_result) |
| 80 | + task.status = ScanStatus.success |
| 81 | + |
| 82 | + session.commit() |
| 83 | + |
| 84 | + except Exception: |
| 85 | + session.rollback() |
| 86 | + task.status = ScanStatus.failed |
| 87 | + session.commit() |
| 88 | + raise |
| 89 | + finally: |
| 90 | + session.close() |
0 commit comments