From 4d821d0ab104e45f72c243194865ae71bb9bad36 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 20 Jan 2025 00:12:45 +0100 Subject: [PATCH] Add `/api/instances/list` Add an API method for listing instances with filtering and pagination. This method will be used in the new Instances page in the UI. It is similar to the deprecated `/api/pools/list_instances` method, except it allows filtering by fleet and not by pool. Also add the `fleet_id` and `fleet_name` fields to the instance model so that the Instances page in the UI can display fleet names and provide links to fleet details. --- src/dstack/_internal/core/models/pools.py | 2 + src/dstack/_internal/server/app.py | 2 + .../_internal/server/routers/instances.py | 45 +++ src/dstack/_internal/server/routers/pools.py | 3 +- .../_internal/server/schemas/instances.py | 15 + src/dstack/_internal/server/services/pools.py | 55 +++- src/dstack/_internal/server/testing/common.py | 3 +- .../_internal/server/routers/test_fleets.py | 4 + .../server/routers/test_instances.py | 277 ++++++++++++++++++ .../_internal/server/routers/test_pools.py | 4 + .../_internal/server/routers/test_runs.py | 2 + 11 files changed, 396 insertions(+), 16 deletions(-) create mode 100644 src/dstack/_internal/server/routers/instances.py create mode 100644 src/dstack/_internal/server/schemas/instances.py create mode 100644 src/tests/_internal/server/routers/test_instances.py diff --git a/src/dstack/_internal/core/models/pools.py b/src/dstack/_internal/core/models/pools.py index 0c55464bd..97f70031b 100644 --- a/src/dstack/_internal/core/models/pools.py +++ b/src/dstack/_internal/core/models/pools.py @@ -21,6 +21,8 @@ class Instance(CoreModel): backend: Optional[BackendType] = None instance_type: Optional[InstanceType] = None name: str + fleet_id: Optional[UUID] = None + fleet_name: Optional[str] = None instance_num: int pool_name: Optional[str] = None job_name: Optional[str] = None diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 771497cfe..0da0d9f9b 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -24,6 +24,7 @@ backends, fleets, gateways, + instances, logs, metrics, pools, @@ -169,6 +170,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(backends.project_router) app.include_router(fleets.root_router) app.include_router(fleets.project_router) + app.include_router(instances.root_router) app.include_router(repos.router) app.include_router(runs.root_router) app.include_router(runs.project_router) diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py new file mode 100644 index 000000000..eda85d652 --- /dev/null +++ b/src/dstack/_internal/server/routers/instances.py @@ -0,0 +1,45 @@ +from typing import List + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +import dstack._internal.server.services.pools as pools +from dstack._internal.core.models.pools import Instance +from dstack._internal.server.db import get_session +from dstack._internal.server.models import UserModel +from dstack._internal.server.schemas.instances import ListInstancesRequest +from dstack._internal.server.security.permissions import Authenticated +from dstack._internal.server.utils.routers import get_base_api_additional_responses + +root_router = APIRouter( + prefix="/api/instances", + tags=["instances"], + responses=get_base_api_additional_responses(), +) + + +@root_router.post("/list") +async def list_instances( + body: ListInstancesRequest, + session: AsyncSession = Depends(get_session), + user: UserModel = Depends(Authenticated()), +) -> List[Instance]: + """ + Returns all instances visible to user sorted by descending `created_at`. + `project_names` and `fleet_ids` can be specified as filters. + + The results are paginated. To get the next page, pass `created_at` and `id` of + the last instance from the previous page as `prev_created_at` and `prev_id`. + """ + return await pools.list_user_pool_instances( + session=session, + user=user, + project_names=body.project_names, + fleet_ids=body.fleet_ids, + pool_name=None, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) diff --git a/src/dstack/_internal/server/routers/pools.py b/src/dstack/_internal/server/routers/pools.py index 9cf8358ea..1244e0edc 100644 --- a/src/dstack/_internal/server/routers/pools.py +++ b/src/dstack/_internal/server/routers/pools.py @@ -45,7 +45,8 @@ async def list_pool_instances( return await pools.list_user_pool_instances( session=session, user=user, - project_name=body.project_name, + project_names=[body.project_name] if body.project_name is not None else None, + fleet_ids=None, pool_name=body.pool_name, only_active=body.only_active, prev_created_at=body.prev_created_at, diff --git a/src/dstack/_internal/server/schemas/instances.py b/src/dstack/_internal/server/schemas/instances.py new file mode 100644 index 000000000..0b1d6ccaa --- /dev/null +++ b/src/dstack/_internal/server/schemas/instances.py @@ -0,0 +1,15 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from dstack._internal.core.models.common import CoreModel + + +class ListInstancesRequest(CoreModel): + project_names: Optional[list[str]] = None + fleet_ids: Optional[list[UUID]] = None + only_active: bool = False + prev_created_at: Optional[datetime] = None + prev_id: Optional[UUID] = None + limit: int = 1000 + ascending: bool = False diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index d88b36fc0..9ad07590e 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -1,5 +1,6 @@ import ipaddress import uuid +from collections.abc import Container, Iterable from datetime import datetime, timezone from typing import List, Optional @@ -69,7 +70,11 @@ async def list_project_pools(session: AsyncSession, project: ProjectModel) -> Li async def get_pool( - session: AsyncSession, project: ProjectModel, pool_name: str, select_deleted: bool = False + session: AsyncSession, + project: ProjectModel, + pool_name: str, + select_deleted: bool = False, + load_instance_fleets: bool = False, ) -> Optional[PoolModel]: filters = [ PoolModel.name == pool_name, @@ -77,29 +82,42 @@ async def get_pool( ] if not select_deleted: filters.append(PoolModel.deleted == False) - res = await session.scalars(select(PoolModel).where(*filters)) + query = select(PoolModel).where(*filters) + if load_instance_fleets: + query = query.options(joinedload(PoolModel.instances, InstanceModel.fleet)) + res = await session.scalars(query) return res.one_or_none() async def get_or_create_pool_by_name( - session: AsyncSession, project: ProjectModel, pool_name: Optional[str] + session: AsyncSession, + project: ProjectModel, + pool_name: Optional[str], + load_instance_fleets: bool = False, ) -> PoolModel: if pool_name is None: if project.default_pool_id is not None: - return await get_default_pool_or_error(session, project) - default_pool = await get_pool(session, project, DEFAULT_POOL_NAME) + return await get_default_pool_or_error(session, project, load_instance_fleets) + default_pool = await get_pool( + session, project, DEFAULT_POOL_NAME, load_instance_fleets=load_instance_fleets + ) if default_pool is not None: await set_default_pool(session, project, DEFAULT_POOL_NAME) return default_pool return await create_pool(session, project, DEFAULT_POOL_NAME) - pool = await get_pool(session, project, pool_name) + pool = await get_pool(session, project, pool_name, load_instance_fleets=load_instance_fleets) if pool is not None: return pool return await create_pool(session, project, pool_name) -async def get_default_pool_or_error(session: AsyncSession, project: ProjectModel) -> PoolModel: - res = await session.execute(select(PoolModel).where(PoolModel.id == project.default_pool_id)) +async def get_default_pool_or_error( + session: AsyncSession, project: ProjectModel, load_instance_fleets: bool = False +) -> PoolModel: + query = select(PoolModel).where(PoolModel.id == project.default_pool_id) + if load_instance_fleets: + query = query.options(joinedload(PoolModel.instances, InstanceModel.fleet)) + res = await session.execute(query) return res.scalar_one() @@ -201,11 +219,13 @@ async def show_pool_instances( session: AsyncSession, project: ProjectModel, pool_name: Optional[str] ) -> PoolInstances: if pool_name is not None: - pool = await get_pool(session, project, pool_name) + pool = await get_pool(session, project, pool_name, load_instance_fleets=True) if pool is None: raise ResourceNotExistsError("Pool not found") else: - pool = await get_or_create_pool_by_name(session, project, pool_name) + pool = await get_or_create_pool_by_name( + session, project, pool_name, load_instance_fleets=True + ) pool_instances = get_pool_instances(pool) instances = list(map(instance_model_to_instance, pool_instances)) return PoolInstances( @@ -223,6 +243,8 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance: id=instance_model.id, project_name=instance_model.project.name, name=instance_model.name, + fleet_id=instance_model.fleet_id, + fleet_name=instance_model.fleet.name if instance_model.fleet else None, instance_num=instance_model.instance_num, status=instance_model.status, unreachable=instance_model.unreachable, @@ -478,6 +500,7 @@ def filter_pool_instances( async def list_pools_instance_models( session: AsyncSession, projects: List[ProjectModel], + fleet_ids: Optional[Iterable[uuid.UUID]], pool: Optional[PoolModel], only_active: bool, prev_created_at: Optional[datetime], @@ -488,6 +511,8 @@ async def list_pools_instance_models( filters: List = [ InstanceModel.project_id.in_(p.id for p in projects), ] + if fleet_ids is not None: + filters.append(InstanceModel.fleet_id.in_(fleet_ids)) if pool is not None: filters.append(InstanceModel.pool_id == pool.id) if only_active: @@ -533,7 +558,7 @@ async def list_pools_instance_models( .where(*filters) .order_by(*order_by) .limit(limit) - .options(joinedload(InstanceModel.pool)) + .options(joinedload(InstanceModel.pool), joinedload(InstanceModel.fleet)) ) instance_models = list(res.scalars().all()) return instance_models @@ -542,7 +567,8 @@ async def list_pools_instance_models( async def list_user_pool_instances( session: AsyncSession, user: UserModel, - project_name: Optional[str], + project_names: Optional[Container[str]], + fleet_ids: Optional[Iterable[uuid.UUID]], pool_name: Optional[str], only_active: bool, prev_created_at: Optional[datetime], @@ -558,8 +584,8 @@ async def list_user_pool_instances( return [] pool = None - if project_name is not None: - projects = [proj for proj in projects if proj.name == project_name] + if project_names is not None: + projects = [proj for proj in projects if proj.name in project_names] if len(projects) == 0: return [] if pool_name is not None: @@ -573,6 +599,7 @@ async def list_user_pool_instances( instance_models = await list_pools_instance_models( session=session, projects=projects, + fleet_ids=fleet_ids, pool=pool, only_active=only_active, prev_created_at=prev_created_at, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 2edfe7517..9cb6e0fa8 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -480,6 +480,7 @@ async def create_instance( region: str = "eu-west", remote_connection_info: Optional[RemoteConnectionInfo] = None, job_provisioning_data: Optional[JobProvisioningData] = None, + name: str = "test_instance", ) -> InstanceModel: if instance_id is None: instance_id = uuid.uuid4() @@ -544,7 +545,7 @@ async def create_instance( im = InstanceModel( id=instance_id, - name="test_instance", + name=name, instance_num=instance_num, pool=pool, fleet=fleet, diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index b4b94e4ac..af7e89efd 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -316,6 +316,8 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", "project_name": project.name, "name": f"{spec.configuration.name}-0", + "fleet_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "fleet_name": spec.configuration.name, "instance_num": 0, "job_name": None, "hostname": None, @@ -443,6 +445,8 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A }, }, "name": f"{spec.configuration.name}-0", + "fleet_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", + "fleet_name": spec.configuration.name, "instance_num": 0, "pool_name": None, "job_name": None, diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py new file mode 100644 index 000000000..8b1b40229 --- /dev/null +++ b/src/tests/_internal/server/routers/test_instances.py @@ -0,0 +1,277 @@ +import datetime as dt +import uuid +from dataclasses import dataclass +from itertools import count + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.models import UserModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_pool, + create_project, + create_user, + get_auth_headers, + get_fleet_configuration, + get_fleet_spec, +) + + +@dataclass +class PreparedData: + users: list[UserModel] + + +SAMPLE_FLEET_IDS = [uuid.uuid4() for _ in range(3)] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestListInstances: + @pytest_asyncio.fixture + async def data(self, session: AsyncSession) -> PreparedData: + users = [ + await create_user(session, name="user0", global_role=GlobalRole.ADMIN), + await create_user(session, name="user1", global_role=GlobalRole.USER), + await create_user(session, name="user2", global_role=GlobalRole.USER), + ] + projects = [ + await create_project(session, owner=users[0], name="project0"), + await create_project(session, owner=users[1], name="project1"), + await create_project(session, owner=users[2], name="project2"), + ] + await add_project_member( + session, project=projects[0], user=users[0], project_role=ProjectRole.ADMIN + ) + await add_project_member( + session, project=projects[1], user=users[1], project_role=ProjectRole.ADMIN + ) + await add_project_member( + session, project=projects[2], user=users[2], project_role=ProjectRole.ADMIN + ) + await add_project_member( + session, project=projects[2], user=users[1], project_role=ProjectRole.USER + ) + pools = [ + await create_pool(session, projects[0]), + await create_pool(session, projects[1]), + await create_pool(session, projects[2]), + ] + fleets = [ + await create_fleet( + session, + projects[0], + spec=get_fleet_spec(conf=get_fleet_configuration("fleet0")), + fleet_id=SAMPLE_FLEET_IDS[0], + ), + await create_fleet( + session, + projects[1], + spec=get_fleet_spec(conf=get_fleet_configuration("fleet1")), + fleet_id=SAMPLE_FLEET_IDS[1], + ), + await create_fleet( + session, + projects[2], + spec=get_fleet_spec(conf=get_fleet_configuration("fleet2")), + fleet_id=SAMPLE_FLEET_IDS[2], + ), + ] + _ = [ + await create_instance( + session=session, + project=projects[0], + pool=pools[0], + fleet=fleets[0], + created_at=dt.datetime(2024, 1, 1, tzinfo=dt.timezone.utc), + name="fleet0-0", + ), + await create_instance( + session=session, + project=projects[1], + pool=pools[1], + fleet=fleets[1], + created_at=dt.datetime(2024, 1, 2, tzinfo=dt.timezone.utc), + name="fleet1-0", + ), + await create_instance( + session=session, + project=projects[2], + pool=pools[2], + fleet=fleets[2], + created_at=dt.datetime(2024, 1, 3, tzinfo=dt.timezone.utc), + name="fleet2-0", + ), + await create_instance( + session=session, + project=projects[2], + pool=pools[2], + fleet=fleets[2], + created_at=dt.datetime(2024, 1, 4, tzinfo=dt.timezone.utc), + instance_num=1, + name="fleet2-1", + status=InstanceStatus.TERMINATED, + ), + ] + return PreparedData(users=users) + + @pytest.mark.parametrize( + ("user", "expected_instances"), + [ + pytest.param( + 0, + ["fleet0-0", "fleet1-0", "fleet2-0", "fleet2-1"], + id="global-admin", + ), + pytest.param( + 1, + ["fleet1-0", "fleet2-0", "fleet2-1"], + id="admin-in-one-project-user-in-other", + ), + pytest.param( + 2, + ["fleet2-0", "fleet2-1"], + id="project-admin", + ), + ], + ) + async def test_project_access( + self, user: int, expected_instances: list[str], data: PreparedData, client: AsyncClient + ) -> None: + resp = await client.post( + "/api/instances/list", + headers=get_auth_headers(data.users[user].token), + json={"ascending": True}, + ) + assert resp.status_code == 200 + instances = [instance["name"] for instance in resp.json()] + assert instances == expected_instances + + @pytest.mark.parametrize( + ("filters", "expected_instances"), + [ + pytest.param( + {"project_names": ["project1", "project2"]}, + ["fleet1-0", "fleet2-0", "fleet2-1"], + id="two-projects", + ), + pytest.param( + {"project_names": ["project1"]}, + ["fleet1-0"], + id="one-project", + ), + pytest.param( + {"project_names": ["project0"]}, + [], + id="forbidden-project", + ), + pytest.param( + {"project_names": ["nonexistent"]}, + [], + id="nonexistent-project", + ), + pytest.param( + {"fleet_ids": [str(SAMPLE_FLEET_IDS[1]), str(SAMPLE_FLEET_IDS[2])]}, + ["fleet1-0", "fleet2-0", "fleet2-1"], + id="two-fleets", + ), + pytest.param( + {"fleet_ids": [str(SAMPLE_FLEET_IDS[1])]}, + ["fleet1-0"], + id="one-fleet", + ), + pytest.param( + {"fleet_ids": [str(SAMPLE_FLEET_IDS[0])]}, + [], + id="forbidden-fleet", + ), + pytest.param( + {"fleet_ids": [str(uuid.uuid4())]}, + [], + id="nonexistent-fleet", + ), + pytest.param( + {"project_names": ["project1"], "fleet_ids": [str(SAMPLE_FLEET_IDS[1])]}, + ["fleet1-0"], + id="project-and-fleet-match", + ), + pytest.param( + {"project_names": ["project2"], "fleet_ids": [str(SAMPLE_FLEET_IDS[1])]}, + [], + id="project-and-fleet-no-match", + ), + pytest.param( + {"only_active": True, "project_names": ["project2"]}, + ["fleet2-0"], + id="only-active", + ), + ], + ) + async def test_filters( + self, + filters: dict, + expected_instances: list[str], + data: PreparedData, + client: AsyncClient, + ) -> None: + resp = await client.post( + "/api/instances/list", + headers=get_auth_headers(data.users[1].token), + json={"ascending": True, **filters}, + ) + assert resp.status_code == 200 + instances = [instance["name"] for instance in resp.json()] + assert instances == expected_instances + + @pytest.mark.parametrize( + ("is_ascending", "expected_pages"), + [ + pytest.param(True, [["fleet1-0", "fleet2-0"], ["fleet2-1"]], id="ascending"), + pytest.param(False, [["fleet2-1", "fleet2-0"], ["fleet1-0"]], id="descending"), + ], + ) + async def test_pagination( + self, + is_ascending: bool, + expected_pages: list[list[str]], + data: PreparedData, + client: AsyncClient, + ) -> None: + pages = [] + prev_id = None + prev_created_at = None + for page_no in count(): + if page_no == 10: + raise RuntimeError("Too many pages") + resp = await client.post( + "/api/instances/list", + headers=get_auth_headers(data.users[1].token), + json={ + "ascending": is_ascending, + "limit": 2, + "project_names": ["project1", "project2"], + "prev_id": prev_id, + "prev_created_at": prev_created_at, + }, + ) + assert resp.status_code == 200 + page = [] + for instance in resp.json(): + page.append(instance["name"]) + prev_id = instance["id"] + prev_created_at = instance["created"] + if not page: + break + pages.append(page) + assert pages == expected_pages + + async def test_not_authenticated(self, client: AsyncClient, data) -> None: + resp = await client.post("/api/instances/list", json={}) + assert resp.status_code == 403 diff --git a/src/tests/_internal/server/routers/test_pools.py b/src/tests/_internal/server/routers/test_pools.py index f9929fcc8..47d165609 100644 --- a/src/tests/_internal/server/routers/test_pools.py +++ b/src/tests/_internal/server/routers/test_pools.py @@ -321,6 +321,8 @@ async def test_show_pool(self, test_db, session: AsyncSession, client: AsyncClie "id": str(instance.id), "project_name": project.name, "name": "test_instance", + "fleet_id": None, + "fleet_name": None, "instance_num": 0, "job_name": None, "hostname": "running_instance.ip", @@ -490,6 +492,8 @@ async def test_remove_instance(self, test_db, session: AsyncSession, client: Asy "id": str(instance.id), "project_name": project.name, "name": "test_instance", + "fleet_id": None, + "fleet_name": None, "instance_num": 0, "job_name": None, "hostname": "running_instance.ip", diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index b04a1e427..e9a19ac16 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1506,6 +1506,8 @@ async def test_creates_instance(self, test_db, session: AsyncSession, client: As "backend": None, "instance_type": None, "name": result["name"], + "fleet_id": None, + "fleet_name": None, "instance_num": 0, "job_name": None, "hostname": None,