Skip to content

Commit 6cd837a

Browse files
committed
feat(core): add typed Workflow/Step registry and builder
1 parent 586e69b commit 6cd837a

5 files changed

Lines changed: 301 additions & 0 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from fastapi_cloudflow.core.arg import Arg, ArgExpr
2+
from fastapi_cloudflow.core.step import AssignStep, HttpStep, ModelAdapter, Step
3+
from fastapi_cloudflow.core.types import Context, RetryPolicy, WorkflowMeta
4+
from fastapi_cloudflow.core.workflow import (
5+
Registry,
6+
Workflow,
7+
WorkflowBuilder,
8+
get_registry,
9+
get_workflows,
10+
step,
11+
workflow,
12+
)
13+
14+
__all__ = [
15+
"Context",
16+
"WorkflowMeta",
17+
"RetryPolicy",
18+
"ArgExpr",
19+
"Arg",
20+
"Step",
21+
"AssignStep",
22+
"HttpStep",
23+
"ModelAdapter",
24+
"Workflow",
25+
"Registry",
26+
"WorkflowBuilder",
27+
"workflow",
28+
"get_registry",
29+
"get_workflows",
30+
"step",
31+
]

src/fastapi_cloudflow/core/arg.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
4+
class ArgExpr:
5+
def __init__(self, expr: str) -> None:
6+
self.expr = expr
7+
8+
@staticmethod
9+
def _coerce_expr(value: str | ArgExpr) -> str:
10+
if isinstance(value, ArgExpr):
11+
return value.expr
12+
return f'"{value}"'
13+
14+
def __truediv__(self, other: str | ArgExpr) -> ArgExpr:
15+
if isinstance(other, ArgExpr):
16+
right = other.expr
17+
return ArgExpr(f'{self.expr} + "/" + {right}')
18+
else:
19+
return ArgExpr(f'{self.expr} + "/{other}"')
20+
21+
def __add__(self, other: str | ArgExpr) -> ArgExpr:
22+
right = self._coerce_expr(other)
23+
return ArgExpr(f"{self.expr} + {right}")
24+
25+
def __str__(self) -> str:
26+
return f"${{{self.expr}}}"
27+
28+
29+
class Arg:
30+
@staticmethod
31+
def env(name: str) -> ArgExpr:
32+
return ArgExpr(f'sys.get_env("{name}")')
33+
34+
@staticmethod
35+
def param(path: str) -> ArgExpr:
36+
return ArgExpr(f"params.{path}")
37+
38+
@staticmethod
39+
def ctx(key: str) -> ArgExpr:
40+
return ArgExpr(f"ctx.{key}")

src/fastapi_cloudflow/core/step.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from collections.abc import Awaitable, Callable, Iterable
2+
from datetime import timedelta
3+
from typing import Any, TypeVar
4+
5+
from pydantic import BaseModel
6+
7+
from fastapi_cloudflow.core.arg import ArgExpr
8+
from fastapi_cloudflow.core.types import Context, RetryPolicy
9+
10+
InT = TypeVar("InT", bound=BaseModel)
11+
OutT = TypeVar("OutT", bound=BaseModel)
12+
13+
14+
class Step[InT: BaseModel, OutT: BaseModel]:
15+
name: str
16+
input_model: type[InT]
17+
output_model: type[OutT]
18+
retry: RetryPolicy | None
19+
timeout: timedelta | None
20+
tags: set[str]
21+
22+
def __init__(
23+
self,
24+
name: str,
25+
input_model: type[InT],
26+
output_model: type[OutT],
27+
fn: Callable[[Context, InT], Awaitable[OutT]] | None = None,
28+
retry: RetryPolicy | None = None,
29+
timeout: timedelta | None = None,
30+
tags: Iterable[str] = (),
31+
) -> None:
32+
self.name = name
33+
self.input_model = input_model
34+
self.output_model = output_model
35+
self.fn = fn
36+
self.retry = retry
37+
self.timeout = timeout
38+
self.tags = set(tags)
39+
40+
async def __call__(self, ctx: Context, data: InT) -> OutT:
41+
if self.fn is None:
42+
raise RuntimeError("Step is not callable. Is it a native step?")
43+
return await self.fn(ctx, data)
44+
45+
46+
class AssignStep(Step[InT, OutT]):
47+
def __init__(self, name: str, input_model: type[InT], output_model: type[OutT], expr: dict[str, Any]) -> None:
48+
super().__init__(name=name, input_model=input_model, output_model=output_model, fn=None)
49+
self.expr = expr
50+
51+
52+
class HttpStep(Step[InT, OutT]):
53+
def __init__(
54+
self,
55+
name: str,
56+
input_model: type[InT],
57+
output_model: type[OutT],
58+
method: str,
59+
url: str | ArgExpr,
60+
headers: dict[str, str | ArgExpr] | None = None,
61+
auth: dict[str, Any] | None = None,
62+
retry: RetryPolicy | None = None,
63+
timeout: timedelta | None = None,
64+
) -> None:
65+
super().__init__(
66+
name=name,
67+
input_model=input_model,
68+
output_model=output_model,
69+
fn=None,
70+
retry=retry,
71+
timeout=timeout,
72+
)
73+
self.method = method.upper()
74+
self.url = url
75+
self.headers = headers or {}
76+
self.auth = auth
77+
78+
79+
class ModelAdapter(Step[InT, OutT]):
80+
def __init__(self, name: str, input_model: type[InT], output_model: type[OutT], mapping: dict[str, Any]) -> None:
81+
super().__init__(name=name, input_model=input_model, output_model=output_model, fn=None)
82+
self.mapping = mapping
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from fastapi import Request
6+
7+
8+
@dataclass
9+
class WorkflowMeta:
10+
name: str | None = None
11+
step: str | None = None
12+
run_id: str | None = None
13+
14+
15+
@dataclass
16+
class Context:
17+
request: Request
18+
workflow: WorkflowMeta
19+
20+
21+
@dataclass
22+
class RetryPolicy:
23+
max_retries: int = 5
24+
initial_delay_s: float = 1.0
25+
max_delay_s: float = 30.0
26+
multiplier: float = 2.0
27+
predicate: str = "http.default_retry_predicate"
28+
29+
@staticmethod
30+
def idempotent_http() -> RetryPolicy:
31+
return RetryPolicy(
32+
max_retries=5,
33+
initial_delay_s=1.0,
34+
max_delay_s=30.0,
35+
multiplier=2.0,
36+
predicate="http.default_retry_predicate",
37+
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from collections.abc import Awaitable, Callable, Iterable
5+
from typing import Any, TypeVar, get_type_hints
6+
7+
from pydantic import BaseModel
8+
9+
from fastapi_cloudflow.core.step import Step
10+
from fastapi_cloudflow.core.types import Context
11+
12+
13+
class Workflow:
14+
def __init__(self, name: str, nodes: list[Step[Any, Any]]) -> None:
15+
self.name = name
16+
self.nodes = nodes
17+
18+
19+
class Registry:
20+
def __init__(self) -> None:
21+
self.steps: dict[str, Step[Any, Any]] = {}
22+
self.workflows: dict[str, Workflow] = {}
23+
24+
def register_step(self, step: Step[Any, Any]) -> None:
25+
if step.name in self.steps:
26+
raise ValueError(f"Step name collision: {step.name}")
27+
self.steps[step.name] = step
28+
29+
def register_workflow(self, workflow: Workflow) -> None:
30+
existing = self.workflows.get(workflow.name)
31+
if existing is not None:
32+
existing_names = [s.name for s in existing.nodes]
33+
new_names = [s.name for s in workflow.nodes]
34+
if existing_names == new_names:
35+
return
36+
raise ValueError(f"Workflow name collision: {workflow.name}")
37+
self.workflows[workflow.name] = workflow
38+
39+
def get_workflows(self) -> list[Workflow]:
40+
return list(self.workflows.values())
41+
42+
43+
class WorkflowBuilder:
44+
def __init__(self, name: str, nodes: list[Step[Any, Any]] | None = None) -> None:
45+
self.name = name
46+
self.nodes = nodes or []
47+
48+
def __rshift__(self, other: Step[Any, Any]) -> WorkflowBuilder:
49+
if self.nodes:
50+
prev = self.nodes[-1]
51+
if prev.output_model is not other.input_model:
52+
raise TypeError(
53+
f"Type mismatch: {prev.name} outputs {prev.output_model.__name__} "
54+
f"but {other.name} expects {other.input_model.__name__}"
55+
)
56+
return WorkflowBuilder(self.name, self.nodes + [other])
57+
58+
def build(self) -> Workflow:
59+
if not self.nodes:
60+
raise ValueError("Workflow has no steps")
61+
wf = Workflow(self.name, self.nodes)
62+
_REGISTRY.register_workflow(wf)
63+
return wf
64+
65+
66+
_REGISTRY = Registry()
67+
68+
69+
def workflow(name: str) -> WorkflowBuilder:
70+
return WorkflowBuilder(name)
71+
72+
73+
InT = TypeVar("InT", bound=BaseModel)
74+
OutT = TypeVar("OutT", bound=BaseModel)
75+
76+
77+
def step(
78+
*,
79+
name: str | None = None,
80+
retry: Any | None = None,
81+
timeout: Any | None = None,
82+
tags: Iterable[str] = (),
83+
):
84+
def decorator(fn: Callable[[Context, InT], Awaitable[OutT]]) -> Step[InT, OutT]:
85+
hints = get_type_hints(fn)
86+
sig = inspect.signature(fn)
87+
params = list(sig.parameters.values())
88+
if len(params) != 2:
89+
raise TypeError("@step function must accept exactly two positional parameters: (Context, InModel)")
90+
in_param = params[1].name
91+
in_model = hints.get(in_param)
92+
out_model = hints.get("return")
93+
if not (isinstance(in_model, type) and issubclass(in_model, BaseModel)):
94+
raise TypeError("@step function must type its second parameter as a Pydantic BaseModel subclass")
95+
if not (isinstance(out_model, type) and issubclass(out_model, BaseModel)):
96+
raise TypeError("@step function must return a Pydantic BaseModel subclass")
97+
base_name = getattr(fn, "__name__", "step")
98+
nm = name or base_name.replace("_", "-")
99+
s: Step[Any, Any] = Step(nm, in_model, out_model, fn=fn, retry=retry, timeout=timeout, tags=tags)
100+
_REGISTRY.register_step(s)
101+
return s
102+
103+
return decorator
104+
105+
106+
def get_registry() -> Registry:
107+
return _REGISTRY
108+
109+
110+
def get_workflows() -> list[Workflow]:
111+
return _REGISTRY.get_workflows()

0 commit comments

Comments
 (0)