|
10 | 10 |
|
11 | 11 | import ghstack.github |
12 | 12 |
|
13 | | -MAX_RETRIES = 5 |
14 | | -INITIAL_BACKOFF_SECONDS = 60 |
15 | | - |
16 | 13 |
|
17 | 14 | class RealGitHubEndpoint(ghstack.github.GitHubEndpoint): |
18 | 15 | """ |
@@ -57,19 +54,30 @@ def rest_endpoint(self) -> str: |
57 | 54 | # Passed to requests as 'cert'. |
58 | 55 | cert: Optional[Union[str, Tuple[str, str]]] |
59 | 56 |
|
| 57 | + # The maximum number of times to retry a request before giving up. |
| 58 | + max_retries: int |
| 59 | + |
| 60 | + # The initial backoff time to use, in seconds. We will double this |
| 61 | + # time for each retry. |
| 62 | + initial_backoff_seconds: int |
| 63 | + |
60 | 64 | def __init__( |
61 | 65 | self, |
62 | 66 | oauth_token: Optional[str], |
63 | 67 | github_url: str, |
64 | 68 | proxy: Optional[str] = None, |
65 | 69 | verify: Optional[str] = None, |
66 | 70 | cert: Optional[Union[str, Tuple[str, str]]] = None, |
| 71 | + max_retries: int = 5, |
| 72 | + initial_backoff_seconds: int = 60, |
67 | 73 | ): |
68 | 74 | self.oauth_token = oauth_token |
69 | 75 | self.proxy = proxy |
70 | 76 | self.github_url = github_url |
71 | 77 | self.verify = verify |
72 | 78 | self.cert = cert |
| 79 | + self.max_retries = max_retries |
| 80 | + self.initial_backoff_seconds = initial_backoff_seconds |
73 | 81 |
|
74 | 82 | def push_hook(self, refName: Sequence[str]) -> None: |
75 | 83 | pass |
@@ -160,8 +168,8 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: |
160 | 168 |
|
161 | 169 | url = self.rest_endpoint.format(github_url=self.github_url) + "/" + path |
162 | 170 |
|
163 | | - backoff_seconds = INITIAL_BACKOFF_SECONDS |
164 | | - for attempt in range(0, MAX_RETRIES): |
| 171 | + backoff_seconds = self.initial_backoff_seconds |
| 172 | + for attempt in range(0, self.max_retries): |
165 | 173 | logging.debug("# {} {}".format(method, url)) |
166 | 174 | logging.debug("Request body:\n{}".format(json.dumps(kwargs, indent=1))) |
167 | 175 |
|
@@ -189,29 +197,41 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: |
189 | 197 | if resp.status_code in (403, 429): |
190 | 198 | remaining_count = resp.headers.get("x-ratelimit-remaining") |
191 | 199 | reset_time = resp.headers.get("x-ratelimit-reset") |
| 200 | + more_attempts = attempt < (self.max_retries - 1) |
192 | 201 |
|
193 | 202 | if remaining_count == "0" and reset_time: |
194 | | - sleep_time = int(reset_time) - int(time.time()) |
195 | | - logging.warning( |
196 | | - f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." |
197 | | - ) |
198 | | - time.sleep(sleep_time) |
199 | | - continue |
| 203 | + sleep_time = max(0, int(reset_time) - int(time.time())) |
| 204 | + if more_attempts and sleep_time > 0: |
| 205 | + logging.warning( |
| 206 | + f"Rate limit exceeded. Sleeping until reset in {sleep_time} seconds." |
| 207 | + ) |
| 208 | + time.sleep(sleep_time) |
| 209 | + continue |
| 210 | + else: |
| 211 | + raise RuntimeError(pretty_json) |
200 | 212 | else: |
201 | 213 | retry_after_seconds = resp.headers.get("retry-after") |
202 | 214 | if retry_after_seconds: |
203 | 215 | sleep_time = int(retry_after_seconds) |
204 | | - logging.warning( |
205 | | - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." |
206 | | - ) |
| 216 | + if more_attempts and sleep_time > 0: |
| 217 | + logging.warning( |
| 218 | + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds." |
| 219 | + ) |
| 220 | + time.sleep(sleep_time) |
| 221 | + continue |
| 222 | + else: |
| 223 | + raise RuntimeError(pretty_json) |
207 | 224 | else: |
208 | 225 | sleep_time = backoff_seconds |
209 | | - logging.warning( |
210 | | - f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." |
211 | | - ) |
212 | | - backoff_seconds *= 2 |
213 | | - time.sleep(sleep_time) |
214 | | - continue |
| 226 | + if more_attempts and sleep_time > 0: |
| 227 | + logging.warning( |
| 228 | + f"Secondary rate limit hit. Sleeping for {sleep_time} seconds (exponential backoff)." |
| 229 | + ) |
| 230 | + backoff_seconds *= 2 |
| 231 | + time.sleep(sleep_time) |
| 232 | + continue |
| 233 | + else: |
| 234 | + raise RuntimeError(pretty_json) |
215 | 235 |
|
216 | 236 | if resp.status_code == 404: |
217 | 237 | raise ghstack.github.NotFoundError( |
|
0 commit comments