Skip to content

Commit

Permalink
Retry in case of 429 in v2
Browse files Browse the repository at this point in the history
  • Loading branch information
ekouts committed Jan 20, 2025
1 parent a1d7c16 commit bfc0ee8
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 13 deletions.
5 changes: 4 additions & 1 deletion firecrest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def slurm_state_completed(state):
'TIMEOUT',
}
if state:
return all(s in completion_states for s in state.split(','))
# Make sure all the steps include one of the completion states
return all(
any(cs in s for cs in completion_states) for s in state.split(',')
)

return False

Expand Down
4 changes: 2 additions & 2 deletions firecrest/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from firecrest.v2._async.Client import AsyncFirecrest
from firecrest.v2._sync.Client import Firecrest
from firecrest.v2._async.Client import AsyncFirecrest # noqa
from firecrest.v2._sync.Client import Firecrest # noqa
49 changes: 44 additions & 5 deletions firecrest/v2/_async/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Optional, List

from firecrest.utilities import (
slurm_state_completed, time_block
parse_retry_after, slurm_state_completed, time_block
)
from firecrest.FirecrestException import (
FirecrestException,
Expand Down Expand Up @@ -71,6 +71,45 @@ class AsyncFirecrest:

TOO_MANY_REQUESTS_CODE = 429

def _retry_requests(func):
async def wrapper(*args, **kwargs):
client = args[0]
num_retries = 0
resp = await func(*args, **kwargs)
while True:
if resp.status_code != client.TOO_MANY_REQUESTS_CODE:
break
elif (
client.num_retries_rate_limit is not None
and num_retries >= client.num_retries_rate_limit
):
client.log(
logging.DEBUG,
f"Rate limit is reached and the request has "
f"been retried already {num_retries} times"
)
break
else:
reset = resp.headers.get(
"Retry-After",
default=resp.headers.get(
"RateLimit-Reset", default=10
),
)
reset = parse_retry_after(reset, client.log)
client.log(
logging.INFO,
f"Rate limit is reached, will sleep for "
f"{reset} seconds and try again"
)
await asyncio.sleep(reset)
resp = await func(*args, **kwargs)
num_retries += 1

return resp

return wrapper

def __init__(
self,
firecrest_url: str,
Expand Down Expand Up @@ -129,7 +168,7 @@ def log(self, level: int, msg: Any) -> None:
if not self.disable_client_logging:
logger.log(level, msg)

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
async def _get_request(
self,
endpoint,
Expand All @@ -151,7 +190,7 @@ async def _get_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
async def _post_request(
self, endpoint, additional_headers=None, params=None, data=None, files=None
) -> httpx.Response:
Expand All @@ -175,7 +214,7 @@ async def _post_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
async def _put_request(
self, endpoint, additional_headers=None, data=None
) -> httpx.Response:
Expand All @@ -194,7 +233,7 @@ async def _put_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
async def _delete_request(
self, endpoint, additional_headers=None, params=None, data=None
) -> httpx.Response:
Expand Down
49 changes: 44 additions & 5 deletions firecrest/v2/_sync/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Optional, List

from firecrest.utilities import (
slurm_state_completed, time_block
parse_retry_after, slurm_state_completed, time_block
)
from firecrest.FirecrestException import (
FirecrestException,
Expand Down Expand Up @@ -70,6 +70,45 @@ class Firecrest:

TOO_MANY_REQUESTS_CODE = 429

def _retry_requests(func):
def wrapper(*args, **kwargs):
client = args[0]
num_retries = 0
resp = func(*args, **kwargs)
while True:
if resp.status_code != client.TOO_MANY_REQUESTS_CODE:
break
elif (
client.num_retries_rate_limit is not None
and num_retries >= client.num_retries_rate_limit
):
client.log(
logging.DEBUG,
f"Rate limit is reached and the request has "
f"been retried already {num_retries} times"
)
break
else:
reset = resp.headers.get(
"Retry-After",
default=resp.headers.get(
"RateLimit-Reset", default=10
),
)
reset = parse_retry_after(reset, client.log)
client.log(
logging.INFO,
f"Rate limit is reached, will sleep for "
f"{reset} seconds and try again"
)
time.sleep(reset)
resp = func(*args, **kwargs)
num_retries += 1

return resp

return wrapper

def __init__(
self,
firecrest_url: str,
Expand Down Expand Up @@ -128,7 +167,7 @@ def log(self, level: int, msg: Any) -> None:
if not self.disable_client_logging:
logger.log(level, msg)

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
def _get_request(
self,
endpoint,
Expand All @@ -150,7 +189,7 @@ def _get_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
def _post_request(
self, endpoint, additional_headers=None, params=None, data=None, files=None
) -> httpx.Response:
Expand All @@ -174,7 +213,7 @@ def _post_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
def _put_request(
self, endpoint, additional_headers=None, data=None
) -> httpx.Response:
Expand All @@ -193,7 +232,7 @@ def _put_request(

return resp

# @_retry_requests # type: ignore
@_retry_requests # type: ignore
def _delete_request(
self, endpoint, additional_headers=None, params=None, data=None
) -> httpx.Response:
Expand Down

0 comments on commit bfc0ee8

Please sign in to comment.