diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index fe1dd4c0c..0e3dbdd80 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -34,9 +34,10 @@ InstanceGroupPlacement, ) from dstack._internal.core.models.instances import InstanceStatus, SSHKey -from dstack._internal.core.services.diff import diff_models +from dstack._internal.core.services.diff import copy_model, diff_models from dstack._internal.utils.common import local_time from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.nested_list import NestedList, NestedListItem from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str from dstack.api.utils import load_profile @@ -85,14 +86,10 @@ def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace): ) confirm_message += "Create the fleet?" else: + effective_spec = plan.get_effective_spec() + diff = _render_fleet_spec_diff(plan.current_resource.spec, effective_spec) action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]." - if plan.action == ApplyAction.CREATE: - delete_fleet_name = plan.current_resource.name - action_message += ( - " Configuration changes detected. Cannot update the fleet in-place" - ) - confirm_message += "Re-create the fleet?" - elif plan.current_resource.spec == plan.effective_spec: + if plan.current_resource.spec == effective_spec: if command_args.yes and not command_args.force: # --force is required only with --yes, # otherwise we may ask for force apply interactively. @@ -103,8 +100,26 @@ def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace): delete_fleet_name = plan.current_resource.name action_message += " No configuration changes detected." confirm_message += "Re-create the fleet?" + elif plan.action == ApplyAction.CREATE: + delete_fleet_name = plan.current_resource.name + if diff is not None: + # TODO: Highlight only the fields that block in-place update instead of + # showing the full detected diff here. + action_message += ( + f" Detected changes that [error]cannot[/] be updated in-place:\n{diff}" + ) + else: + action_message += ( + " Configuration changes detected. Cannot update the fleet in-place." + ) + confirm_message += "Re-create the fleet?" else: - action_message += " Configuration changes detected." + if diff is not None: + action_message += ( + f" Detected changes that [code]can[/] be updated in-place:\n{diff}" + ) + else: + action_message += " Configuration changes detected." confirm_message += "Update the fleet in-place?" console.print(action_message) @@ -357,6 +372,44 @@ def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]: exit() +def _render_fleet_spec_diff(old_spec: FleetSpec, new_spec: FleetSpec) -> Optional[str]: + old_spec = copy_model(old_spec) + new_spec = copy_model(new_spec) + changed_spec_fields = list(diff_models(old_spec, new_spec)) + if not changed_spec_fields: + return None + + nested_list = NestedList() + for spec_field in changed_spec_fields: + if spec_field == "merged_profile": + continue + if spec_field == "configuration": + item = NestedListItem( + "Configuration properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.configuration, new_spec.configuration) + ], + ) + elif spec_field == "profile": + item = NestedListItem( + "Profile properties:", + children=[ + NestedListItem(field) + for field in diff_models(old_spec.profile, new_spec.profile) + ], + ) + elif spec_field == "configuration_path": + item = NestedListItem("Configuration path") + else: + item = NestedListItem(spec_field.replace("_", " ").capitalize()) + nested_list.children.append(item) + + if not nested_list.children: + return None + return nested_list.render() + + def _print_plan_header(plan: FleetPlan): def th(s: str) -> str: return f"[bold]{s}[/bold]" diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 578c4a9eb..98de46ed4 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -152,6 +152,8 @@ def apply_configuration( confirm_message = "Stop and override the run?" elif not run_plan.current_resource.status.is_finished(): stop_run_name = run_plan.current_resource.run_spec.run_name + # TODO: Highlight only the fields that block in-place update instead of + # showing the full detected diff here. console.print( f"Active run [code]{conf.name}[/] already exists." f" Detected changes that [error]cannot[/] be updated in-place:\n{diff}" diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 02b079905..a6ebca9f0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -63,7 +63,7 @@ def __init__( workers_num: int = 10, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=30), + min_processing_interval: timedelta = timedelta(seconds=15), lock_timeout: timedelta = timedelta(seconds=20), heartbeat_trigger: timedelta = timedelta(seconds=10), *, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 003ecd906..e0cdaa081 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1219,26 +1219,36 @@ def _check_can_update_inner(current: M, new: M, updatable_fields: tuple[str, ... return diff -@_check_can_update("configuration", "configuration_path") +@_check_can_update("configuration", "configuration_path", "merged_profile") def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: ModelDiff): + # Allow `merged_profile` only to absorb derived changes from supported configuration updates + # such as `configuration.reservation` and `configuration.tags`. + # Direct `profile` updates are still not in-place updatable. if "configuration" in diff: _check_can_update_fleet_configuration(current.configuration, new.configuration) -@_check_can_update("ssh_config") -def _check_can_update_fleet_configuration( - current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff -): +def _check_can_update_fleet_configuration(current: FleetConfiguration, new: FleetConfiguration): + diff = diff_models(current, new) + current_ssh_config = current.ssh_config + new_ssh_config = new.ssh_config + if current_ssh_config is None: + if new_ssh_config is not None: + raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update") + # TODO: Support best-effort `nodes.target` apply semantics: + # create missing instances and terminate extra idle instances. + # Current in-place update only persists `target`; FleetPipeline reconciles `min`/`max`. + # + # For `reservation` and `tags`, update affects only future provisioning. + _check_can_update_inner(current, new, ("nodes", "reservation", "tags")) + return + + if new_ssh_config is None: + raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update") + + _check_can_update_inner(current, new, ("ssh_config",)) if "ssh_config" in diff: - current_ssh_config = current.ssh_config - new_ssh_config = new.ssh_config - if current_ssh_config is None: - if new_ssh_config is not None: - raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update") - elif new_ssh_config is None: - raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update") - else: - _check_can_update_ssh_config(current_ssh_config, new_ssh_config) + _check_can_update_ssh_config(current_ssh_config, new_ssh_config) @_check_can_update("hosts") diff --git a/src/tests/_internal/cli/services/configurators/test_fleet.py b/src/tests/_internal/cli/services/configurators/test_fleet.py index a14b5c7ac..f1d0bfe22 100644 --- a/src/tests/_internal/cli/services/configurators/test_fleet.py +++ b/src/tests/_internal/cli/services/configurators/test_fleet.py @@ -1,13 +1,129 @@ import argparse -from typing import List, Tuple +from datetime import datetime, timezone +from textwrap import dedent +from typing import List, Optional, Tuple from unittest.mock import Mock +from uuid import uuid4 import pytest +from rich.console import Console -from dstack._internal.cli.services.configurators.fleet import FleetConfigurator +import dstack._internal.cli.services.configurators.fleet as fleet_configurator_module +from dstack._internal.cli.services.configurators.fleet import ( + FleetConfigurator, + _render_fleet_spec_diff, +) from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.envs import Env -from dstack._internal.core.models.fleets import FleetConfiguration +from dstack._internal.core.models.fleets import ( + Fleet, + FleetConfiguration, + FleetNodesSpec, + FleetPlan, + FleetSpec, + FleetStatus, + InstanceGroupPlacement, +) +from dstack._internal.core.models.profiles import Profile + + +def create_conf() -> FleetConfiguration: + return FleetConfiguration.parse_obj({"ssh_config": {"hosts": ["1.2.3.4"]}}) + + +def apply_args( + conf: FleetConfiguration, args: List[str] +) -> Tuple[FleetConfiguration, argparse.Namespace]: + parser = argparse.ArgumentParser() + configurator = FleetConfigurator(Mock()) + configurator.register_args(parser) + conf = conf.copy(deep=True) + configurator_args = parser.parse_args(args) + configurator.apply_args(conf, configurator_args) + return conf, configurator_args + + +def get_cloud_fleet_spec( + *, + name: str = "test-fleet", + nodes: Optional[FleetNodesSpec] = None, + placement: Optional[InstanceGroupPlacement] = None, +) -> FleetSpec: + if nodes is None: + nodes = FleetNodesSpec(min=0, target=0, max=1) + return FleetSpec( + configuration=FleetConfiguration( + name=name, + nodes=nodes, + placement=placement, + ), + configuration_path="fleet.dstack.yml", + profile=Profile(), + ) + + +def get_ssh_fleet_spec( + *, + name: str = "test-fleet", + hosts: Optional[list[str]] = None, +) -> FleetSpec: + if hosts is None: + hosts = ["10.0.0.100"] + return FleetSpec( + configuration=FleetConfiguration.parse_obj( + { + "name": name, + "ssh_config": {"hosts": hosts}, + } + ), + configuration_path="fleet.dstack.yml", + profile=Profile(), + ) + + +def create_fleet_plan( + *, + current_spec: FleetSpec, + spec: FleetSpec, + action: ApplyAction, +) -> FleetPlan: + return FleetPlan( + project_name="test-project", + user="test-user", + spec=spec, + effective_spec=spec, + current_resource=Fleet( + id=uuid4(), + name=current_spec.configuration.name or "test-fleet", + project_name="test-project", + spec=current_spec, + created_at=datetime.now(timezone.utc), + status=FleetStatus.ACTIVE, + instances=[], + ), + offers=[], + total_offers=0, + action=action, + ) + + +def get_command_args() -> argparse.Namespace: + return argparse.Namespace( + yes=False, + force=False, + detach=False, + ) + + +def patch_console_and_confirm( + monkeypatch: pytest.MonkeyPatch, +) -> tuple[Console, Mock]: + console = Console(record=True, force_terminal=False, color_system=None, width=120) + confirm_ask = Mock(return_value=False) + monkeypatch.setattr(fleet_configurator_module, "console", console) + monkeypatch.setattr(fleet_configurator_module, "confirm_ask", confirm_ask) + return console, confirm_ask class TestFleetConfigurator: @@ -39,17 +155,93 @@ def test_env_value_from_environ_not_set(self, monkeypatch: pytest.MonkeyPatch): apply_args(conf, ["--env", "FROM_ENV"]) -def create_conf() -> FleetConfiguration: - return FleetConfiguration.parse_obj({"ssh_config": {"hosts": ["1.2.3.4"]}}) +class TestApplyPlanMessages: + def test_prints_in_place_update_diff(self, monkeypatch: pytest.MonkeyPatch): + console, confirm_ask = patch_console_and_confirm(monkeypatch) + current_spec = get_cloud_fleet_spec(nodes=FleetNodesSpec(min=0, target=0, max=1)) + spec = get_cloud_fleet_spec(nodes=FleetNodesSpec(min=1, target=1, max=1)) + plan = create_fleet_plan( + current_spec=current_spec, + spec=spec, + action=ApplyAction.UPDATE, + ) + FleetConfigurator(Mock())._apply_plan(plan, get_command_args()) -def apply_args( - conf: FleetConfiguration, args: List[str] -) -> Tuple[FleetConfiguration, argparse.Namespace]: - parser = argparse.ArgumentParser() - configurator = FleetConfigurator(Mock()) - configurator.register_args(parser) - conf = conf.copy(deep=True) - configurator_args = parser.parse_args(args) - configurator.apply_args(conf, configurator_args) - return conf, configurator_args + output = console.export_text() + assert "Found fleet test-fleet." in output + assert "Detected changes that can be updated in-place:" in output + assert "- Configuration properties:" in output + assert " - nodes" in output + confirm_ask.assert_called_once_with("Update the fleet in-place?") + + def test_prints_recreate_diff(self, monkeypatch: pytest.MonkeyPatch): + console, confirm_ask = patch_console_and_confirm(monkeypatch) + current_spec = get_cloud_fleet_spec(placement=InstanceGroupPlacement.ANY) + spec = get_cloud_fleet_spec(placement=InstanceGroupPlacement.CLUSTER) + plan = create_fleet_plan( + current_spec=current_spec, + spec=spec, + action=ApplyAction.CREATE, + ) + + FleetConfigurator(Mock())._apply_plan(plan, get_command_args()) + + output = console.export_text() + assert "Found fleet test-fleet." in output + assert "Detected changes that cannot be updated in-place:" in output + assert "- Configuration properties:" in output + assert " - placement" in output + confirm_ask.assert_called_once_with("Re-create the fleet?") + + def test_prints_no_diff_message(self, monkeypatch: pytest.MonkeyPatch): + console, confirm_ask = patch_console_and_confirm(monkeypatch) + spec = get_cloud_fleet_spec() + plan = create_fleet_plan( + current_spec=spec, + spec=spec.copy(deep=True), + action=ApplyAction.UPDATE, + ) + + FleetConfigurator(Mock())._apply_plan(plan, get_command_args()) + + output = console.export_text() + assert "Found fleet test-fleet." in output + assert "No configuration changes detected." in output + assert "Detected changes that" not in output + confirm_ask.assert_called_once_with("Re-create the fleet?") + + +class TestRenderFleetSpecDiff: + def test_renders_cloud_nodes_change(self): + old = get_cloud_fleet_spec(nodes=FleetNodesSpec(min=0, target=0, max=1)) + new = get_cloud_fleet_spec(nodes=FleetNodesSpec(min=1, target=1, max=1)) + + assert ( + _render_fleet_spec_diff(old, new) + == dedent( + """ + - Configuration properties: + - nodes + """ + ).lstrip() + ) + + def test_renders_ssh_hosts_change(self): + old = get_ssh_fleet_spec(hosts=["10.0.0.100"]) + new = get_ssh_fleet_spec(hosts=["10.0.0.100", "10.0.0.101"]) + + assert ( + _render_fleet_spec_diff(old, new) + == dedent( + """ + - Configuration properties: + - ssh_config + """ + ).lstrip() + ) + + def test_no_diff(self): + spec = get_cloud_fleet_spec() + + assert _render_fleet_spec_diff(spec, spec.copy(deep=True)) is None diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 3443f05a8..e7d603861 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1436,6 +1436,108 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A assert instance.status == InstanceStatus.PENDING assert instance.remote_connection_info is not None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_updates_cloud_fleet_nodes_in_place_when_fleet_in_use( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_spec = get_fleet_spec( + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=2)) + ) + fleet = await create_fleet(session=session, project=project, spec=current_spec) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user, fleet=fleet) + job = await create_job(session=session, run=run, fleet=fleet) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + job=job, + status=InstanceStatus.BUSY, + instance_num=0, + ) + spec = current_spec.copy(deep=True) + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=3) + + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={ + "plan": { + "spec": spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + }, + "force": False, + }, + ) + + response_json = response.json() + assert response.status_code == 200, response_json + assert response_json["id"] == str(fleet.id) + assert response_json["spec"]["configuration"]["nodes"] == {"min": 1, "max": 3} + + await session.refresh(fleet) + await session.refresh(instance) + assert json.loads(fleet.spec)["configuration"]["nodes"] == {"min": 1, "max": 3} + assert instance.status == InstanceStatus.BUSY + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_updates_cloud_fleet_nodes_target_without_changing_instance_count( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_spec = get_fleet_spec( + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1)) + ) + fleet = await create_fleet(session=session, project=project, spec=current_spec) + spec = current_spec.copy(deep=True) + spec.configuration.nodes = FleetNodesSpec(min=0, target=1, max=1) + + response = await client.post( + f"/api/project/{project.name}/fleets/apply", + headers=get_auth_headers(user.token), + json={ + "plan": { + "spec": spec.dict(), + "current_resource": _fleet_model_to_json_dict(fleet), + }, + "force": False, + }, + ) + + response_json = response.json() + assert response.status_code == 200, response_json + assert response_json["id"] == str(fleet.id) + assert response_json["spec"]["configuration"]["nodes"] == { + "min": 0, + "target": 1, + "max": 1, + } + + await session.refresh(fleet) + assert json.loads(fleet.spec)["configuration"]["nodes"] == { + "min": 0, + "target": 1, + "max": 1, + } + res = await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + assert list(res.scalars().all()) == [] + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)) @@ -2118,6 +2220,62 @@ async def test_returns_update_plan_for_existing_fleet( "action": "update", } + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_update_plan_for_existing_cloud_fleet_nodes_update( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_spec = get_fleet_spec( + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1)) + ) + spec = current_spec.copy(deep=True) + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + response_json = response.json() + assert response.status_code == 200, response_json + assert response_json["current_resource"]["id"] == str(fleet.id) + assert response_json["action"] == "update" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_create_plan_for_existing_cloud_fleet_blocks_update( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + current_spec = get_fleet_spec( + conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1)) + ) + spec = current_spec.copy(deep=True) + spec.configuration.blocks = 2 + fleet = await create_fleet(session=session, project=project, spec=current_spec) + + response = await client.post( + f"/api/project/{project.name}/fleets/get_plan", + headers=get_auth_headers(user.token), + json={"spec": spec.dict()}, + ) + + response_json = response.json() + assert response.status_code == 200, response_json + assert response_json["current_resource"]["id"] == str(fleet.id) + assert response_json["action"] == "create" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_create_plan_for_existing_fleet(