From e3c221fd0ed42be808cbab590305bf2d75273eb5 Mon Sep 17 00:00:00 2001 From: jvstme <36324149+jvstme@users.noreply.github.com> Date: Tue, 21 Jan 2025 08:32:30 +0000 Subject: [PATCH] Allow getting by ID in `/api/project/_/fleets/get` (#2200) --- src/dstack/_internal/server/routers/fleets.py | 8 ++- src/dstack/_internal/server/schemas/fleets.py | 3 +- .../_internal/server/services/fleets.py | 35 +++++++-- src/dstack/_internal/server/testing/common.py | 2 + src/dstack/api/server/_fleets.py | 5 +- .../_internal/server/routers/test_fleets.py | 72 +++++++++++++++++-- 6 files changed, 111 insertions(+), 14 deletions(-) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index 54853327d..67d21761f 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -75,11 +75,13 @@ async def get_fleet( user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> Fleet: """ - Returns a fleet given a fleet name. + Returns a fleet given `name` or `id`. + If given `name`, does not return deleted fleets. + If given `id`, returns deleted fleets. """ _, project = user_project - fleet = await fleets_services.get_fleet_by_name( - session=session, project=project, name=body.name + fleet = await fleets_services.get_fleet( + session=session, project=project, name=body.name, fleet_id=body.id ) if fleet is None: raise ResourceNotExistsError() diff --git a/src/dstack/_internal/server/schemas/fleets.py b/src/dstack/_internal/server/schemas/fleets.py index 03ec04621..792b07c98 100644 --- a/src/dstack/_internal/server/schemas/fleets.py +++ b/src/dstack/_internal/server/schemas/fleets.py @@ -18,7 +18,8 @@ class ListFleetsRequest(CoreModel): class GetFleetRequest(CoreModel): - name: str + name: Optional[str] + id: Optional[UUID] = None class GetFleetPlanRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 61630959b..da93ce50e 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -179,17 +179,42 @@ async def list_project_fleet_models( return list(res.unique().scalars().all()) -async def get_fleet_by_name( - session: AsyncSession, project: ProjectModel, name: str +async def get_fleet( + session: AsyncSession, + project: ProjectModel, + name: Optional[str], + fleet_id: Optional[uuid.UUID], ) -> Optional[Fleet]: - fleet_model = await get_project_fleet_model_by_name( - session=session, project=project, name=name - ) + if fleet_id is not None: + fleet_model = await get_project_fleet_model_by_id( + session=session, project=project, fleet_id=fleet_id + ) + elif name is not None: + fleet_model = await get_project_fleet_model_by_name( + session=session, project=project, name=name + ) + else: + raise ServerClientError("name or id must be specified") if fleet_model is None: return None return fleet_model_to_fleet(fleet_model) +async def get_project_fleet_model_by_id( + session: AsyncSession, + project: ProjectModel, + fleet_id: uuid.UUID, +) -> Optional[FleetModel]: + filters = [ + FleetModel.id == fleet_id, + FleetModel.project_id == project.id, + ] + res = await session.execute( + select(FleetModel).where(*filters).options(joinedload(FleetModel.instances)) + ) + return res.unique().scalar_one_or_none() + + async def get_project_fleet_model_by_name( session: AsyncSession, project: ProjectModel, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 9cb6e0fa8..d070bce2c 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -418,6 +418,7 @@ async def create_fleet( spec: Optional[FleetSpec] = None, fleet_id: Optional[UUID] = None, status: FleetStatus = FleetStatus.ACTIVE, + deleted: bool = False, ) -> FleetModel: if fleet_id is None: fleet_id = uuid.uuid4() @@ -426,6 +427,7 @@ async def create_fleet( fm = FleetModel( id=fleet_id, project=project, + deleted=deleted, name=spec.configuration.name, status=status, created_at=created_at, diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index 5e9195518..e89a4e2ce 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -20,7 +20,10 @@ def list(self, project_name: str) -> List[Fleet]: def get(self, project_name: str, name: str) -> Fleet: body = GetFleetRequest(name=name) - resp = self._request(f"/api/project/{project_name}/fleets/get", body=body.json()) + resp = self._request( + f"/api/project/{project_name}/fleets/get", + body=body.json(exclude={"id"}), # `id` is not supported in pre-0.18.36 servers + ) return parse_obj_as(Fleet.__response__, resp.json()) def get_plan( diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 78720fb18..fe0d344c3 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1,7 +1,7 @@ import json from datetime import datetime, timezone from unittest.mock import Mock, patch -from uuid import UUID +from uuid import UUID, uuid4 import pytest from freezegun import freeze_time @@ -183,7 +183,10 @@ async def test_returns_40x_if_not_authenticated( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_fleet(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("deleted", [False, True]) + async def test_returns_fleet_by_id( + self, test_db, session: AsyncSession, client: AsyncClient, deleted: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( @@ -193,11 +196,12 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async session=session, project=project, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + deleted=deleted, ) response = await client.post( f"/api/project/{project.name}/fleets/get", headers=get_auth_headers(user.token), - json={"name": fleet.name}, + json={"id": str(fleet.id)}, ) assert response.status_code == 200 assert response.json() == { @@ -213,7 +217,67 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_400_if_fleet_does_not_exist( + async def test_returns_not_deleted_fleet_by_name( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + active_fleet = await create_fleet( + session=session, + project=project, + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + fleet_id=uuid4(), + ) + deleted_fleet = await create_fleet( + session=session, + project=project, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + fleet_id=uuid4(), + deleted=True, + ) + assert active_fleet.name == deleted_fleet.name + assert active_fleet.id != deleted_fleet.id + response = await client.post( + f"/api/project/{project.name}/fleets/get", + headers=get_auth_headers(user.token), + json={"name": active_fleet.name}, + ) + assert response.status_code == 200 + assert response.json() == { + "id": str(active_fleet.id), + "name": active_fleet.name, + "project_name": project.name, + "spec": json.loads(active_fleet.spec), + "created_at": "2023-01-02T03:04:00+00:00", + "status": active_fleet.status.value, + "status_message": None, + "instances": [], + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_returns_by_name_if_fleet_deleted( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project, deleted=True) + response = await client.post( + f"/api/project/{project.name}/fleets/get", + headers=get_auth_headers(user.token), + json={"name": fleet.name}, + ) + assert response.status_code == 400 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_returns_by_name_if_fleet_does_not_exist( self, test_db, session: AsyncSession, client: AsyncClient ): user = await create_user(session, global_role=GlobalRole.USER)