Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/dstack/_internal/cli/commands/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from dstack._internal.cli.utils.fleet import get_fleets_table, print_fleets_table
from dstack._internal.core.errors import CLIError, ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent


Expand Down Expand Up @@ -49,6 +50,7 @@ def _register(self):
)
delete_parser.add_argument(
"name",
type=EntityReference.parse,
help="The name of the fleet",
).completer = FleetNameCompleter() # type: ignore[attr-defined]
delete_parser.add_argument(
Expand All @@ -73,6 +75,7 @@ def _register(self):
"name",
nargs="?",
metavar="NAME",
type=EntityReference.parse,
help="The name of the fleet",
).completer = FleetNameCompleter() # type: ignore[attr-defined]
name_group.add_argument(
Expand Down Expand Up @@ -112,35 +115,43 @@ def _list(self, args: argparse.Namespace):
pass

def _delete(self, args: argparse.Namespace):
if args.name.project is not None:
console.print(
"The [code]<project>/<fleet>[/] format is not supported for fleet names."
" Can only delete fleets or instances owned by the current project"
)
exit(1)
name = args.name.name

try:
self.api.client.fleets.get(project_name=self.api.project, name=args.name)
self.api.client.fleets.get(project_name=self.api.project, name=name)
except ResourceNotExistsError:
console.print(f"Fleet [code]{args.name}[/] does not exist")
console.print(f"Fleet [code]{name}[/] does not exist")
exit(1)

if not args.instances:
if not args.yes and not confirm_ask(f"Delete the fleet [code]{args.name}[/]?"):
if not args.yes and not confirm_ask(f"Delete the fleet [code]{name}[/]?"):
console.print("\nExiting...")
return

with console.status("Deleting fleet..."):
self.api.client.fleets.delete(project_name=self.api.project, names=[args.name])
self.api.client.fleets.delete(project_name=self.api.project, names=[name])

console.print(f"Fleet [code]{args.name}[/] deleted")
console.print(f"Fleet [code]{name}[/] deleted")
return

if not args.yes and not confirm_ask(
f"Delete the fleet [code]{args.name}[/] instances [code]{args.instances}[/]?"
f"Delete the fleet [code]{name}[/] instances [code]{args.instances}[/]?"
):
console.print("\nExiting...")
return

with console.status("Deleting fleet instances..."):
self.api.client.fleets.delete_instances(
project_name=self.api.project, name=args.name, instance_nums=args.instances
project_name=self.api.project, name=name, instance_nums=args.instances
)

console.print(f"Fleet [code]{args.name}[/] instances deleted")
console.print(f"Fleet [code]{name}[/] instances deleted")

def _get(self, args: argparse.Namespace):
# TODO: Implement non-json output format
Expand All @@ -157,7 +168,10 @@ def _get(self, args: argparse.Namespace):
project_name=self.api.project, fleet_id=fleet_id
)
else:
fleet = self.api.client.fleets.get(project_name=self.api.project, name=args.name)
fleet = self.api.client.fleets.get(
project_name=args.name.project or self.api.project,
name=args.name.name,
)
except ResourceNotExistsError:
console.print(f"Fleet [code]{args.name or args.id}[/] not found")
exit(1)
Expand Down
7 changes: 6 additions & 1 deletion src/dstack/_internal/core/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,17 @@ class EntityReference(CoreModel):
def parse(cls, v: Union[str, "EntityReference"]) -> "EntityReference":
if isinstance(v, EntityReference):
return v
invalid_ref_error = ValueError(
"Invalid entity reference. Only `<name>` or `<project>/<name>` formats are allowed"
)
parts = v.split("/")
if any(len(part) == 0 for part in parts):
raise invalid_ref_error
if len(parts) == 1:
return cls(project=None, name=parts[0])
if len(parts) == 2:
return cls(project=parts[0], name=parts[1])
raise ValueError("Invalid entity reference. Only `<project>/<name>` format is allowed")
raise invalid_ref_error

def format(self) -> str:
if self.project is None:
Expand Down
27 changes: 27 additions & 0 deletions src/tests/_internal/core/models/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from dstack._internal.core.models.common import EntityReference


class TestEntityReferenceParse:
@pytest.mark.parametrize(
"value, expected",
[
("fleet", EntityReference(project=None, name="fleet")),
("project/fleet", EntityReference(project="project", name="fleet")),
(
EntityReference(project="proj", name="fleet"),
EntityReference(project="proj", name="fleet"),
),
],
)
def test_valid(self, value, expected):
assert EntityReference.parse(value) == expected

@pytest.mark.parametrize(
"value",
["", "/name", "name/", "/", "a/b/c"],
)
def test_invalid(self, value: str):
with pytest.raises(ValueError, match="Invalid entity reference"):
EntityReference.parse(value)
Loading