55from datetime import datetime , timedelta
66from json .decoder import JSONDecodeError
77import sys
8- from typing import TYPE_CHECKING , Any , Callable
8+ from typing import TYPE_CHECKING , Any , Callable , cast
99
1010from aiohttp import ClientSession
1111from aiohttp .client_exceptions import ClientResponseError
@@ -68,6 +68,7 @@ def __init__(
6868 ) -> None :
6969 """Initialize."""
7070 self ._refresh_token_callbacks : list [Callable [..., None ]] = []
71+ self ._request_retries = request_retries
7172 self .session : ClientSession = session
7273
7374 # These will get filled in after initial authentication:
@@ -79,16 +80,7 @@ def __init__(
7980 self .user_id : int | None = None
8081 self .websocket : WebsocketClient | None = None
8182
82- # Implement a version of the request coroutine, but with backoff/retry logic:
83- self .async_request = backoff .on_exception (
84- backoff .expo ,
85- ClientResponseError ,
86- jitter = backoff .random_jitter ,
87- logger = LOGGER ,
88- max_tries = request_retries ,
89- on_backoff = self ._async_handle_on_backoff ,
90- on_giveup = self ._handle_on_giveup ,
91- )(self ._async_request )
83+ self .async_request = self ._wrap_request_method (self ._request_retries )
9284
9385 @classmethod
9486 async def async_from_auth (
@@ -264,6 +256,29 @@ def _handle_on_giveup(_: dict[str, Any]) -> None:
264256 err = err_info [1 ].with_traceback (err_info [2 ]) # type: ignore
265257 raise RequestError (err ) from err
266258
259+ def _wrap_request_method (self , request_retries : int ) -> Callable :
260+ """Wrap the request method in backoff/retry logic."""
261+ return cast (
262+ Callable ,
263+ backoff .on_exception (
264+ backoff .expo ,
265+ ClientResponseError ,
266+ jitter = backoff .random_jitter ,
267+ logger = LOGGER ,
268+ max_tries = request_retries ,
269+ on_backoff = self ._async_handle_on_backoff ,
270+ on_giveup = self ._handle_on_giveup ,
271+ )(self ._async_request ),
272+ )
273+
274+ def disable_request_retries (self ) -> None :
275+ """Disable the request retry mechanism."""
276+ self .async_request = self ._wrap_request_method (1 )
277+
278+ def enable_request_retries (self ) -> None :
279+ """Enable the request retry mechanism."""
280+ self .async_request = self ._wrap_request_method (self ._request_retries )
281+
267282 def add_refresh_token_callback (
268283 self , callback : Callable [..., None ]
269284 ) -> Callable [..., None ]:
0 commit comments