Skip to content

Commit

Permalink
Allow getting by ID in /api/project/_/fleets/get (#2200)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvstme authored Jan 21, 2025
1 parent bc5f0ac commit e3c221f
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 14 deletions.
8 changes: 5 additions & 3 deletions src/dstack/_internal/server/routers/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ async def get_fleet(
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> Fleet:
"""
Returns a fleet given a fleet name.
Returns a fleet given `name` or `id`.
If given `name`, does not return deleted fleets.
If given `id`, returns deleted fleets.
"""
_, project = user_project
fleet = await fleets_services.get_fleet_by_name(
session=session, project=project, name=body.name
fleet = await fleets_services.get_fleet(
session=session, project=project, name=body.name, fleet_id=body.id
)
if fleet is None:
raise ResourceNotExistsError()
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/schemas/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class ListFleetsRequest(CoreModel):


class GetFleetRequest(CoreModel):
name: str
name: Optional[str]
id: Optional[UUID] = None


class GetFleetPlanRequest(CoreModel):
Expand Down
35 changes: 30 additions & 5 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,42 @@ async def list_project_fleet_models(
return list(res.unique().scalars().all())


async def get_fleet_by_name(
session: AsyncSession, project: ProjectModel, name: str
async def get_fleet(
session: AsyncSession,
project: ProjectModel,
name: Optional[str],
fleet_id: Optional[uuid.UUID],
) -> Optional[Fleet]:
fleet_model = await get_project_fleet_model_by_name(
session=session, project=project, name=name
)
if fleet_id is not None:
fleet_model = await get_project_fleet_model_by_id(
session=session, project=project, fleet_id=fleet_id
)
elif name is not None:
fleet_model = await get_project_fleet_model_by_name(
session=session, project=project, name=name
)
else:
raise ServerClientError("name or id must be specified")
if fleet_model is None:
return None
return fleet_model_to_fleet(fleet_model)


async def get_project_fleet_model_by_id(
session: AsyncSession,
project: ProjectModel,
fleet_id: uuid.UUID,
) -> Optional[FleetModel]:
filters = [
FleetModel.id == fleet_id,
FleetModel.project_id == project.id,
]
res = await session.execute(
select(FleetModel).where(*filters).options(joinedload(FleetModel.instances))
)
return res.unique().scalar_one_or_none()


async def get_project_fleet_model_by_name(
session: AsyncSession,
project: ProjectModel,
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ async def create_fleet(
spec: Optional[FleetSpec] = None,
fleet_id: Optional[UUID] = None,
status: FleetStatus = FleetStatus.ACTIVE,
deleted: bool = False,
) -> FleetModel:
if fleet_id is None:
fleet_id = uuid.uuid4()
Expand All @@ -426,6 +427,7 @@ async def create_fleet(
fm = FleetModel(
id=fleet_id,
project=project,
deleted=deleted,
name=spec.configuration.name,
status=status,
created_at=created_at,
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/api/server/_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def list(self, project_name: str) -> List[Fleet]:

def get(self, project_name: str, name: str) -> Fleet:
body = GetFleetRequest(name=name)
resp = self._request(f"/api/project/{project_name}/fleets/get", body=body.json())
resp = self._request(
f"/api/project/{project_name}/fleets/get",
body=body.json(exclude={"id"}), # `id` is not supported in pre-0.18.36 servers
)
return parse_obj_as(Fleet.__response__, resp.json())

def get_plan(
Expand Down
72 changes: 68 additions & 4 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from datetime import datetime, timezone
from unittest.mock import Mock, patch
from uuid import UUID
from uuid import UUID, uuid4

import pytest
from freezegun import freeze_time
Expand Down Expand Up @@ -183,7 +183,10 @@ async def test_returns_40x_if_not_authenticated(

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_fleet(self, test_db, session: AsyncSession, client: AsyncClient):
@pytest.mark.parametrize("deleted", [False, True])
async def test_returns_fleet_by_id(
self, test_db, session: AsyncSession, client: AsyncClient, deleted: bool
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
Expand All @@ -193,11 +196,12 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async
session=session,
project=project,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
deleted=deleted,
)
response = await client.post(
f"/api/project/{project.name}/fleets/get",
headers=get_auth_headers(user.token),
json={"name": fleet.name},
json={"id": str(fleet.id)},
)
assert response.status_code == 200
assert response.json() == {
Expand All @@ -213,7 +217,67 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_returns_400_if_fleet_does_not_exist(
async def test_returns_not_deleted_fleet_by_name(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
active_fleet = await create_fleet(
session=session,
project=project,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
fleet_id=uuid4(),
)
deleted_fleet = await create_fleet(
session=session,
project=project,
created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
fleet_id=uuid4(),
deleted=True,
)
assert active_fleet.name == deleted_fleet.name
assert active_fleet.id != deleted_fleet.id
response = await client.post(
f"/api/project/{project.name}/fleets/get",
headers=get_auth_headers(user.token),
json={"name": active_fleet.name},
)
assert response.status_code == 200
assert response.json() == {
"id": str(active_fleet.id),
"name": active_fleet.name,
"project_name": project.name,
"spec": json.loads(active_fleet.spec),
"created_at": "2023-01-02T03:04:00+00:00",
"status": active_fleet.status.value,
"status_message": None,
"instances": [],
}

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_not_returns_by_name_if_fleet_deleted(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
fleet = await create_fleet(session=session, project=project, deleted=True)
response = await client.post(
f"/api/project/{project.name}/fleets/get",
headers=get_auth_headers(user.token),
json={"name": fleet.name},
)
assert response.status_code == 400

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_not_returns_by_name_if_fleet_does_not_exist(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session, global_role=GlobalRole.USER)
Expand Down

0 comments on commit e3c221f

Please sign in to comment.