diff --git a/src/dstack/_internal/core/models/instances.py b/src/dstack/_internal/core/models/instances.py index 11a1aca51..8820e0bc1 100644 --- a/src/dstack/_internal/core/models/instances.py +++ b/src/dstack/_internal/core/models/instances.py @@ -273,6 +273,7 @@ class InstanceTerminationReason(str, Enum): NO_OFFERS = "no_offers" MASTER_FAILED = "master_failed" MAX_INSTANCES_LIMIT = "max_instances_limit" + FLEET_SPEC_MISMATCH = "fleet_spec_mismatch" NO_BALANCE = "no_balance" """`NO_BALANCE` is used in dstack Sky.""" diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index a6ebca9f0..e2a4e6a55 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -43,11 +43,13 @@ from dstack._internal.server.services.fleets import ( create_fleet_instance_model, emit_fleet_status_change_event, + get_fleet_requirements, get_fleet_spec, get_next_instance_num, is_fleet_empty, is_fleet_in_use, ) +from dstack._internal.server.services.instances import instance_matches_constraints from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.utils import sentry_utils @@ -313,6 +315,7 @@ async def _refetch_locked_fleet_for_lock_decision( FleetModel.consolidation_attempt, FleetModel.last_consolidated_at, FleetModel.last_processed_at, + FleetModel.created_at, ) ) .execution_options(populate_existing=True) @@ -538,17 +541,28 @@ def _consolidate_fleet_state_with_spec( consolidation_instances: Sequence[InstanceModel], ) -> _ProcessResult: result = _ProcessResult() - maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range( + + spec_mismatch_updates = _terminate_instances_not_matching_fleet_spec( instances=consolidation_instances, fleet_spec=consolidation_fleet_spec, ) + if spec_mismatch_updates: + result.instance_id_to_update_map.update(spec_mismatch_updates) + + # Exclude spec-mismatched instances so min/max check sees only compatible instances. + effective_instances = [i for i in consolidation_instances if i.id not in spec_mismatch_updates] + + maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range( + instances=effective_instances, + fleet_spec=consolidation_fleet_spec, + ) if maintain_nodes_result.has_changes: - result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map + result.instance_id_to_update_map.update(maintain_nodes_result.instance_id_to_update_map) result.new_instance_creates = maintain_nodes_result.new_instance_creates - if maintain_nodes_result.changes_required: + if len(spec_mismatch_updates) > 0 or maintain_nodes_result.changes_required: result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 else: - # The fleet is consolidated with respect to nodes min/max. + # The fleet is consolidated with respect to spec and nodes min/max. result.fleet_update_map["consolidation_attempt"] = 0 result.fleet_update_map["last_consolidated_at"] = NOW_PLACEHOLDER return result @@ -556,7 +570,7 @@ def _consolidate_fleet_state_with_spec( def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool: consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt) - last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at + last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.created_at duration_since_last_consolidation = get_current_datetime() - last_consolidated_at return duration_since_last_consolidation >= consolidation_retry_delay @@ -579,6 +593,47 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: return _CONSOLIDATION_RETRY_DELAYS[-1] +def _terminate_instances_not_matching_fleet_spec( + instances: Sequence[InstanceModel], + fleet_spec: FleetSpec, +) -> dict[uuid.UUID, _InstanceUpdateMap]: + updates: dict[uuid.UUID, _InstanceUpdateMap] = {} + for instance in instances: + if not _can_terminate_spec_mismatched_instance(instance): + continue + if not _instance_matches_fleet_spec(instance, fleet_spec): + updates[instance.id] = { + "status": InstanceStatus.TERMINATING, + "termination_reason": InstanceTerminationReason.FLEET_SPEC_MISMATCH, + "termination_reason_message": "Instance does not match fleet spec", + } + return updates + + +def _can_terminate_spec_mismatched_instance(instance: InstanceModel) -> bool: + if instance.deleted: + return False + # Pending instances have not selected an offer yet, so InstancePipeline will provision them + # using the current fleet spec. Recycle only instances already tied to the old spec. + return instance.status in (InstanceStatus.IDLE, InstanceStatus.PROVISIONING) + + +def _instance_matches_fleet_spec(instance: InstanceModel, fleet_spec: FleetSpec) -> bool: + if instance.offer is None: + # Not yet provisioned — will be provisioned using the current (updated) spec. + return True + profile = fleet_spec.merged_profile + requirements = get_fleet_requirements(fleet_spec) + return instance_matches_constraints( + instance, + backend_types=profile.backends, + regions=profile.regions, + instance_types=profile.instance_types, + zones=profile.availability_zones, + requirements=requirements, + ) + + def _maintain_fleet_nodes_in_min_max_range( instances: Sequence[InstanceModel], fleet_spec: FleetSpec, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index e0cdaa081..5d6a551ce 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1126,8 +1126,9 @@ async def _update_fleet( _check_can_update_fleet_spec(fleet_sensitive.spec, spec) - spec_json = spec.json() - fleet_model.spec = spec_json + fleet_model.spec = spec.json() + # Reset consolidation attempt so the next pipeline pass picks up the spec change promptly. + fleet_model.consolidation_attempt = 0 if ( fleet_sensitive.spec.configuration.ssh_config is not None @@ -1240,7 +1241,22 @@ def _check_can_update_fleet_configuration(current: FleetConfiguration, new: Flee # Current in-place update only persists `target`; FleetPipeline reconciles `min`/`max`. # # For `reservation` and `tags`, update affects only future provisioning. - _check_can_update_inner(current, new, ("nodes", "reservation", "tags")) + _check_can_update_inner( + current, + new, + ( + "nodes", + "reservation", + "tags", + "resources", + "backends", + "regions", + "availability_zones", + "instance_types", + "spot_policy", + "max_price", + ), + ) return if new_ssh_config is None: diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index d54ec8b68..23b42520d 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -357,6 +357,44 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O return host_private_key, proxy_private_keys[0] +def instance_matches_constraints( + instance: InstanceModel, + *, + backend_types: Optional[List[BackendType]] = None, + regions: Optional[List[str]] = None, + instance_types: Optional[List[str]] = None, + zones: Optional[List[str]] = None, + requirements: Optional[Requirements] = None, +) -> bool: + """Check if an instance matches the given provisioning constraints.""" + jpd = get_instance_provisioning_data(instance) + if jpd is not None: + if backend_types is not None and jpd.get_base_backend() not in backend_types: + return False + if regions is not None and jpd.region.lower() not in [r.lower() for r in regions]: + return False + if instance_types is not None and jpd.instance_type.name.lower() not in [ + i.lower() for i in instance_types + ]: + return False + if ( + jpd.availability_zone is not None + and zones is not None + and jpd.availability_zone not in zones + ): + return False + + if requirements is not None: + if instance.offer is None: + return False + offer = InstanceOffer.__response__.parse_raw(instance.offer) + catalog_item = offer_to_catalog_item(offer) + if not gpuhunt.matches(catalog_item, q=requirements_to_query_filter(requirements)): + return False + + return True + + def filter_instances( instances: List[InstanceModel], profile: Profile, @@ -368,9 +406,6 @@ def filter_instances( volumes: Optional[List[List[Volume]]] = None, shared: bool = False, ) -> List[InstanceModel]: - filtered_instances: List[InstanceModel] = [] - candidates: List[InstanceModel] = [] - backend_types = profile.backends regions = profile.regions zones = profile.availability_zones @@ -383,6 +418,7 @@ def filter_instances( v.provisioning_data.availability_zone for v in mount_point_volumes if v.provisioning_data is not None + and v.provisioning_data.availability_zone is not None ] if zones is None: zones = volume_zones @@ -405,12 +441,9 @@ def filter_instances( regions = [master_job_provisioning_data.region] regions = [r for r in regions if r == master_job_provisioning_data.region] - if regions is not None: - regions = [r.lower() for r in regions] instance_types = profile.instance_types - if instance_types is not None: - instance_types = [i.lower() for i in instance_types] + filtered_instances: List[InstanceModel] = [] for instance in instances: if instance.unreachable: continue @@ -418,39 +451,21 @@ def filter_instances( continue if status is not None and instance.status != status: continue - jpd = get_instance_provisioning_data(instance) - if jpd is not None: - if backend_types is not None and jpd.get_base_backend() not in backend_types: - continue - if regions is not None and jpd.region.lower() not in regions: - continue - if instance_types is not None and jpd.instance_type.name.lower() not in instance_types: - continue - if ( - jpd.availability_zone is not None - and zones is not None - and jpd.availability_zone not in zones - ): - continue if instance.total_blocks is None: # Still provisioning, we don't know yet if it shared or not continue if (instance.total_blocks > 1) != shared: continue - - candidates.append(instance) - - if requirements is None: - return candidates - - query_filter = requirements_to_query_filter(requirements) - for instance in candidates: - if instance.offer is None: + if not instance_matches_constraints( + instance, + backend_types=backend_types, + regions=regions, + instance_types=instance_types, + zones=zones, + requirements=requirements, + ): continue - offer = InstanceOffer.__response__.parse_raw(instance.offer) - catalog_item = offer_to_catalog_item(offer) - if gpuhunt.matches(catalog_item, query_filter): - filtered_instances.append(instance) + filtered_instances.append(instance) return filtered_instances diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 54e2c7a78..1acceeeec 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import ( FleetNodesSpec, FleetStatus, @@ -1215,3 +1216,196 @@ async def test_consolidation_attempt_resets_when_no_changes( last_consolidated_at = fleet.last_consolidated_at assert last_consolidated_at assert last_consolidated_at > previous_last_consolidated_at + + async def test_consolidation_terminates_idle_instances_not_matching_fleet_spec( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=2) + spec.configuration.backends = [BackendType.AWS] + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + matching_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.AWS, + instance_num=0, + ) + mismatched_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.GCP, + instance_num=1, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(matching_instance) + await session.refresh(mismatched_instance) + assert matching_instance.status == InstanceStatus.IDLE + assert mismatched_instance.status == InstanceStatus.TERMINATING + assert ( + mismatched_instance.termination_reason == InstanceTerminationReason.FLEET_SPEC_MISMATCH + ) + + async def test_consolidation_preserves_pending_instances_without_offer( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + spec.configuration.backends = [BackendType.AWS] + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + pending_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + instance_num=0, + offer=None, + job_provisioning_data=None, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(pending_instance) + assert pending_instance.status == InstanceStatus.PENDING + + async def test_consolidation_preserves_busy_instances_not_matching_fleet_spec( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + spec.configuration.backends = [BackendType.AWS] + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + busy_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + backend=BackendType.GCP, + instance_num=0, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(busy_instance) + assert busy_instance.status == InstanceStatus.BUSY + + async def test_consolidation_creates_replacements_after_spec_mismatch_termination( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + spec.configuration.backends = [BackendType.AWS] + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.GCP, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.GCP, + instance_num=1, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.TERMINATING + assert instance2.status == InstanceStatus.TERMINATING + # New replacement instances should be created to satisfy nodes.min=2 + all_instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + new_instances = [i for i in all_instances if i.status == InstanceStatus.PENDING] + assert len(new_instances) == 2 + assert fleet.consolidation_attempt == 1 + + async def test_consolidation_preserves_instances_matching_fleet_spec( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + spec.configuration.backends = [BackendType.AWS] + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + backend=BackendType.AWS, + instance_num=0, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance) + assert instance.status == InstanceStatus.IDLE + assert fleet.consolidation_attempt == 0 diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index e7d603861..5c5cef8c6 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -2276,6 +2276,35 @@ async def test_returns_create_plan_for_existing_cloud_fleet_blocks_update( assert response_json["current_resource"]["id"] == str(fleet.id) assert response_json["action"] == "create" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_update_plan_for_existing_cloud_fleet_provisioning_fields_update( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_spec = get_fleet_spec( + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1)) + ) + spec = current_spec.copy(deep=True) + spec.configuration.backends = [BackendType.AWS] + spec.configuration.regions = ["us-east-1"] + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + response_json = response.json() + assert response.status_code == 200, response_json + assert response_json["current_resource"]["id"] == str(fleet.id) + assert response_json["action"] == "update" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_create_plan_for_existing_fleet(