Skip to content

Commit 6b6cfb9

Browse files
authored
Implement cloud fleet in-place update for provisioning fields (#3775)
* Implement cloud fleet in-place update for provisioning fields * Fix consolidation nor running * Adjust termination_reason_message
1 parent d061685 commit 6b6cfb9

6 files changed

Lines changed: 352 additions & 42 deletions

File tree

src/dstack/_internal/core/models/instances.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ class InstanceTerminationReason(str, Enum):
273273
NO_OFFERS = "no_offers"
274274
MASTER_FAILED = "master_failed"
275275
MAX_INSTANCES_LIMIT = "max_instances_limit"
276+
FLEET_SPEC_MISMATCH = "fleet_spec_mismatch"
276277
NO_BALANCE = "no_balance"
277278
"""`NO_BALANCE` is used in dstack Sky."""
278279

src/dstack/_internal/server/background/pipeline_tasks/fleets.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@
4343
from dstack._internal.server.services.fleets import (
4444
create_fleet_instance_model,
4545
emit_fleet_status_change_event,
46+
get_fleet_requirements,
4647
get_fleet_spec,
4748
get_next_instance_num,
4849
is_fleet_empty,
4950
is_fleet_in_use,
5051
)
52+
from dstack._internal.server.services.instances import instance_matches_constraints
5153
from dstack._internal.server.services.locking import get_locker
5254
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
5355
from dstack._internal.server.utils import sentry_utils
@@ -313,6 +315,7 @@ async def _refetch_locked_fleet_for_lock_decision(
313315
FleetModel.consolidation_attempt,
314316
FleetModel.last_consolidated_at,
315317
FleetModel.last_processed_at,
318+
FleetModel.created_at,
316319
)
317320
)
318321
.execution_options(populate_existing=True)
@@ -538,25 +541,36 @@ def _consolidate_fleet_state_with_spec(
538541
consolidation_instances: Sequence[InstanceModel],
539542
) -> _ProcessResult:
540543
result = _ProcessResult()
541-
maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(
544+
545+
spec_mismatch_updates = _terminate_instances_not_matching_fleet_spec(
542546
instances=consolidation_instances,
543547
fleet_spec=consolidation_fleet_spec,
544548
)
549+
if spec_mismatch_updates:
550+
result.instance_id_to_update_map.update(spec_mismatch_updates)
551+
552+
# Exclude spec-mismatched instances so min/max check sees only compatible instances.
553+
effective_instances = [i for i in consolidation_instances if i.id not in spec_mismatch_updates]
554+
555+
maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(
556+
instances=effective_instances,
557+
fleet_spec=consolidation_fleet_spec,
558+
)
545559
if maintain_nodes_result.has_changes:
546-
result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map
560+
result.instance_id_to_update_map.update(maintain_nodes_result.instance_id_to_update_map)
547561
result.new_instance_creates = maintain_nodes_result.new_instance_creates
548-
if maintain_nodes_result.changes_required:
562+
if len(spec_mismatch_updates) > 0 or maintain_nodes_result.changes_required:
549563
result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1
550564
else:
551-
# The fleet is consolidated with respect to nodes min/max.
565+
# The fleet is consolidated with respect to spec and nodes min/max.
552566
result.fleet_update_map["consolidation_attempt"] = 0
553567
result.fleet_update_map["last_consolidated_at"] = NOW_PLACEHOLDER
554568
return result
555569

556570

557571
def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool:
558572
consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt)
559-
last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at
573+
last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.created_at
560574
duration_since_last_consolidation = get_current_datetime() - last_consolidated_at
561575
return duration_since_last_consolidation >= consolidation_retry_delay
562576

@@ -579,6 +593,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
579593
return _CONSOLIDATION_RETRY_DELAYS[-1]
580594

581595

596+
def _terminate_instances_not_matching_fleet_spec(
597+
instances: Sequence[InstanceModel],
598+
fleet_spec: FleetSpec,
599+
) -> dict[uuid.UUID, _InstanceUpdateMap]:
600+
updates: dict[uuid.UUID, _InstanceUpdateMap] = {}
601+
for instance in instances:
602+
if not _can_terminate_spec_mismatched_instance(instance):
603+
continue
604+
if not _instance_matches_fleet_spec(instance, fleet_spec):
605+
updates[instance.id] = {
606+
"status": InstanceStatus.TERMINATING,
607+
"termination_reason": InstanceTerminationReason.FLEET_SPEC_MISMATCH,
608+
"termination_reason_message": "Instance does not match fleet spec",
609+
}
610+
return updates
611+
612+
613+
def _can_terminate_spec_mismatched_instance(instance: InstanceModel) -> bool:
614+
if instance.deleted:
615+
return False
616+
# Pending instances have not selected an offer yet, so InstancePipeline will provision them
617+
# using the current fleet spec. Recycle only instances already tied to the old spec.
618+
return instance.status in (InstanceStatus.IDLE, InstanceStatus.PROVISIONING)
619+
620+
621+
def _instance_matches_fleet_spec(instance: InstanceModel, fleet_spec: FleetSpec) -> bool:
622+
if instance.offer is None:
623+
# Not yet provisioned — will be provisioned using the current (updated) spec.
624+
return True
625+
profile = fleet_spec.merged_profile
626+
requirements = get_fleet_requirements(fleet_spec)
627+
return instance_matches_constraints(
628+
instance,
629+
backend_types=profile.backends,
630+
regions=profile.regions,
631+
instance_types=profile.instance_types,
632+
zones=profile.availability_zones,
633+
requirements=requirements,
634+
)
635+
636+
582637
def _maintain_fleet_nodes_in_min_max_range(
583638
instances: Sequence[InstanceModel],
584639
fleet_spec: FleetSpec,

src/dstack/_internal/server/services/fleets.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,9 @@ async def _update_fleet(
11261126

11271127
_check_can_update_fleet_spec(fleet_sensitive.spec, spec)
11281128

1129-
spec_json = spec.json()
1130-
fleet_model.spec = spec_json
1129+
fleet_model.spec = spec.json()
1130+
# Reset consolidation attempt so the next pipeline pass picks up the spec change promptly.
1131+
fleet_model.consolidation_attempt = 0
11311132

11321133
if (
11331134
fleet_sensitive.spec.configuration.ssh_config is not None
@@ -1240,7 +1241,22 @@ def _check_can_update_fleet_configuration(current: FleetConfiguration, new: Flee
12401241
# Current in-place update only persists `target`; FleetPipeline reconciles `min`/`max`.
12411242
#
12421243
# For `reservation` and `tags`, update affects only future provisioning.
1243-
_check_can_update_inner(current, new, ("nodes", "reservation", "tags"))
1244+
_check_can_update_inner(
1245+
current,
1246+
new,
1247+
(
1248+
"nodes",
1249+
"reservation",
1250+
"tags",
1251+
"resources",
1252+
"backends",
1253+
"regions",
1254+
"availability_zones",
1255+
"instance_types",
1256+
"spot_policy",
1257+
"max_price",
1258+
),
1259+
)
12441260
return
12451261

12461262
if new_ssh_config is None:

src/dstack/_internal/server/services/instances.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,44 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O
357357
return host_private_key, proxy_private_keys[0]
358358

359359

360+
def instance_matches_constraints(
361+
instance: InstanceModel,
362+
*,
363+
backend_types: Optional[List[BackendType]] = None,
364+
regions: Optional[List[str]] = None,
365+
instance_types: Optional[List[str]] = None,
366+
zones: Optional[List[str]] = None,
367+
requirements: Optional[Requirements] = None,
368+
) -> bool:
369+
"""Check if an instance matches the given provisioning constraints."""
370+
jpd = get_instance_provisioning_data(instance)
371+
if jpd is not None:
372+
if backend_types is not None and jpd.get_base_backend() not in backend_types:
373+
return False
374+
if regions is not None and jpd.region.lower() not in [r.lower() for r in regions]:
375+
return False
376+
if instance_types is not None and jpd.instance_type.name.lower() not in [
377+
i.lower() for i in instance_types
378+
]:
379+
return False
380+
if (
381+
jpd.availability_zone is not None
382+
and zones is not None
383+
and jpd.availability_zone not in zones
384+
):
385+
return False
386+
387+
if requirements is not None:
388+
if instance.offer is None:
389+
return False
390+
offer = InstanceOffer.__response__.parse_raw(instance.offer)
391+
catalog_item = offer_to_catalog_item(offer)
392+
if not gpuhunt.matches(catalog_item, q=requirements_to_query_filter(requirements)):
393+
return False
394+
395+
return True
396+
397+
360398
def filter_instances(
361399
instances: List[InstanceModel],
362400
profile: Profile,
@@ -368,9 +406,6 @@ def filter_instances(
368406
volumes: Optional[List[List[Volume]]] = None,
369407
shared: bool = False,
370408
) -> List[InstanceModel]:
371-
filtered_instances: List[InstanceModel] = []
372-
candidates: List[InstanceModel] = []
373-
374409
backend_types = profile.backends
375410
regions = profile.regions
376411
zones = profile.availability_zones
@@ -383,6 +418,7 @@ def filter_instances(
383418
v.provisioning_data.availability_zone
384419
for v in mount_point_volumes
385420
if v.provisioning_data is not None
421+
and v.provisioning_data.availability_zone is not None
386422
]
387423
if zones is None:
388424
zones = volume_zones
@@ -405,52 +441,31 @@ def filter_instances(
405441
regions = [master_job_provisioning_data.region]
406442
regions = [r for r in regions if r == master_job_provisioning_data.region]
407443

408-
if regions is not None:
409-
regions = [r.lower() for r in regions]
410444
instance_types = profile.instance_types
411-
if instance_types is not None:
412-
instance_types = [i.lower() for i in instance_types]
413445

446+
filtered_instances: List[InstanceModel] = []
414447
for instance in instances:
415448
if instance.unreachable:
416449
continue
417450
if instance.health.is_failure():
418451
continue
419452
if status is not None and instance.status != status:
420453
continue
421-
jpd = get_instance_provisioning_data(instance)
422-
if jpd is not None:
423-
if backend_types is not None and jpd.get_base_backend() not in backend_types:
424-
continue
425-
if regions is not None and jpd.region.lower() not in regions:
426-
continue
427-
if instance_types is not None and jpd.instance_type.name.lower() not in instance_types:
428-
continue
429-
if (
430-
jpd.availability_zone is not None
431-
and zones is not None
432-
and jpd.availability_zone not in zones
433-
):
434-
continue
435454
if instance.total_blocks is None:
436455
# Still provisioning, we don't know yet if it shared or not
437456
continue
438457
if (instance.total_blocks > 1) != shared:
439458
continue
440-
441-
candidates.append(instance)
442-
443-
if requirements is None:
444-
return candidates
445-
446-
query_filter = requirements_to_query_filter(requirements)
447-
for instance in candidates:
448-
if instance.offer is None:
459+
if not instance_matches_constraints(
460+
instance,
461+
backend_types=backend_types,
462+
regions=regions,
463+
instance_types=instance_types,
464+
zones=zones,
465+
requirements=requirements,
466+
):
449467
continue
450-
offer = InstanceOffer.__response__.parse_raw(instance.offer)
451-
catalog_item = offer_to_catalog_item(offer)
452-
if gpuhunt.matches(catalog_item, query_filter):
453-
filtered_instances.append(instance)
468+
filtered_instances.append(instance)
454469
return filtered_instances
455470

456471

0 commit comments

Comments
 (0)