33import re
44import sys
55import zlib
6+ from typing import Any , Dict , Tuple , Union
67
78from logging_utils .setup import ColoredFormatter
89
@@ -12,7 +13,7 @@ class HttpInterceptor:
1213 Class to intercept and process HTTP requests and responses
1314 """
1415
15- def __init__ (self , log_dir = "logs" ):
16+ def __init__ (self , log_dir : str = "logs" ):
1617 self .log_dir = log_dir
1718 self .logger = logging .getLogger ("http_interceptor" )
1819 self .setup_logging ()
@@ -36,7 +37,7 @@ def setup_logging():
3637 logging .getLogger ("websockets" ).setLevel (logging .ERROR )
3738
3839 @staticmethod
39- def should_intercept (host , path ):
40+ def should_intercept (host : str , path : str ):
4041 """
4142 Determine if the request should be intercepted based on host and path
4243 """
@@ -47,7 +48,9 @@ def should_intercept(host, path):
4748 # Add more conditions as needed
4849 return False
4950
50- async def process_request (self , request_data , host , path ):
51+ async def process_request (
52+ self , request_data : Union [int , bytes ], host : str , path : str
53+ ) -> Union [int , bytes ]:
5154 """
5255 Process the request data before sending to the server
5356 """
@@ -63,7 +66,13 @@ async def process_request(self, request_data, host, path):
6366 # Not JSON or not UTF-8, just pass through
6467 return request_data
6568
66- async def process_response (self , response_data , host , path , headers ):
69+ async def process_response (
70+ self ,
71+ response_data : Union [int , bytes ],
72+ host : str ,
73+ path : str ,
74+ headers : Dict [Any , Any ],
75+ ) -> Dict [str , Any ]:
6776 """
6877 Process the response data before sending to the client
6978 """
@@ -78,7 +87,7 @@ async def process_response(self, response_data, host, path, headers):
7887 except Exception as e :
7988 raise e
8089
81- def parse_response (self , response_data ) :
90+ def parse_response (self , response_data : bytes ) -> Dict [ str , Any ] :
8291 pattern = rb'\[\[\[null,.*?]],"model"]'
8392 matches = []
8493 for match_obj in re .finditer (pattern , response_data ):
@@ -114,7 +123,7 @@ def parse_response(self, response_data):
114123
115124 return resp
116125
117- def parse_toolcall_params (self , args ) :
126+ def parse_toolcall_params (self , args : Any ) -> Dict [ str , Any ] :
118127 try :
119128 params = args [0 ]
120129 func_params = {}
@@ -140,13 +149,13 @@ def parse_toolcall_params(self, args):
140149 raise e
141150
142151 @staticmethod
143- def _decompress_zlib_stream (compressed_stream ) :
152+ def _decompress_zlib_stream (compressed_stream : Union [ bytearray , bytes ]) -> bytes :
144153 decompressor = zlib .decompressobj (wbits = zlib .MAX_WBITS | 32 ) # zlib header
145154 decompressed = decompressor .decompress (compressed_stream )
146155 return decompressed
147156
148157 @staticmethod
149- def _decode_chunked (response_body : bytes ) -> tuple [bytes , bool ]:
158+ def _decode_chunked (response_body : bytes ) -> Tuple [bytes , bool ]:
150159 chunked_data = bytearray ()
151160 while True :
152161 # print(' '.join(format(x, '02x') for x in response_body))
@@ -165,7 +174,7 @@ def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
165174 if length == 0 :
166175 length_crlf_idx = response_body .find (b"0\r \n \r \n " )
167176 if length_crlf_idx != - 1 :
168- return chunked_data , True
177+ return bytes ( chunked_data ) , True
169178
170179 if length + 2 > len (response_body ):
171180 break
@@ -177,4 +186,4 @@ def _decode_chunked(response_body: bytes) -> tuple[bytes, bool]:
177186 break
178187
179188 response_body = response_body [length_crlf_idx + 2 + length + 2 :]
180- return chunked_data , False
189+ return bytes ( chunked_data ) , False
0 commit comments