Skip to content

Commit bc702c1

Browse files
Merge pull request #104 from biosimulations/batch-processing
Batch processing
2 parents 4fd072c + 93ee79d commit bc702c1

10 files changed

Lines changed: 405 additions & 23 deletions

File tree

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from http import HTTPStatus
2+
from typing import Any, Optional, Union, cast
3+
4+
import httpx
5+
6+
from ...client import AuthenticatedClient, Client
7+
from ...types import Response, UNSET
8+
from ... import errors
9+
10+
from ...models.hpc_run import HpcRun
11+
from ...models.http_validation_error import HTTPValidationError
12+
from typing import cast
13+
14+
15+
def _get_kwargs(
16+
*,
17+
body: list[int],
18+
) -> dict[str, Any]:
19+
headers: dict[str, Any] = {}
20+
21+
_kwargs: dict[str, Any] = {
22+
"method": "get",
23+
"url": "/results/simulations/status/batch",
24+
}
25+
26+
_kwargs["json"] = body
27+
28+
headers["Content-Type"] = "application/json"
29+
30+
_kwargs["headers"] = headers
31+
return _kwargs
32+
33+
34+
def _parse_response(
35+
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
36+
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
37+
if response.status_code == 200:
38+
response_200 = []
39+
_response_200 = response.json()
40+
for response_200_item_data in _response_200:
41+
response_200_item = HpcRun.from_dict(response_200_item_data)
42+
43+
response_200.append(response_200_item)
44+
45+
return response_200
46+
if response.status_code == 422:
47+
response_422 = HTTPValidationError.from_dict(response.json())
48+
49+
return response_422
50+
if client.raise_on_unexpected_status:
51+
raise errors.UnexpectedStatus(response.status_code, response.content)
52+
else:
53+
return None
54+
55+
56+
def _build_response(
57+
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
58+
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
59+
return Response(
60+
status_code=HTTPStatus(response.status_code),
61+
content=response.content,
62+
headers=response.headers,
63+
parsed=_parse_response(client=client, response=response),
64+
)
65+
66+
67+
def sync_detailed(
68+
*,
69+
client: Union[AuthenticatedClient, Client],
70+
body: list[int],
71+
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
72+
"""Get simulation status records for a list of IDs
73+
74+
Args:
75+
body (list[int]):
76+
77+
Raises:
78+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
79+
httpx.TimeoutException: If the request takes longer than Client.timeout.
80+
81+
Returns:
82+
Response[Union[HTTPValidationError, list['HpcRun']]]
83+
"""
84+
85+
kwargs = _get_kwargs(
86+
body=body,
87+
)
88+
89+
response = client.get_httpx_client().request(
90+
**kwargs,
91+
)
92+
93+
return _build_response(client=client, response=response)
94+
95+
96+
def sync(
97+
*,
98+
client: Union[AuthenticatedClient, Client],
99+
body: list[int],
100+
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
101+
"""Get simulation status records for a list of IDs
102+
103+
Args:
104+
body (list[int]):
105+
106+
Raises:
107+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
108+
httpx.TimeoutException: If the request takes longer than Client.timeout.
109+
110+
Returns:
111+
Union[HTTPValidationError, list['HpcRun']]
112+
"""
113+
114+
return sync_detailed(
115+
client=client,
116+
body=body,
117+
).parsed
118+
119+
120+
async def asyncio_detailed(
121+
*,
122+
client: Union[AuthenticatedClient, Client],
123+
body: list[int],
124+
) -> Response[Union[HTTPValidationError, list["HpcRun"]]]:
125+
"""Get simulation status records for a list of IDs
126+
127+
Args:
128+
body (list[int]):
129+
130+
Raises:
131+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
132+
httpx.TimeoutException: If the request takes longer than Client.timeout.
133+
134+
Returns:
135+
Response[Union[HTTPValidationError, list['HpcRun']]]
136+
"""
137+
138+
kwargs = _get_kwargs(
139+
body=body,
140+
)
141+
142+
response = await client.get_async_httpx_client().request(**kwargs)
143+
144+
return _build_response(client=client, response=response)
145+
146+
147+
async def asyncio(
148+
*,
149+
client: Union[AuthenticatedClient, Client],
150+
body: list[int],
151+
) -> Optional[Union[HTTPValidationError, list["HpcRun"]]]:
152+
"""Get simulation status records for a list of IDs
153+
154+
Args:
155+
body (list[int]):
156+
157+
Raises:
158+
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
159+
httpx.TimeoutException: If the request takes longer than Client.timeout.
160+
161+
Returns:
162+
Union[HTTPValidationError, list['HpcRun']]
163+
"""
164+
165+
return (
166+
await asyncio_detailed(
167+
client=client,
168+
body=body,
169+
)
170+
).parsed

compose_api/api/client/models/job_status.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33

44
class JobStatus(str, Enum):
5+
CANCELLED = "cancelled"
56
COMPLETED = "completed"
67
FAILED = "failed"
8+
OUT_OF_MEMORY = "out_of_memory"
79
PENDING = "pending"
810
QUEUED = "queued"
911
RUNNING = "running"
12+
SUSPENDED = "suspended"
13+
TIMEOUT = "timeout"
14+
UNKNOWN = "unknown"
1015
WAITING = "waiting"
1116

1217
def __str__(self) -> str:

compose_api/api/client/models/registered_package.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from typing import cast
1111

1212
if TYPE_CHECKING:
13-
from ..models.bi_graph_process import BiGraphProcess
1413
from ..models.bi_graph_step import BiGraphStep
14+
from ..models.bi_graph_process import BiGraphProcess
1515

1616

1717
T = TypeVar("T", bound="RegisteredPackage")
@@ -36,8 +36,8 @@ class RegisteredPackage:
3636
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
3737

3838
def to_dict(self) -> dict[str, Any]:
39-
from ..models.bi_graph_process import BiGraphProcess
4039
from ..models.bi_graph_step import BiGraphStep
40+
from ..models.bi_graph_process import BiGraphProcess
4141

4242
database_id = self.database_id
4343

@@ -69,8 +69,8 @@ def to_dict(self) -> dict[str, Any]:
6969

7070
@classmethod
7171
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
72-
from ..models.bi_graph_process import BiGraphProcess
7372
from ..models.bi_graph_step import BiGraphStep
73+
from ..models.bi_graph_process import BiGraphProcess
7474

7575
d = dict(src_dict)
7676
database_id = d.pop("database_id")

compose_api/api/client/models/simulator_version.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import datetime
1515

1616
if TYPE_CHECKING:
17-
from ..models.registered_package import RegisteredPackage
1817
from ..models.containerization_file_repr import ContainerizationFileRepr
18+
from ..models.registered_package import RegisteredPackage
1919

2020

2121
T = TypeVar("T", bound="SimulatorVersion")
@@ -40,8 +40,8 @@ class SimulatorVersion:
4040
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
4141

4242
def to_dict(self) -> dict[str, Any]:
43-
from ..models.registered_package import RegisteredPackage
4443
from ..models.containerization_file_repr import ContainerizationFileRepr
44+
from ..models.registered_package import RegisteredPackage
4545

4646
singularity_def = self.singularity_def.to_dict()
4747

@@ -82,8 +82,8 @@ def to_dict(self) -> dict[str, Any]:
8282

8383
@classmethod
8484
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
85-
from ..models.registered_package import RegisteredPackage
8685
from ..models.containerization_file_repr import ContainerizationFileRepr
86+
from ..models.registered_package import RegisteredPackage
8787

8888
d = dict(src_dict)
8989
singularity_def = ContainerizationFileRepr.from_dict(d.pop("singularity_def"))

compose_api/api/routers/results.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525
config = RouterConfig(router=APIRouter(), prefix="/results", dependencies=[])
2626

2727

28+
@config.router.get(
29+
path="/simulations/status/batch",
30+
response_model=list[HpcRun],
31+
operation_id="get-simulations-status-batch",
32+
tags=["Results"],
33+
dependencies=[Depends(get_database_service)],
34+
summary="Get simulation status records for a list of IDs",
35+
)
36+
async def get_simulations_status_batch(ids: list[int]) -> list[HpcRun]:
37+
db_service = get_database_service()
38+
if db_service is None:
39+
raise HTTPException(status_code=500, detail="Database service is not initialized")
40+
try:
41+
return await db_service.get_hpc_db().get_hpcruns_by_refs(ref_ids=ids, job_type=JobType.SIMULATION)
42+
except Exception as e:
43+
logger.exception(f"Error fetching batch simulation statuses for ids: {ids}.")
44+
raise HTTPException(status_code=500, detail=str(e)) from e
45+
46+
2847
@config.router.get(
2948
path="/simulation/status",
3049
response_model=HpcRun,

compose_api/api/spec/openapi_3_1_0_generated.yaml

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
openapi: 3.1.0
22
info:
33
title: compose-api
4-
version: 0.3.3
4+
version: 0.3.8
55
paths:
66
/curated/copasi:
77
post:
@@ -124,6 +124,37 @@ paths:
124124
application/json:
125125
schema:
126126
$ref: '#/components/schemas/HTTPValidationError'
127+
/results/simulations/status/batch:
128+
get:
129+
tags:
130+
- Results
131+
summary: Get simulation status records for a list of IDs
132+
operationId: get-simulations-status-batch
133+
requestBody:
134+
content:
135+
application/json:
136+
schema:
137+
items:
138+
type: integer
139+
type: array
140+
title: Ids
141+
required: true
142+
responses:
143+
'200':
144+
description: Successful Response
145+
content:
146+
application/json:
147+
schema:
148+
items:
149+
$ref: '#/components/schemas/HpcRun'
150+
type: array
151+
title: Response Get-Simulations-Status-Batch
152+
'422':
153+
description: Validation Error
154+
content:
155+
application/json:
156+
schema:
157+
$ref: '#/components/schemas/HTTPValidationError'
127158
/results/simulation/status:
128159
get:
129160
tags:
@@ -449,6 +480,11 @@ components:
449480
- completed
450481
- failed
451482
- pending
483+
- cancelled
484+
- out_of_memory
485+
- suspended
486+
- timeout
487+
- unknown
452488
title: JobStatus
453489
JobType:
454490
type: string

compose_api/db/services/hpc_db.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c
4646
async def get_hpcrun_by_ref(self, ref_id: int, job_type: JobType) -> HpcRun | None:
4747
pass
4848

49+
@abstractmethod
50+
async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]:
51+
pass
52+
4953
@abstractmethod
5054
async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None:
5155
pass
@@ -135,6 +139,15 @@ async def insert_hpcrun(self, slurmjobid: int, job_type: JobType, ref_id: int, c
135139
await session.flush()
136140
return orm_hpc_run.to_hpc_run()
137141

142+
@override
143+
async def get_hpcruns_by_refs(self, ref_ids: list[int], job_type: JobType) -> list[HpcRun]:
144+
async with self.async_session_maker() as session, session.begin():
145+
run_type = self._get_job_type_ref(job_type)
146+
stmt = select(ORMHpcRun).where(run_type.in_(ref_ids))
147+
result: Result[tuple[ORMHpcRun]] = await session.execute(stmt)
148+
orm_hpcruns = result.scalars().all()
149+
return [orm_hpcrun.to_hpc_run() for orm_hpcrun in orm_hpcruns]
150+
138151
@override
139152
async def get_hpcrun_by_slurmjobid(self, slurmjobid: int) -> HpcRun | None:
140153
async with self.async_session_maker() as session, session.begin():

0 commit comments

Comments
 (0)