Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,25 @@ async def get_plan(

offers = []
if effective_spec.configuration.ssh_config is None:
offers_with_backends = await get_create_instance_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=get_fleet_requirements(effective_spec),
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
)
requirements = get_fleet_requirements(effective_spec)
if _is_elastic_cloud_fleet_spec(effective_spec):
offers_with_backends = await offers_services.get_offers_by_requirements(
project=project,
profile=effective_spec.merged_profile,
requirements=requirements,
multinode=(
effective_spec.configuration.placement == InstanceGroupPlacement.CLUSTER
),
blocks=effective_spec.configuration.blocks,
)
else:
offers_with_backends = await get_create_instance_offers(
project=project,
profile=effective_spec.merged_profile,
requirements=requirements,
fleet_spec=effective_spec,
blocks=effective_spec.configuration.blocks,
)
Comment thread
peterschmidt85 marked this conversation as resolved.
Outdated
offers = [offer for _, offer in offers_with_backends]

_remove_fleet_spec_sensitive_info(effective_spec)
Expand All @@ -468,6 +480,16 @@ async def get_plan(
return plan


def _is_elastic_cloud_fleet_spec(fleet_spec: FleetSpec) -> bool:
Comment thread
peterschmidt85 marked this conversation as resolved.
Outdated
nodes = fleet_spec.configuration.nodes
return (
fleet_spec.configuration.ssh_config is None
and nodes is not None
and nodes.min == 0
and nodes.target == 0
)


async def get_create_instance_offers(
project: ProjectModel,
profile: Profile,
Expand Down
73 changes: 73 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.fleets import (
FleetConfiguration,
FleetNodesSpec,
FleetStatus,
InstanceGroupPlacement,
SSHHostParams,
Expand Down Expand Up @@ -2028,6 +2029,78 @@ async def test_returns_create_plan_for_new_fleet(
"action": "create",
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_offers_for_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=0, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == [json.loads(offer.json())]
assert response_json["total_offers"] == 1
assert response_json["max_offer_price"] == offer.price

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_no_offers_for_non_elastic_container_backend_fleet(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
offer = get_instance_offer_with_availability(
backend=BackendType.RUNPOD,
region="US-OR-1",
price=0.7185,
)
spec = get_fleet_spec(
conf=get_fleet_configuration(nodes=FleetNodesSpec(min=0, target=1, max=1))
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()

response_json = response.json()
assert response.status_code == 200, response_json
assert response_json["offers"] == []
assert response_json["total_offers"] == 0
assert response_json["max_offer_price"] is None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_update_plan_for_existing_fleet(
Expand Down
Loading