|
14 | 14 | from ..logging_config import get_logger |
15 | 15 | from ..path_fix import PathFixRoute |
16 | 16 | from ..key_store import upsert_account_keys_in_store |
17 | | -from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases |
| 17 | +from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases, scan_account_databases_from_path |
18 | 18 |
|
19 | 19 | logger = get_logger(__name__) |
20 | 20 |
|
@@ -79,6 +79,8 @@ async def decrypt_databases(request: DecryptRequest): |
79 | 79 | "account_results": results.get("account_results", {}), |
80 | 80 | } |
81 | 81 |
|
| 82 | + except HTTPException: |
| 83 | + raise |
82 | 84 | except Exception as e: |
83 | 85 | logger.error(f"解密API异常: {str(e)}") |
84 | 86 | raise HTTPException(status_code=500, detail=str(e)) |
@@ -126,44 +128,17 @@ async def generate_progress(): |
126 | 128 | yield _sse({"type": "scanning", "message": "正在扫描数据库文件..."}) |
127 | 129 | await asyncio.sleep(0) |
128 | 130 |
|
129 | | - account_name = "unknown_account" |
130 | | - path_parts = storage_path.parts |
131 | | - account_patterns = ["wxid_"] |
132 | | - for part in path_parts: |
133 | | - for pattern in account_patterns: |
134 | | - if part.startswith(pattern): |
135 | | - parts = part.split("_") |
136 | | - if len(parts) >= 3: |
137 | | - account_name = "_".join(parts[:-1]) |
138 | | - else: |
139 | | - account_name = part |
140 | | - break |
141 | | - if account_name != "unknown_account": |
142 | | - break |
143 | | - |
144 | | - if account_name == "unknown_account": |
145 | | - for part in reversed(path_parts): |
146 | | - if part != "db_storage" and len(part) > 3: |
147 | | - account_name = part |
148 | | - break |
149 | | - |
150 | | - databases: list[dict] = [] |
151 | | - for root, _dirs, files in os.walk(storage_path): |
152 | | - if "db_storage" not in str(root): |
153 | | - continue |
154 | | - for file_name in files: |
155 | | - if not file_name.endswith(".db"): |
156 | | - continue |
157 | | - if file_name in ["key_info.db"]: |
158 | | - continue |
159 | | - db_path = os.path.join(root, file_name) |
160 | | - databases.append({"path": db_path, "name": file_name, "account": account_name}) |
161 | | - |
162 | | - if not databases: |
163 | | - yield _sse({"type": "error", "message": "未找到微信数据库文件!请检查 db_storage_path 是否正确"}) |
| 131 | + scan_result = scan_account_databases_from_path(p) |
| 132 | + if scan_result["status"] == "error": |
| 133 | + payload = {"type": "error", "message": scan_result["message"]} |
| 134 | + detected_accounts = scan_result.get("detected_accounts") or [] |
| 135 | + if detected_accounts: |
| 136 | + payload["detected_accounts"] = detected_accounts |
| 137 | + yield _sse(payload) |
164 | 138 | return |
165 | 139 |
|
166 | | - account_databases = {account_name: databases} |
| 140 | + account_databases = scan_result.get("account_databases", {}) |
| 141 | + account_sources = scan_result.get("account_sources", {}) |
167 | 142 | total_databases = sum(len(dbs) for dbs in account_databases.values()) |
168 | 143 |
|
169 | 144 | yield _sse({"type": "start", "total": total_databases, "message": f"开始解密 {total_databases} 个数据库"}) |
@@ -193,12 +168,9 @@ async def generate_progress(): |
193 | 168 |
|
194 | 169 | # Save a hint for later UI (same as non-stream endpoint). |
195 | 170 | try: |
196 | | - source_db_storage_path = p |
197 | | - wxid_dir = "" |
198 | | - if storage_path.name.lower() == "db_storage": |
199 | | - wxid_dir = str(storage_path.parent) |
200 | | - else: |
201 | | - wxid_dir = str(storage_path) |
| 171 | + source_info = account_sources.get(account, {}) |
| 172 | + source_db_storage_path = str(source_info.get("db_storage_path") or p) |
| 173 | + wxid_dir = str(source_info.get("wxid_dir") or "") |
202 | 174 | (account_output_dir / "_source.json").write_text( |
203 | 175 | json.dumps({"db_storage_path": source_db_storage_path, "wxid_dir": wxid_dir}, ensure_ascii=False, indent=2), |
204 | 176 | encoding="utf-8", |
|
0 commit comments