Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class InstanceTerminationReason(str, Enum):
NO_OFFERS = "no_offers"
MASTER_FAILED = "master_failed"
MAX_INSTANCES_LIMIT = "max_instances_limit"
FLEET_SPEC_MISMATCH = "fleet_spec_mismatch"
NO_BALANCE = "no_balance"
"""`NO_BALANCE` is used in dstack Sky."""

Expand Down
65 changes: 60 additions & 5 deletions src/dstack/_internal/server/background/pipeline_tasks/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@
from dstack._internal.server.services.fleets import (
create_fleet_instance_model,
emit_fleet_status_change_event,
get_fleet_requirements,
get_fleet_spec,
get_next_instance_num,
is_fleet_empty,
is_fleet_in_use,
)
from dstack._internal.server.services.instances import instance_matches_constraints
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.utils import sentry_utils
Expand Down Expand Up @@ -313,6 +315,7 @@ async def _refetch_locked_fleet_for_lock_decision(
FleetModel.consolidation_attempt,
FleetModel.last_consolidated_at,
FleetModel.last_processed_at,
FleetModel.created_at,
)
)
.execution_options(populate_existing=True)
Expand Down Expand Up @@ -538,25 +541,36 @@ def _consolidate_fleet_state_with_spec(
consolidation_instances: Sequence[InstanceModel],
) -> _ProcessResult:
result = _ProcessResult()
maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(

spec_mismatch_updates = _terminate_instances_not_matching_fleet_spec(
instances=consolidation_instances,
fleet_spec=consolidation_fleet_spec,
)
if spec_mismatch_updates:
result.instance_id_to_update_map.update(spec_mismatch_updates)

# Exclude spec-mismatched instances so min/max check sees only compatible instances.
effective_instances = [i for i in consolidation_instances if i.id not in spec_mismatch_updates]

maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(
instances=effective_instances,
fleet_spec=consolidation_fleet_spec,
)
if maintain_nodes_result.has_changes:
result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map
result.instance_id_to_update_map.update(maintain_nodes_result.instance_id_to_update_map)
result.new_instance_creates = maintain_nodes_result.new_instance_creates
if maintain_nodes_result.changes_required:
if len(spec_mismatch_updates) > 0 or maintain_nodes_result.changes_required:
result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1
else:
# The fleet is consolidated with respect to nodes min/max.
# The fleet is consolidated with respect to spec and nodes min/max.
result.fleet_update_map["consolidation_attempt"] = 0
result.fleet_update_map["last_consolidated_at"] = NOW_PLACEHOLDER
return result


def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool:
consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt)
last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at
last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.created_at
duration_since_last_consolidation = get_current_datetime() - last_consolidated_at
return duration_since_last_consolidation >= consolidation_retry_delay

Expand All @@ -579,6 +593,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta:
return _CONSOLIDATION_RETRY_DELAYS[-1]


def _terminate_instances_not_matching_fleet_spec(
instances: Sequence[InstanceModel],
fleet_spec: FleetSpec,
) -> dict[uuid.UUID, _InstanceUpdateMap]:
updates: dict[uuid.UUID, _InstanceUpdateMap] = {}
for instance in instances:
if not _can_terminate_spec_mismatched_instance(instance):
continue
if not _instance_matches_fleet_spec(instance, fleet_spec):
updates[instance.id] = {
"status": InstanceStatus.TERMINATING,
"termination_reason": InstanceTerminationReason.FLEET_SPEC_MISMATCH,
"termination_reason_message": "Instance does not match fleet spec",
}
return updates


def _can_terminate_spec_mismatched_instance(instance: InstanceModel) -> bool:
if instance.deleted:
return False
# Pending instances have not selected an offer yet, so InstancePipeline will provision them
# using the current fleet spec. Recycle only instances already tied to the old spec.
return instance.status in (InstanceStatus.IDLE, InstanceStatus.PROVISIONING)


def _instance_matches_fleet_spec(instance: InstanceModel, fleet_spec: FleetSpec) -> bool:
if instance.offer is None:
# Not yet provisioned — will be provisioned using the current (updated) spec.
return True
profile = fleet_spec.merged_profile
requirements = get_fleet_requirements(fleet_spec)
return instance_matches_constraints(
instance,
backend_types=profile.backends,
regions=profile.regions,
instance_types=profile.instance_types,
zones=profile.availability_zones,
requirements=requirements,
)


def _maintain_fleet_nodes_in_min_max_range(
instances: Sequence[InstanceModel],
fleet_spec: FleetSpec,
Expand Down
22 changes: 19 additions & 3 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,9 @@ async def _update_fleet(

_check_can_update_fleet_spec(fleet_sensitive.spec, spec)

spec_json = spec.json()
fleet_model.spec = spec_json
fleet_model.spec = spec.json()
# Reset consolidation attempt so the next pipeline pass picks up the spec change promptly.
fleet_model.consolidation_attempt = 0

if (
fleet_sensitive.spec.configuration.ssh_config is not None
Expand Down Expand Up @@ -1240,7 +1241,22 @@ def _check_can_update_fleet_configuration(current: FleetConfiguration, new: Flee
# Current in-place update only persists `target`; FleetPipeline reconciles `min`/`max`.
#
# For `reservation` and `tags`, update affects only future provisioning.
_check_can_update_inner(current, new, ("nodes", "reservation", "tags"))
_check_can_update_inner(
current,
new,
(
"nodes",
"reservation",
"tags",
"resources",
"backends",
"regions",
"availability_zones",
"instance_types",
"spot_policy",
"max_price",
),
)
return

if new_ssh_config is None:
Expand Down
83 changes: 49 additions & 34 deletions src/dstack/_internal/server/services/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,44 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O
return host_private_key, proxy_private_keys[0]


def instance_matches_constraints(
instance: InstanceModel,
*,
backend_types: Optional[List[BackendType]] = None,
regions: Optional[List[str]] = None,
instance_types: Optional[List[str]] = None,
zones: Optional[List[str]] = None,
requirements: Optional[Requirements] = None,
) -> bool:
"""Check if an instance matches the given provisioning constraints."""
jpd = get_instance_provisioning_data(instance)
if jpd is not None:
if backend_types is not None and jpd.get_base_backend() not in backend_types:
return False
if regions is not None and jpd.region.lower() not in [r.lower() for r in regions]:
return False
if instance_types is not None and jpd.instance_type.name.lower() not in [
i.lower() for i in instance_types
]:
return False
if (
jpd.availability_zone is not None
and zones is not None
and jpd.availability_zone not in zones
):
return False

if requirements is not None:
if instance.offer is None:
return False
offer = InstanceOffer.__response__.parse_raw(instance.offer)
catalog_item = offer_to_catalog_item(offer)
if not gpuhunt.matches(catalog_item, q=requirements_to_query_filter(requirements)):
return False

return True


def filter_instances(
instances: List[InstanceModel],
profile: Profile,
Expand All @@ -368,9 +406,6 @@ def filter_instances(
volumes: Optional[List[List[Volume]]] = None,
shared: bool = False,
) -> List[InstanceModel]:
filtered_instances: List[InstanceModel] = []
candidates: List[InstanceModel] = []

backend_types = profile.backends
regions = profile.regions
zones = profile.availability_zones
Expand All @@ -383,6 +418,7 @@ def filter_instances(
v.provisioning_data.availability_zone
for v in mount_point_volumes
if v.provisioning_data is not None
and v.provisioning_data.availability_zone is not None
]
if zones is None:
zones = volume_zones
Expand All @@ -405,52 +441,31 @@ def filter_instances(
regions = [master_job_provisioning_data.region]
regions = [r for r in regions if r == master_job_provisioning_data.region]

if regions is not None:
regions = [r.lower() for r in regions]
instance_types = profile.instance_types
if instance_types is not None:
instance_types = [i.lower() for i in instance_types]

filtered_instances: List[InstanceModel] = []
for instance in instances:
if instance.unreachable:
continue
if instance.health.is_failure():
continue
if status is not None and instance.status != status:
continue
jpd = get_instance_provisioning_data(instance)
if jpd is not None:
if backend_types is not None and jpd.get_base_backend() not in backend_types:
continue
if regions is not None and jpd.region.lower() not in regions:
continue
if instance_types is not None and jpd.instance_type.name.lower() not in instance_types:
continue
if (
jpd.availability_zone is not None
and zones is not None
and jpd.availability_zone not in zones
):
continue
if instance.total_blocks is None:
# Still provisioning, we don't know yet if it shared or not
continue
if (instance.total_blocks > 1) != shared:
continue

candidates.append(instance)

if requirements is None:
return candidates

query_filter = requirements_to_query_filter(requirements)
for instance in candidates:
if instance.offer is None:
if not instance_matches_constraints(
instance,
backend_types=backend_types,
regions=regions,
instance_types=instance_types,
zones=zones,
requirements=requirements,
):
continue
offer = InstanceOffer.__response__.parse_raw(instance.offer)
catalog_item = offer_to_catalog_item(offer)
if gpuhunt.matches(catalog_item, query_filter):
filtered_instances.append(instance)
filtered_instances.append(instance)
return filtered_instances


Expand Down
Loading
Loading