Skip to content

Commit b6fb321

Browse files
authored
Implement running jobs pipeline (#3657)
* Add running jobs pipeline fetcher scaffold * Prototype job pipeline for provisioning and pullings states * Treat job_model as read-on;y * Finish running jobs pipeline worker * Wire pipeline * Restore TODOs and simplifify code * Set TERMINATED_DUE_TO_UTILIZATION_POLICY * Set INACTIVITY_DURATION_EXCEEDED * Extract _handle_instance_unreachable * Unify jobs pipelines patterns * Add context and apply to fleet pipeline * Describe Typical worker structure * Move unlock/processed updates inside _apply_process_result * Add FIXME: Race condition when checking len(fleet_model.instances) == 0 * Fix stale fleet_model read * Clean up pipeline tests * Fix empty fleet select * Fix missing az restriction for clusters in submitted_jobs * Add deprecated note * Pass instance_project_ssh_private_key * Fix missing pipeline tests
1 parent 2a3c77f commit b6fb321

18 files changed

Lines changed: 3993 additions & 390 deletions

File tree

contributing/PIPELINES.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,55 @@ Brief checklist for implementing a new pipeline:
3737
8. Register the pipeline in `PipelineManager` and hint fetch from services after commit via `pipeline_hinter.hint_fetch(Model.__name__)`.
3838
9. Add minimum tests: fetch eligibility/order, successful unlock path, stale lock token path, and related lock contention retry path.
3939

40+
## Typical worker structure
41+
42+
Most workers are easiest to reason about when `process()` is split into three phases:
43+
44+
1. Load/refetch: open a short DB session, refetch the locked main row by `id + lock_token`, lock any required related rows, and gather any extra data needed for processing.
45+
2. Process: do the heavy work outside DB sessions and build result objects or update maps instead of mutating detached ORM models.
46+
3. Apply: open a short DB session, guard the main update by `id + lock_token`, resolve time placeholders, apply related updates, emit events, and unlock rows.
47+
48+
A dedicated context object is often useful for the load step when the worker needs multiple loaded models, related lock metadata, or derived values that should be passed cleanly into processing and apply. For very small pipelines, a direct load -> process -> apply flow may still be clearer.
49+
50+
Workers can share one context type and one apply function across all states even if the processing logic differs by state:
51+
52+
```python
53+
async def process(item):
54+
context = await _load_process_context(item)
55+
if context is None:
56+
return
57+
result = await _process_item(context)
58+
await _apply_process_result(item, context, result)
59+
```
60+
61+
Sometimes state-specific helpers are still the cleanest option, but they can still share a common apply phase if all states write results in the same general shape:
62+
63+
```python
64+
async def process(item):
65+
if item.status == Status.PENDING:
66+
context = await _load_pending_context(item)
67+
elif item.status == Status.RUNNING:
68+
context = await _load_running_context(item)
69+
else:
70+
return
71+
if context is None:
72+
return
73+
result = await _process_item(context)
74+
await _apply_process_result(item, context, result)
75+
```
76+
77+
If different states have materially different write-side behavior, different apply paths are fine as well. This commonly happens when one state does a normal guarded update while another does delete-or-cleanup work with different related updates:
78+
79+
```python
80+
async def process(item):
81+
if item.to_be_deleted:
82+
await _process_to_be_deleted_item(item)
83+
elif item.status == Status.SUBMITTED:
84+
await _process_submitted_item(item)
85+
```
86+
87+
It's ok not to force all pipelines into one exact shape.
88+
4089
## Implementation patterns
4190

4291
**Guarded apply by lock token**

src/dstack/_internal/server/background/pipeline_tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline
66
from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline
77
from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline
8+
from dstack._internal.server.background.pipeline_tasks.jobs_running import JobRunningPipeline
89
from dstack._internal.server.background.pipeline_tasks.jobs_terminating import (
910
JobTerminatingPipeline,
1011
)
@@ -23,6 +24,7 @@ def __init__(self) -> None:
2324
ComputeGroupPipeline(),
2425
FleetPipeline(),
2526
GatewayPipeline(),
27+
JobRunningPipeline(),
2628
JobTerminatingPipeline(),
2729
InstancePipeline(),
2830
PlacementGroupPipeline(),

src/dstack/_internal/server/background/pipeline_tasks/fleets.py

Lines changed: 155 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -196,106 +196,23 @@ def __init__(
196196

197197
@sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process")
198198
async def process(self, item: PipelineItem):
199-
async with get_session_ctx() as session:
200-
res = await session.execute(
201-
select(FleetModel)
202-
.where(
203-
FleetModel.id == item.id,
204-
FleetModel.lock_token == item.lock_token,
205-
)
206-
.options(joinedload(FleetModel.project))
207-
.options(
208-
selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))
209-
.joinedload(InstanceModel.jobs)
210-
.load_only(JobModel.id),
211-
)
212-
.options(
213-
selectinload(
214-
FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses()))
215-
).load_only(RunModel.status)
216-
)
217-
)
218-
fleet_model = res.unique().scalar_one_or_none()
219-
if fleet_model is None:
220-
log_lock_token_mismatch(logger, item)
221-
return
222-
223-
# Lock instance only if consolidation is needed.
224-
locked_instance_ids: set[uuid.UUID] = set()
225-
consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model)
226-
consolidation_instances = None
227-
if consolidation_fleet_spec is not None:
228-
consolidation_instances = await _lock_fleet_instances_for_consolidation(
229-
session=session,
230-
item=item,
231-
)
232-
if consolidation_instances is None:
233-
return
234-
locked_instance_ids = {instance.id for instance in consolidation_instances}
235-
199+
process_context = await _load_process_context(item)
200+
if process_context is None:
201+
return
236202
result = await _process_fleet(
237-
fleet_model,
238-
consolidation_fleet_spec=consolidation_fleet_spec,
239-
consolidation_instances=consolidation_instances,
240-
)
241-
fleet_update_map = _FleetUpdateMap()
242-
fleet_update_map.update(result.fleet_update_map)
243-
set_processed_update_map_fields(fleet_update_map)
244-
set_unlock_update_map_fields(fleet_update_map)
245-
instance_update_rows = _build_instance_update_rows(
246-
result.instance_id_to_update_map,
247-
unlock_instance_ids=locked_instance_ids,
203+
process_context.fleet_model,
204+
consolidation_fleet_spec=process_context.consolidation_fleet_spec,
205+
consolidation_instances=process_context.consolidation_instances,
248206
)
207+
await _apply_process_result(item, process_context, result)
249208

250-
async with get_session_ctx() as session:
251-
now = get_current_datetime()
252-
resolve_now_placeholders(fleet_update_map, now=now)
253-
resolve_now_placeholders(instance_update_rows, now=now)
254-
res = await session.execute(
255-
update(FleetModel)
256-
.where(
257-
FleetModel.id == fleet_model.id,
258-
FleetModel.lock_token == fleet_model.lock_token,
259-
)
260-
.values(**fleet_update_map)
261-
.returning(FleetModel.id)
262-
)
263-
updated_ids = list(res.scalars().all())
264-
if len(updated_ids) == 0:
265-
log_lock_token_changed_after_processing(logger, item)
266-
if locked_instance_ids:
267-
await _unlock_fleet_locked_instances(
268-
session=session,
269-
item=item,
270-
locked_instance_ids=locked_instance_ids,
271-
)
272-
# TODO: Clean up fleet.
273-
return
274-
275-
if fleet_update_map.get("deleted"):
276-
await session.execute(
277-
update(PlacementGroupModel)
278-
.where(PlacementGroupModel.fleet_id == item.id)
279-
.values(fleet_deleted=True)
280-
)
281-
if instance_update_rows:
282-
await session.execute(
283-
update(InstanceModel),
284-
instance_update_rows,
285-
)
286-
if len(result.new_instance_creates) > 0:
287-
await _create_missing_fleet_instances(
288-
session=session,
289-
fleet_model=fleet_model,
290-
new_instance_creates=result.new_instance_creates,
291-
)
292-
emit_fleet_status_change_event(
293-
session=session,
294-
fleet_model=fleet_model,
295-
old_status=fleet_model.status,
296-
new_status=fleet_update_map.get("status", fleet_model.status),
297-
status_message=fleet_update_map.get("status_message", fleet_model.status_message),
298-
)
209+
210+
@dataclass
211+
class _ProcessContext:
212+
fleet_model: FleetModel
213+
consolidation_fleet_spec: Optional[FleetSpec]
214+
consolidation_instances: Optional[list[InstanceModel]]
215+
locked_instance_ids: set[uuid.UUID] = field(default_factory=set)
299216

300217

301218
class _FleetUpdateMap(ItemUpdateMap, total=False):
@@ -318,6 +235,83 @@ class _InstanceUpdateMap(ItemUpdateMap, total=False):
318235
id: uuid.UUID
319236

320237

238+
@dataclass
239+
class _ProcessResult:
240+
fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap)
241+
instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict)
242+
new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list)
243+
244+
245+
class _NewInstanceCreate(TypedDict):
246+
id: uuid.UUID
247+
instance_num: int
248+
249+
250+
@dataclass
251+
class _MaintainNodesResult:
252+
instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict)
253+
new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list)
254+
changes_required: bool = False
255+
256+
@property
257+
def has_changes(self) -> bool:
258+
return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0
259+
260+
261+
async def _load_process_context(item: PipelineItem) -> Optional[_ProcessContext]:
262+
async with get_session_ctx() as session:
263+
fleet_model = await _refetch_locked_fleet(session=session, item=item)
264+
if fleet_model is None:
265+
log_lock_token_mismatch(logger, item)
266+
return None
267+
268+
consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model)
269+
consolidation_instances = None
270+
if consolidation_fleet_spec is not None:
271+
consolidation_instances = await _lock_fleet_instances_for_consolidation(
272+
session=session,
273+
item=item,
274+
)
275+
if consolidation_instances is None:
276+
return None
277+
278+
return _ProcessContext(
279+
fleet_model=fleet_model,
280+
consolidation_fleet_spec=consolidation_fleet_spec,
281+
consolidation_instances=consolidation_instances,
282+
locked_instance_ids=(
283+
set()
284+
if consolidation_instances is None
285+
else {i.id for i in consolidation_instances}
286+
),
287+
)
288+
289+
290+
async def _refetch_locked_fleet(
291+
session: AsyncSession,
292+
item: PipelineItem,
293+
) -> Optional[FleetModel]:
294+
res = await session.execute(
295+
select(FleetModel)
296+
.where(
297+
FleetModel.id == item.id,
298+
FleetModel.lock_token == item.lock_token,
299+
)
300+
.options(joinedload(FleetModel.project))
301+
.options(
302+
selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))
303+
.joinedload(InstanceModel.jobs)
304+
.load_only(JobModel.id),
305+
)
306+
.options(
307+
selectinload(
308+
FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses()))
309+
).load_only(RunModel.status)
310+
)
311+
)
312+
return res.unique().scalar_one_or_none()
313+
314+
321315
def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optional[FleetSpec]:
322316
if fleet_model.status == FleetStatus.TERMINATING:
323317
return None
@@ -398,27 +392,71 @@ async def _lock_fleet_instances_for_consolidation(
398392
return locked_instance_models
399393

400394

401-
@dataclass
402-
class _ProcessResult:
403-
fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap)
404-
instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict)
405-
new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list)
406-
407-
408-
class _NewInstanceCreate(TypedDict):
409-
id: uuid.UUID
410-
instance_num: int
411-
412-
413-
@dataclass
414-
class _MaintainNodesResult:
415-
instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict)
416-
new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list)
417-
changes_required: bool = False
395+
async def _apply_process_result(
396+
item: PipelineItem,
397+
context: _ProcessContext,
398+
result: "_ProcessResult",
399+
) -> None:
400+
fleet_update_map = _FleetUpdateMap()
401+
fleet_update_map.update(result.fleet_update_map)
402+
set_processed_update_map_fields(fleet_update_map)
403+
set_unlock_update_map_fields(fleet_update_map)
404+
instance_update_rows = _build_instance_update_rows(
405+
result.instance_id_to_update_map,
406+
unlock_instance_ids=context.locked_instance_ids,
407+
)
418408

419-
@property
420-
def has_changes(self) -> bool:
421-
return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0
409+
async with get_session_ctx() as session:
410+
now = get_current_datetime()
411+
resolve_now_placeholders(fleet_update_map, now=now)
412+
resolve_now_placeholders(instance_update_rows, now=now)
413+
res = await session.execute(
414+
update(FleetModel)
415+
.where(
416+
FleetModel.id == context.fleet_model.id,
417+
FleetModel.lock_token == context.fleet_model.lock_token,
418+
)
419+
.values(**fleet_update_map)
420+
.returning(FleetModel.id)
421+
)
422+
updated_ids = list(res.scalars().all())
423+
if len(updated_ids) == 0:
424+
log_lock_token_changed_after_processing(logger, item)
425+
if context.locked_instance_ids:
426+
await _unlock_fleet_locked_instances(
427+
session=session,
428+
item=item,
429+
locked_instance_ids=context.locked_instance_ids,
430+
)
431+
# TODO: Clean up fleet.
432+
return
433+
434+
if fleet_update_map.get("deleted"):
435+
await session.execute(
436+
update(PlacementGroupModel)
437+
.where(PlacementGroupModel.fleet_id == context.fleet_model.id)
438+
.values(fleet_deleted=True)
439+
)
440+
if instance_update_rows:
441+
await session.execute(
442+
update(InstanceModel),
443+
instance_update_rows,
444+
)
445+
if len(result.new_instance_creates) > 0:
446+
await _create_missing_fleet_instances(
447+
session=session,
448+
fleet_model=context.fleet_model,
449+
new_instance_creates=result.new_instance_creates,
450+
)
451+
emit_fleet_status_change_event(
452+
session=session,
453+
fleet_model=context.fleet_model,
454+
old_status=context.fleet_model.status,
455+
new_status=fleet_update_map.get("status", context.fleet_model.status),
456+
status_message=fleet_update_map.get(
457+
"status_message", context.fleet_model.status_message
458+
),
459+
)
422460

423461

424462
async def _process_fleet(

src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,7 @@ async def process(self, item: InstancePipelineItem):
238238
process_context = await _process_terminating_item(item)
239239
if process_context is None:
240240
return
241-
set_processed_update_map_fields(process_context.result.instance_update_map)
242-
set_unlock_update_map_fields(process_context.result.instance_update_map)
241+
243242
await _apply_process_result(
244243
item=item,
245244
instance_model=process_context.instance_model,
@@ -376,6 +375,9 @@ async def _apply_process_result(
376375
instance_model: InstanceModel,
377376
result: ProcessResult,
378377
) -> None:
378+
set_processed_update_map_fields(result.instance_update_map)
379+
set_unlock_update_map_fields(result.instance_update_map)
380+
379381
async with get_session_ctx() as session:
380382
if result.health_check_create is not None:
381383
session.add(InstanceHealthCheckModel(**result.health_check_create))

0 commit comments

Comments
 (0)