Skip to content

Commit 6eadede

Browse files
committed
Add workflows endpoint
1 parent d43a6a4 commit 6eadede

7 files changed

Lines changed: 482 additions & 87 deletions

File tree

metafold/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from metafold.projects import ProjectsEndpoint
33
from metafold.assets import AssetsEndpoint
44
from metafold.jobs import JobsEndpoint
5+
from metafold.workflows import WorkflowsEndpoint
56
from metafold.auth import AuthProvider
67
from typing import Optional
78

@@ -17,6 +18,7 @@ class MetafoldClient(Client):
1718
projects: ProjectsEndpoint
1819
assets: AssetsEndpoint
1920
jobs: JobsEndpoint
21+
workflows: WorkflowsEndpoint
2022

2123
def __init__(
2224
self,
@@ -48,3 +50,4 @@ def __init__(
4850
self.projects = ProjectsEndpoint(self)
4951
self.assets = AssetsEndpoint(self)
5052
self.jobs = JobsEndpoint(self)
53+
self.workflows = WorkflowsEndpoint(self)

metafold/api.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,6 @@
22
from functools import wraps
33
from typing import Any, Callable, Optional, TypeVar, Union
44

5-
T = TypeVar("T")
6-
U = TypeVar("U")
7-
8-
9-
def optional(f: Callable[[T], U]) -> Callable[[Optional[T]], Optional[U]]:
10-
"""Decorator to generate converters that accept optional values."""
11-
@wraps(f)
12-
def decorator(v: Optional[T]) -> Optional[U]:
13-
if v is None:
14-
return v
15-
return f(v)
16-
17-
return decorator
18-
195

206
def asdatetime(s: Union[str, datetime]) -> datetime:
217
"""Parse Metafold API datetime.
@@ -38,3 +24,21 @@ def asdict(**kwargs: Any) -> dict[str, Any]:
3824
if v is not None:
3925
d[k] = v
4026
return d
27+
28+
29+
T = TypeVar("T")
30+
U = TypeVar("U")
31+
32+
33+
def optional(f: Callable[[T], U]) -> Callable[[Optional[T]], Optional[U]]:
34+
"""Decorator to generate converters that accept optional values."""
35+
@wraps(f)
36+
def decorator(v: Optional[T]) -> Optional[U]:
37+
if v is None:
38+
return v
39+
return f(v)
40+
41+
return decorator
42+
43+
44+
optional_datetime = optional(asdatetime)

metafold/client.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from metafold.auth import AuthProvider
2+
from metafold.exceptions import PollTimeout
23
from requests import HTTPError, Response, Session
3-
from typing import Any, Callable, Optional
4+
from typing import Any, Callable, Optional, Union
45
from urllib.parse import urljoin
56
import platform
7+
import time
68

79

810
class Client:
@@ -74,3 +76,29 @@ def patch(self, url: str, *args: Any, **kwargs: Any) -> Response:
7476

7577
def delete(self, url: str, *args: Any, **kwargs: Any) -> Response:
7678
return self._request(self._session.delete, url, *args, **kwargs)
79+
80+
def poll(
81+
self, url: str,
82+
timeout: Union[int, float] = 120,
83+
every: Union[int, float] = 1,
84+
) -> Response:
85+
"""Poll the given URL in regular intervals.
86+
87+
Helpful for waiting on async processes given a status URL.
88+
89+
Args:
90+
timeout: Time in seconds to wait for a result.
91+
every: Frequency in seconds.
92+
93+
Returns:
94+
HTTP response.
95+
"""
96+
t0 = time.monotonic()
97+
r = self.get(url)
98+
while r.status_code == 202:
99+
elapsed = time.monotonic() - t0
100+
if elapsed >= timeout:
101+
raise PollTimeout(f"Polling timed out: {url}")
102+
time.sleep(1)
103+
r = self.get(url)
104+
return r

metafold/jobs.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,49 @@
11
from attrs import field, frozen
22
from datetime import datetime
3-
from metafold.api import asdatetime, asdict, optional
3+
from metafold.api import asdatetime, asdict, optional, optional_datetime
44
from metafold.assets import Asset
55
from metafold.client import Client
66
from metafold.exceptions import PollTimeout
77
from requests import Response
8-
from typing import Any, Optional, Union
9-
import time
8+
from typing import Any, Optional, TypeAlias, TypedDict, Union
109

1110

1211
def _assets(v: list[Union[dict[str, Any], Asset]]) -> list[Asset]:
1312
return [a if isinstance(a, Asset) else Asset(**a) for a in v]
1413

1514

15+
AssetDict: TypeAlias = dict[str, Union[dict[str, Any], Asset]]
16+
17+
18+
def _assets_dict(v: AssetDict) -> dict[str, Asset]:
19+
return dict(
20+
(k, a if isinstance(a, Asset) else Asset(**a))
21+
for k, a in v.items()
22+
)
23+
24+
25+
class IODict(TypedDict):
26+
params: Optional[dict[str, Any]]
27+
assets: Optional[dict[str, AssetDict]]
28+
29+
30+
@frozen(kw_only=True)
31+
class IO:
32+
"""Job input/output.
33+
34+
Attributes:
35+
params: JSON-encoded parameter values.
36+
assets: Related assets.
37+
"""
38+
params: Optional[dict[str, Any]] = None
39+
assets: Optional[dict[str, Asset]] = field(
40+
converter=optional(_assets_dict), default=None)
41+
42+
@staticmethod
43+
def from_dict(d: IODict) -> "IO":
44+
return IO(params=d.get("params"), assets=d.get("assets"))
45+
46+
1647
@frozen(kw_only=True)
1748
class Job:
1849
"""Job resource.
@@ -21,27 +52,31 @@ class Job:
2152
id: Job ID.
2253
name: Job name.
2354
type: Job type.
24-
parameters: Job parameters.
55+
state: Job state. May be one of: pending, started, success, failure, or
56+
canceled.
2557
created: Job creation datetime.
58+
started: Job started datetime.
2659
finished: Job finished datetime.
27-
state: Job state. May be one of: pending, started, success, or failure.
28-
assets: List of generated asset resources.
29-
meta: Additional metadata generated by the job.
60+
error: Error message for failed jobs.
61+
inputs: Input assets and parameters.
62+
outputs: Output assets and parameters.
63+
assets: (Deprecated) List of generated asset resources.
64+
parameters: (Deprecated) Job parameters.
65+
meta: (Deprecated) Additional metadata generated by the job.
3066
"""
3167
id: str
3268
name: Optional[str] = None
3369
type: str
34-
parameters: dict[str, Any]
35-
created: datetime = field(converter=asdatetime)
36-
finished: Optional[datetime] = field(
37-
converter=lambda v: optional(asdatetime)(v),
38-
default=None,
39-
)
4070
state: str
41-
assets: Optional[list[Asset]] = field(
42-
converter=lambda v: optional(_assets)(v),
43-
default=None,
44-
)
71+
created: datetime = field(converter=asdatetime)
72+
started: Optional[datetime] = field(converter=optional_datetime, default=None)
73+
finished: Optional[datetime] = field(converter=optional_datetime, default=None)
74+
error: Optional[str] = None
75+
inputs: IO = field(converter=lambda v: v if isinstance(v, IO) else IO.from_dict(v))
76+
outputs: IO = field(converter=lambda v: v if isinstance(v, IO) else IO.from_dict(v))
77+
# NOTE(ryan): Deprecated
78+
assets: Optional[list[Asset]] = field(converter=optional(_assets), default=None)
79+
parameters: dict[str, Any]
4580
meta: dict[str, Any]
4681

4782

@@ -61,7 +96,8 @@ def list(
6196
6297
Args:
6398
sort: Sort string. For details on syntax see the Metafold API docs.
64-
Supported sorting fields are: "id", "name", or "created".
99+
Supported sorting fields are: "id", "name", "created", "started", or
100+
"finished".
65101
q: Query string. For details on syntax see the Metafold API docs.
66102
Supported search fields are: "id", "name", "type", and "state".
67103
project_id: Job project ID.
@@ -142,31 +178,8 @@ def run_status(
142178
r: Response = self._client.post(f"/projects/{project_id}/jobs", json=payload)
143179
return r.json()["link"]
144180

145-
def poll(
146-
self, url: str,
147-
timeout: Union[int, float] = 120,
148-
every: Union[int, float] = 1,
149-
) -> Response:
150-
"""Poll the given URL in regular intervals.
151-
152-
Helpful for waiting on job results given a status URL.
153-
154-
Args:
155-
timeout: Time in seconds to wait for a result.
156-
every: Frequency in seconds.
157-
158-
Returns:
159-
HTTP response.
160-
"""
161-
t0 = time.monotonic()
162-
r = self._client.get(url)
163-
while r.status_code == 202:
164-
elapsed = time.monotonic() - t0
165-
if elapsed >= timeout:
166-
raise PollTimeout("Job timed out")
167-
time.sleep(1)
168-
r = self._client.get(url)
169-
return r
181+
def poll(self, *args, **kwargs):
182+
return self._client.poll(*args, **kwargs)
170183

171184
def update(
172185
self, job_id: str,
@@ -188,3 +201,14 @@ def update(
188201
payload = asdict(name=name)
189202
r: Response = self._client.patch(url, data=payload)
190203
return Job(**r.json())
204+
205+
def delete(self, job_id: str, project_id: Optional[str] = None):
206+
"""Delete a job.
207+
208+
Args:
209+
job_id: ID of job to delete.
210+
project_id: Job project ID.
211+
"""
212+
project_id = self._client.project_id(project_id)
213+
url = f"/projects/{project_id}/jobs/{job_id}"
214+
self._client.delete(url)

metafold/workflows.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from attrs import field, frozen
2+
from datetime import datetime
3+
from metafold.api import asdatetime, asdict, optional_datetime
4+
from metafold.client import Client
5+
from metafold.exceptions import PollTimeout
6+
from requests import Response
7+
from typing import Optional, Union
8+
9+
10+
@frozen(kw_only=True)
11+
class Workflow:
12+
"""Workflow resource.
13+
14+
Attributes:
15+
id: Workflow ID.
16+
state: Workflow state. May be one of: pending, started, success, failure, or
17+
canceled.
18+
created: Workflow creation datetime.
19+
started: Workflow started datetime.
20+
finished: Workflow finished datetime.
21+
definition: Workflow definition string.
22+
"""
23+
id: str
24+
jobs: list[str] = field(factory=list)
25+
state: str
26+
created: datetime = field(converter=asdatetime)
27+
started: Optional[datetime] = field(converter=optional_datetime, default=None)
28+
finished: Optional[datetime] = field(converter=optional_datetime, default=None)
29+
definition: str
30+
31+
32+
class WorkflowsEndpoint:
33+
"""Metafold workflows endpoint."""
34+
35+
def __init__(self, client: Client) -> None:
36+
self._client = client
37+
38+
def list(
39+
self,
40+
sort: Optional[str] = None,
41+
q: Optional[str] = None,
42+
project_id: Optional[str] = None,
43+
) -> list[Workflow]:
44+
"""List jobs.
45+
46+
Args:
47+
sort: Sort string. For details on syntax see the Metafold API docs.
48+
Supported sorting fields are: "id", "created", "started", or "finished".
49+
q: Query string. For details on syntax see the Metafold API docs.
50+
Supported search fields are: "id" and "state".
51+
project_id: Workflow project ID.
52+
53+
Returns:
54+
List of job resources.
55+
"""
56+
project_id = self._client.project_id(project_id)
57+
url = f"/projects/{project_id}/workflows"
58+
payload = asdict(sort=sort, q=q)
59+
r: Response = self._client.get(url, params=payload)
60+
return [Workflow(**w) for w in r.json()]
61+
62+
def get(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
63+
"""Get a workflow.
64+
65+
Args:
66+
workflow_id: ID of workflow to get.
67+
project_id: Workflow project ID.
68+
69+
Returns:
70+
Workflow resource.
71+
"""
72+
project_id = self._client.project_id(project_id)
73+
url = f"/projects/{project_id}/workflows/{workflow_id}"
74+
r: Response = self._client.get(url)
75+
return Workflow(**r.json())
76+
77+
def run(
78+
self, definition: str,
79+
parameters: Optional[dict[str, str]] = None,
80+
assets: Optional[dict[str, str]] = None,
81+
timeout: Union[int, float] = 120,
82+
project_id: Optional[str] = None,
83+
) -> Workflow:
84+
"""Dispatch a new workflow and wait for it to complete.
85+
86+
Workflow completion does not indicate success. Access the completed workflow's
87+
state to check for success/failure.
88+
89+
Args:
90+
definition: Workflow definition YAML.
91+
parameters: Parameter mapping for jobs in the definition.
92+
assets: Asset mapping for jobs in the definition.
93+
timeout: Time in seconds to wait for a result.
94+
project_id: Workflow project ID.
95+
96+
Returns:
97+
Completed workflow resource.
98+
"""
99+
project_id = self._client.project_id(project_id)
100+
payload = asdict(definition=definition, parameters=parameters, assets=assets)
101+
r: Response = self._client.post(f"/projects/{project_id}/workflows", json=payload)
102+
url = r.json()["link"]
103+
try:
104+
r: Response = self._client.poll(url, timeout)
105+
except PollTimeout as e:
106+
raise RuntimeError(
107+
f"Workflow failed to complete within {timeout} seconds"
108+
) from e
109+
return Workflow(**r.json())
110+
111+
def cancel(self, workflow_id: str, project_id: Optional[str] = None) -> Workflow:
112+
"""Cancel a running workflow.
113+
114+
Args:
115+
workflow_id: ID of workflow to cancel.
116+
project_id: Workflow project ID.
117+
118+
Returns:
119+
Workflow resource.
120+
"""
121+
project_id = self._client.project_id(project_id)
122+
url = f"/projects/{project_id}/workflows/{workflow_id}/cancel"
123+
r: Response = self._client.post(url)
124+
return Workflow(**r.json())
125+
126+
def delete(self, workflow_id: str, project_id: Optional[str] = None):
127+
"""Delete a workflow.
128+
129+
Args:
130+
workflow_id: ID of workflow to delete.
131+
project_id: Workflow project ID.
132+
"""
133+
project_id = self._client.project_id(project_id)
134+
url = f"/projects/{project_id}/workflows/{workflow_id}"
135+
self._client.delete(url)

0 commit comments

Comments
 (0)