Skip to content

Commit 7dffad2

Browse files
committed
refactor: add strict type annotations to medium-risk modules
1 parent 6170d95 commit 7dffad2

9 files changed

Lines changed: 99 additions & 59 deletions

File tree

launcher/checks.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import logging
22
import os
33
import sys
4+
from typing import Dict, List, Optional
45

56
from launcher.config import ACTIVE_AUTH_DIR, SAVED_AUTH_DIR
67

78
logger = logging.getLogger("CamoufoxLauncher")
89

910

10-
def ensure_auth_dirs_exist():
11+
def ensure_auth_dirs_exist() -> None:
1112
logger.info("正在检查并确保认证文件目录存在...")
1213
try:
1314
os.makedirs(ACTIVE_AUTH_DIR, exist_ok=True)
@@ -19,17 +20,19 @@ def ensure_auth_dirs_exist():
1920
sys.exit(1)
2021

2122

22-
def check_dependencies(launch_server, DefaultAddons):
23+
def check_dependencies(
24+
launch_server: Optional[bool], DefaultAddons: Optional[bool]
25+
) -> None:
2326
logger.info("--- 步骤 1: 检查依赖项 ---")
24-
required_modules = {}
27+
required_modules: Dict[str, str] = {}
2528
if launch_server is not None and DefaultAddons is not None:
2629
required_modules["camoufox"] = "camoufox (for server and addons)"
2730
elif launch_server is not None:
2831
required_modules["camoufox_server"] = "camoufox.server"
2932
logger.warning(
3033
" 'camoufox.server' 已导入,但 'camoufox.DefaultAddons' 未导入。排除插件功能可能受限。"
3134
)
32-
missing_py_modules = []
35+
missing_py_modules: List[str] = []
3336
dependencies_ok = True
3437
if required_modules:
3538
logger.info("正在检查 Python 模块:")

launcher/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import sys
44
from dataclasses import dataclass
5-
from typing import Optional
5+
from typing import Dict, Optional
66

77
from launcher.utils import get_proxy_from_gsettings
88

@@ -64,7 +64,9 @@ class LauncherConfig:
6464
internal_camoufox_os: str = "random"
6565

6666

67-
def determine_proxy_configuration(internal_camoufox_proxy_arg=None):
67+
def determine_proxy_configuration(
68+
internal_camoufox_proxy_arg: Optional[str] = None,
69+
) -> Dict[str, Optional[str]]:
6870
"""
6971
统一的代理配置确定函数
7072
按优先级顺序:命令行参数 > 环境变量 > 系统设置

launcher/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import subprocess
66
import sys
77
import threading
8+
from typing import List, Optional
89

910
logger = logging.getLogger("CamoufoxLauncher")
1011

@@ -22,8 +23,8 @@ def is_port_in_use(port: int, host: str = "0.0.0.0") -> bool:
2223
return True
2324

2425

25-
def find_pids_on_port(port: int) -> list[int]:
26-
pids = []
26+
def find_pids_on_port(port: int) -> List[int]:
27+
pids: List[int] = []
2728
system_platform = platform.system()
2829
command = ""
2930
try:
@@ -166,7 +167,7 @@ def kill_process_interactive(pid: int) -> bool:
166167
def input_with_timeout(prompt_message: str, timeout_seconds: int = 30) -> str:
167168
print(prompt_message, end="", flush=True)
168169
if sys.platform == "win32":
169-
user_input_container = [None]
170+
user_input_container: List[Optional[str]] = [None]
170171

171172
def get_input_in_thread():
172173
try:
@@ -190,13 +191,13 @@ def get_input_in_thread():
190191
return ""
191192

192193

193-
def get_proxy_from_gsettings():
194+
def get_proxy_from_gsettings() -> Optional[str]:
194195
"""
195196
Retrieves the proxy settings from GSettings on Linux systems.
196197
Returns a proxy string like "http://host:port" or None.
197198
"""
198199

199-
def _run_gsettings_command(command_parts: list[str]) -> str | None:
200+
def _run_gsettings_command(command_parts: List[str]) -> Optional[str]:
200201
"""Helper function to run gsettings command and return cleaned string output."""
201202
try:
202203
process_result = subprocess.run(

stream/cert_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
from pathlib import Path
3+
from typing import Any, Tuple
34

45
from cryptography import x509
56
from cryptography.hazmat.backends import default_backend
@@ -9,7 +10,7 @@
910

1011

1112
class CertificateManager:
12-
def __init__(self, cert_dir="certs"):
13+
def __init__(self, cert_dir: str = "certs"):
1314
self.cert_dir = Path(cert_dir)
1415
self.cert_dir.mkdir(exist_ok=True)
1516

@@ -92,7 +93,7 @@ def _load_ca_cert(self):
9293
with open(self.ca_cert_path, "rb") as f:
9394
self.ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend())
9495

95-
def get_domain_cert(self, domain):
96+
def get_domain_cert(self, domain: str) -> Tuple[Any, Any]:
9697
"""Get or generate a certificate for the specified domain"""
9798
cert_path = self.cert_dir / f"{domain}.crt"
9899
key_path = self.cert_dir / f"{domain}.key"
@@ -112,7 +113,7 @@ def get_domain_cert(self, domain):
112113
# Generate new certificate
113114
return self._generate_domain_cert(domain)
114115

115-
def _generate_domain_cert(self, domain):
116+
def _generate_domain_cert(self, domain: str) -> Tuple[Any, Any]:
116117
"""Generate a certificate for the specified domain signed by the CA"""
117118
# Generate private key
118119
private_key = rsa.generate_private_key(
@@ -152,7 +153,7 @@ def _generate_domain_cert(self, domain):
152153
.add_extension(
153154
x509.SubjectAlternativeName([x509.DNSName(domain)]), critical=False
154155
)
155-
.sign(self.ca_key, hashes.SHA256(), default_backend())
156+
.sign(self.ca_key, hashes.SHA256(), default_backend()) # type: ignore[arg-type]
156157
)
157158

158159
# Write certificate to file

stream/interceptors.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import sys
55
import zlib
6+
from typing import Any, Dict, Tuple, Union
67

78
from logging_utils.setup import ColoredFormatter
89

@@ -12,7 +13,7 @@ class HttpInterceptor:
1213
Class to intercept and process HTTP requests and responses
1314
"""
1415

15-
def __init__(self, log_dir="logs"):
16+
def __init__(self, log_dir: str = "logs"):
1617
self.log_dir = log_dir
1718
self.logger = logging.getLogger("http_interceptor")
1819
self.setup_logging()
@@ -36,7 +37,7 @@ def setup_logging():
3637
logging.getLogger("websockets").setLevel(logging.ERROR)
3738

3839
@staticmethod
39-
def should_intercept(host, path):
40+
def should_intercept(host: str, path: str):
4041
"""
4142
Determine if the request should be intercepted based on host and path
4243
"""
@@ -47,7 +48,9 @@ def should_intercept(host, path):
4748
# Add more conditions as needed
4849
return False
4950

50-
async def process_request(self, request_data, host, path):
51+
async def process_request(
52+
self, request_data: Union[int, bytes], host: str, path: str
53+
) -> Union[int, bytes]:
5154
"""
5255
Process the request data before sending to the server
5356
"""
@@ -63,7 +66,13 @@ async def process_request(self, request_data, host, path):
6366
# Not JSON or not UTF-8, just pass through
6467
return request_data
6568

66-
async def process_response(self, response_data, host, path, headers):
69+
async def process_response(
70+
self,
71+
response_data: Union[int, bytes],
72+
host: str,
73+
path: str,
74+
headers: Dict[Any, Any],
75+
) -> Dict[str, Any]:
6776
"""
6877
Process the response data before sending to the client
6978
"""
@@ -78,7 +87,7 @@ async def process_response(self, response_data, host, path, headers):
7887
except Exception as e:
7988
raise e
8089

81-
def parse_response(self, response_data):
90+
def parse_response(self, response_data: bytes) -> Dict[str, Any]:
8291
pattern = rb'\[\[\[null,.*?]],"model"]'
8392
matches = []
8493
for match_obj in re.finditer(pattern, response_data):
@@ -114,7 +123,7 @@ def parse_response(self, response_data):
114123

115124
return resp
116125

117-
def parse_toolcall_params(self, args):
126+
def parse_toolcall_params(self, args: Any) -> Dict[str, Any]:
118127
try:
119128
params = args[0]
120129
func_params = {}
@@ -140,13 +149,13 @@ def parse_toolcall_params(self, args):
140149
raise e
141150

142151
@staticmethod
143-
def _decompress_zlib_stream(compressed_stream):
152+
def _decompress_zlib_stream(compressed_stream: Union[bytearray, bytes]) -> bytes:
144153
decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) # zlib header
145154
decompressed = decompressor.decompress(compressed_stream)
146155
return decompressed
147156

148157
@staticmethod
149-
def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
158+
def _decode_chunked(response_body: bytes) -> Tuple[bytes, bool]:
150159
chunked_data = bytearray()
151160
while True:
152161
# print(' '.join(format(x, '02x') for x in response_body))
@@ -165,7 +174,7 @@ def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
165174
if length == 0:
166175
length_crlf_idx = response_body.find(b"0\r\n\r\n")
167176
if length_crlf_idx != -1:
168-
return chunked_data, True
177+
return bytes(chunked_data), True
169178

170179
if length + 2 > len(response_body):
171180
break
@@ -177,4 +186,4 @@ def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
177186
break
178187

179188
response_body = response_body[length_crlf_idx + 2 + length + 2 :]
180-
return chunked_data, False
189+
return bytes(chunked_data), False

stream/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import argparse
22
import asyncio
33
import logging
4-
import multiprocessing
54
import sys
65
from pathlib import Path
6+
from typing import Any, Optional
77

88
from logging_utils import GridFormatter, set_source
99
from stream.proxy_server import ProxyServer
1010

1111

12-
def parse_args():
12+
def parse_args() -> argparse.Namespace:
1313
"""Parse command line arguments"""
1414
parser = argparse.ArgumentParser(
1515
description="HTTPS Proxy Server with SSL Inspection"
@@ -34,7 +34,7 @@ def parse_args():
3434
return parser.parse_args()
3535

3636

37-
async def main():
37+
async def main() -> None:
3838
"""Main entry point"""
3939
args = parse_args()
4040

@@ -83,7 +83,9 @@ async def main():
8383
sys.exit(1)
8484

8585

86-
async def builtin(queue: multiprocessing.Queue = None, port=None, proxy=None):
86+
async def builtin(
87+
queue: Optional[Any] = None, port: Optional[int] = None, proxy: Optional[str] = None
88+
) -> None:
8789
# Set up logging with GridFormatter for consistent output
8890
set_source("PROXY")
8991

stream/proxy_connector.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import ssl as ssl_module
33
import urllib.parse
44

5+
from typing import Any, Optional, Tuple
6+
57
from aiohttp import TCPConnector
68
from python_socks.async_.asyncio import Proxy
79

@@ -11,7 +13,7 @@ class ProxyConnector:
1113
Class to handle connections through different types of proxies
1214
"""
1315

14-
def __init__(self, proxy_url=None):
16+
def __init__(self, proxy_url: Optional[str] = None):
1517
self.proxy_url = proxy_url
1618
self.connector = None
1719

@@ -33,15 +35,18 @@ def _setup_connector(self):
3335
else:
3436
raise ValueError(f"Unsupported proxy type: {proxy_type}")
3537

36-
async def create_connection(self, host, port, ssl=None):
38+
async def create_connection(
39+
self, host: str, port: int, ssl: Optional[Any] = None
40+
) -> Tuple[Any, Any]:
3741
"""Create a connection to the target host through the proxy"""
3842
if not self.connector:
3943
# Direct connection without proxy
4044
reader, writer = await asyncio.open_connection(host, port, ssl=ssl)
4145
return reader, writer
4246

4347
# SOCKS proxy connection
44-
proxy = Proxy.from_url(self.proxy_url)
48+
assert self.proxy_url is not None # Type guard for pyright
49+
proxy = Proxy.from_url(self.proxy_url) # type: ignore[misc]
4550
sock = await proxy.connect(dest_host=host, dest_port=port)
4651
if ssl is None:
4752
reader, writer = await asyncio.open_connection(

0 commit comments

Comments
 (0)