diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 56f27b24e..4d542fc0e 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -134,6 +134,7 @@ class JobTerminationReason(str, Enum): FAILED_TO_START_DUE_TO_NO_CAPACITY = "failed_to_start_due_to_no_capacity" INTERRUPTED_BY_NO_CAPACITY = "interrupted_by_no_capacity" INSTANCE_UNREACHABLE = "instance_unreachable" + INSTANCE_ACCESS_REVOKED = "instance_access_revoked" WAITING_INSTANCE_LIMIT_EXCEEDED = "waiting_instance_limit_exceeded" WAITING_RUNNER_LIMIT_EXCEEDED = "waiting_runner_limit_exceeded" TERMINATED_BY_USER = "terminated_by_user" @@ -158,6 +159,7 @@ def to_status(self) -> JobStatus: self.FAILED_TO_START_DUE_TO_NO_CAPACITY: JobStatus.FAILED, self.INTERRUPTED_BY_NO_CAPACITY: JobStatus.FAILED, self.INSTANCE_UNREACHABLE: JobStatus.FAILED, + self.INSTANCE_ACCESS_REVOKED: JobStatus.FAILED, self.WAITING_INSTANCE_LIMIT_EXCEEDED: JobStatus.FAILED, self.WAITING_RUNNER_LIMIT_EXCEEDED: JobStatus.FAILED, self.TERMINATED_BY_USER: JobStatus.TERMINATED, @@ -196,6 +198,7 @@ def to_error(self) -> Optional[str]: # handled and shown in status_message. error_mapping = { JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable", + JobTerminationReason.INSTANCE_ACCESS_REVOKED: "instance access revoked", JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded", JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED: "waiting runner limit exceeded", JobTerminationReason.VOLUME_ERROR: "volume error", diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 7e19b6cec..c813ee93c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -6,7 +6,7 @@ from typing import Dict, Iterable, Literal, Optional, Sequence, Union import httpx -from sqlalchemy import and_, func, or_, select, update +from sqlalchemy import and_, exists, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only @@ -51,7 +51,9 @@ from dstack._internal.server.background.pipeline_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( + ExportedFleetModel, FleetModel, + ImportModel, InstanceModel, JobModel, ProbeModel, @@ -309,6 +311,7 @@ class _ProcessContext: job: Job job_submission: JobSubmission job_provisioning_data: Optional[JobProvisioningData] + instance_access_revoked: bool server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None @property @@ -374,6 +377,7 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce ) run = run_model_to_run(run_model, include_sensitive=True) job = find_job(run.jobs, job_model.replica_num, job_model.job_num) + instance_access_revoked = await _is_instance_access_revoked(session, job_model) job_submission = job_model_to_job_submission(job_model) server_ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance)) return _ProcessContext( @@ -383,12 +387,24 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce job=job, job_submission=job_submission, job_provisioning_data=job_submission.job_provisioning_data, + instance_access_revoked=instance_access_revoked, server_ssh_private_keys=server_ssh_private_keys, ) async def _process_running_job(context: _ProcessContext) -> _ProcessResult: result = _ProcessResult() + if context.instance_access_revoked: + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.INSTANCE_ACCESS_REVOKED, + termination_reason_message=( + "The instance is no longer imported into the job's project" + ), + ) + return result + if context.job_provisioning_data is None: logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) _terminate_job( @@ -559,6 +575,22 @@ async def _fetch_run_model( return res.unique().scalar_one() +async def _is_instance_access_revoked(session: AsyncSession, job_model: JobModel) -> bool: + if job_model.instance is None or job_model.instance.project_id == job_model.project_id: + return False + return not ( + await session.execute( + select( + exists().where( + ImportModel.project_id == job_model.project_id, + ImportModel.export_id == ExportedFleetModel.export_id, + ExportedFleetModel.fleet_id == job_model.instance.fleet_id, + ) + ) + ) + ).scalar() + + async def _process_provisioning_status( context: _ProcessContext, startup_context: _StartupContext, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 676196f87..a3423f8cf 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1995,6 +1995,105 @@ async def test_registers_service_replica_in_gateway_when_running_on_imported_ins ssh_head_proxy_private_key=None, ) + @pytest.mark.parametrize("job_status", [JobStatus.RUNNING, JobStatus.PULLING]) + async def test_terminates_job_when_instance_access_revoked( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + job_status: JobStatus, + ): + user = await create_user(session=session) + exporter_project = await create_project(session=session, name="exporter", owner=user) + importer_project = await create_project(session=session, name="importer", owner=user) + fleet = await create_fleet(session=session, project=exporter_project) + instance = await create_instance( + session=session, + project=exporter_project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + repo = await create_repo(session=session, project_id=importer_project.id) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=job_status, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + # No export created -> the import link no longer exists -> access revoked + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.INSTANCE_ACCESS_REVOKED + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == ( + f"Job status changed {job_status.upper()} -> TERMINATING." + " Termination reason: INSTANCE_ACCESS_REVOKED" + " (The instance is no longer imported into the job's project)" + ) + + @pytest.mark.parametrize("job_status", [JobStatus.RUNNING, JobStatus.PULLING]) + async def test_does_not_terminate_job_when_instance_access_is_valid( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + job_status: JobStatus, + ): + user = await create_user(session=session) + exporter_project = await create_project(session=session, name="exporter", owner=user) + importer_project = await create_project(session=session, name="importer", owner=user) + fleet = await create_fleet(session=session, project=exporter_project) + instance = await create_instance( + session=session, + project=exporter_project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + repo = await create_repo(session=session, project_id=importer_project.id) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=job_status, + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[], job_logs=[], runner_logs=[], last_updated=0 + ) + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == job_status + assert job.termination_reason is None + async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing( self, test_db,