Skip to content

Commit 07edb32

Browse files
authored
Expand infinite loop fix to have hardcoded guard rail for max requests (#2250)
* Expand infinite loop fix to have hardcoded guard rail * Add simple tests * Fix docstring typo
1 parent 98cb352 commit 07edb32

2 files changed

Lines changed: 78 additions & 4 deletions

File tree

optimade/client/client.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ class OptimadeClient:
9090
max_attempts: int
9191
"""The maximum number of times to repeat a failed query before giving up."""
9292

93+
max_requests_per_provider: int = 2_000_000
94+
"""An upper limit guard rail to avoid infinite hanging of the client on malformed APIs.
95+
If available, a better value will be estimated for each API based on the total number of entries.
96+
97+
"""
98+
9399
use_async: bool
94100
"""Whether or not to make all requests asynchronously using asyncio."""
95101

@@ -957,9 +963,13 @@ async def _get_one_async(
957963
request_delay: float | None = None
958964

959965
results = QueryResults()
966+
number_of_requests = 0
967+
total_data_available: int | None = None
960968
try:
961969
async with self._http_client(headers=self.headers) as client: # type: ignore[union-attr,call-arg,misc]
962970
while next_url:
971+
number_of_requests += 1
972+
963973
attempts = 0
964974
try:
965975
if self.verbosity:
@@ -975,6 +985,21 @@ async def _get_one_async(
975985
if request_delay:
976986
request_delay = min(request_delay, 5)
977987

988+
# Compute the upper limit guard rail on pagination requests based on the number of entries in the entire db
989+
# and the chosen page limit
990+
if total_data_available is None:
991+
total_data_available = page_results["meta"].get(
992+
"data_available", 0
993+
)
994+
page_limit = len(page_results["data"])
995+
if total_data_available and total_data_available > 0:
996+
stopping_criteria = min(
997+
math.ceil(total_data_available / page_limit),
998+
self.max_requests_per_provider,
999+
)
1000+
else:
1001+
stopping_criteria = self.max_results_per_provider
1002+
9781003
except RecoverableHTTPError:
9791004
attempts += 1
9801005
if attempts > self.max_attempts:
@@ -989,9 +1014,9 @@ async def _get_one_async(
9891014
if not paginate:
9901015
break
9911016

992-
if len(results.data) == 0:
1017+
if len(results.data) == 0 or number_of_requests > stopping_criteria:
9931018
if next_url:
994-
message = f"{base_url} unexpectedly stopped returning results. Stopping download."
1019+
message = f"Detected potential infinite loop for {base_url} (more than {stopping_criteria=} requests made). Stopping download."
9951020
results.errors.append(message)
9961021
if not self.silent:
9971022
self._progress.print(message)
@@ -1041,6 +1066,8 @@ def _get_one(
10411066
request_delay: float | None = None
10421067

10431068
results = QueryResults()
1069+
number_of_requests: int = 0
1070+
total_data_available: int | None = None
10441071
try:
10451072
with self._http_client() as client: # type: ignore[misc]
10461073
client.headers.update(self.headers)
@@ -1050,6 +1077,7 @@ def _get_one(
10501077
timeout = (self.http_timeout.connect, self.http_timeout.read)
10511078

10521079
while next_url:
1080+
number_of_requests += 1
10531081
attempts = 0
10541082
try:
10551083
if self.verbosity:
@@ -1059,6 +1087,21 @@ def _get_one(
10591087
r = client.get(next_url, timeout=timeout)
10601088
page_results, next_url = self._handle_response(r, _task)
10611089

1090+
# Compute the upper limit guard rail on pagination requests based on the number of entries in the entire db
1091+
# and the chosen page limit
1092+
if total_data_available is None:
1093+
total_data_available = page_results["meta"].get(
1094+
"data_available", 0
1095+
)
1096+
page_limit = len(page_results["data"])
1097+
if total_data_available and total_data_available > 0:
1098+
stopping_criteria = min(
1099+
math.ceil(total_data_available / page_limit),
1100+
self.max_requests_per_provider,
1101+
)
1102+
else:
1103+
stopping_criteria = self.max_results_per_provider
1104+
10621105
request_delay = page_results["meta"].get("request_delay", None)
10631106
# Don't wait any longer than 5 seconds
10641107
if request_delay:
@@ -1075,9 +1118,9 @@ def _get_one(
10751118

10761119
results.update(page_results)
10771120

1078-
if len(results.data) == 0:
1121+
if len(results.data) == 0 or number_of_requests > stopping_criteria:
10791122
if next_url:
1080-
message = f"{base_url} unexpectedly stopped returning results. Stopping download."
1123+
message = f"Detected potential infinite loop for {base_url} (more than {stopping_criteria=} requests made). Stopping download."
10811124
results.errors.append(message)
10821125
if not self.silent:
10831126
self._progress.print(message)

tests/server/test_client.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,34 @@ async def test_raw_get_one_async(async_http_client):
599599
paginate=False,
600600
)
601601
assert len(override[TEST_URL].data) == 1
602+
603+
604+
@pytest.mark.asyncio
605+
async def test_guardrail_async(async_http_client):
606+
"""Test the upper limit on requests guard rail."""
607+
cli = OptimadeClient(
608+
base_urls=[TEST_URL],
609+
http_client=async_http_client,
610+
use_async=True,
611+
)
612+
cli.max_requests_per_provider = 1
613+
result = await cli.get_one_async(
614+
endpoint="structures", base_url=TEST_URL, filter="", page_limit=5, paginate=True
615+
)
616+
assert len(result[TEST_URL].errors) == 1
617+
assert "infinite" in result[TEST_URL].errors[0]
618+
619+
620+
def test_guardrail_sync(http_client):
621+
"""Test the upper limit on requests guard rail."""
622+
cli = OptimadeClient(
623+
base_urls=[TEST_URL],
624+
http_client=http_client,
625+
use_async=False,
626+
)
627+
cli.max_requests_per_provider = 1
628+
result = cli.get_one(
629+
endpoint="structures", base_url=TEST_URL, filter="", page_limit=5, paginate=True
630+
)
631+
assert len(result[TEST_URL].errors) == 1
632+
assert "infinite" in result[TEST_URL].errors[0]

0 commit comments

Comments
 (0)