Skip to content

Commit

Permalink
Add /api/instances/list (#2199)
Browse files Browse the repository at this point in the history
Add an API method for listing instances with
filtering and pagination. This method will be used
in the new Instances page in the UI. It is similar
to the deprecated `/api/pools/list_instances`
method, except it allows filtering by fleet and
not by pool.

Also add the `fleet_id` and `fleet_name` fields to
the instance model so that the Instances page in
the UI can display fleet names and provide links
to fleet details.
  • Loading branch information
jvstme authored Jan 20, 2025
1 parent 91bdc80 commit bc5f0ac
Show file tree
Hide file tree
Showing 11 changed files with 396 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Instance(CoreModel):
backend: Optional[BackendType] = None
instance_type: Optional[InstanceType] = None
name: str
fleet_id: Optional[UUID] = None
fleet_name: Optional[str] = None
instance_num: int
pool_name: Optional[str] = None
job_name: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
backends,
fleets,
gateways,
instances,
logs,
metrics,
pools,
Expand Down Expand Up @@ -169,6 +170,7 @@ def register_routes(app: FastAPI, ui: bool = True):
app.include_router(backends.project_router)
app.include_router(fleets.root_router)
app.include_router(fleets.project_router)
app.include_router(instances.root_router)
app.include_router(repos.router)
app.include_router(runs.root_router)
app.include_router(runs.project_router)
Expand Down
45 changes: 45 additions & 0 deletions src/dstack/_internal/server/routers/instances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import List

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

import dstack._internal.server.services.pools as pools
from dstack._internal.core.models.pools import Instance
from dstack._internal.server.db import get_session
from dstack._internal.server.models import UserModel
from dstack._internal.server.schemas.instances import ListInstancesRequest
from dstack._internal.server.security.permissions import Authenticated
from dstack._internal.server.utils.routers import get_base_api_additional_responses

root_router = APIRouter(
prefix="/api/instances",
tags=["instances"],
responses=get_base_api_additional_responses(),
)


@root_router.post("/list")
async def list_instances(
body: ListInstancesRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
) -> List[Instance]:
"""
Returns all instances visible to user sorted by descending `created_at`.
`project_names` and `fleet_ids` can be specified as filters.
The results are paginated. To get the next page, pass `created_at` and `id` of
the last instance from the previous page as `prev_created_at` and `prev_id`.
"""
return await pools.list_user_pool_instances(
session=session,
user=user,
project_names=body.project_names,
fleet_ids=body.fleet_ids,
pool_name=None,
only_active=body.only_active,
prev_created_at=body.prev_created_at,
prev_id=body.prev_id,
limit=body.limit,
ascending=body.ascending,
)
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/routers/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ async def list_pool_instances(
return await pools.list_user_pool_instances(
session=session,
user=user,
project_name=body.project_name,
project_names=[body.project_name] if body.project_name is not None else None,
fleet_ids=None,
pool_name=body.pool_name,
only_active=body.only_active,
prev_created_at=body.prev_created_at,
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/server/schemas/instances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

from dstack._internal.core.models.common import CoreModel


class ListInstancesRequest(CoreModel):
project_names: Optional[list[str]] = None
fleet_ids: Optional[list[UUID]] = None
only_active: bool = False
prev_created_at: Optional[datetime] = None
prev_id: Optional[UUID] = None
limit: int = 1000
ascending: bool = False
55 changes: 41 additions & 14 deletions src/dstack/_internal/server/services/pools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
from collections.abc import Container, Iterable
from datetime import datetime, timezone
from typing import List, Optional

Expand Down Expand Up @@ -69,37 +70,54 @@ async def list_project_pools(session: AsyncSession, project: ProjectModel) -> Li


async def get_pool(
session: AsyncSession, project: ProjectModel, pool_name: str, select_deleted: bool = False
session: AsyncSession,
project: ProjectModel,
pool_name: str,
select_deleted: bool = False,
load_instance_fleets: bool = False,
) -> Optional[PoolModel]:
filters = [
PoolModel.name == pool_name,
PoolModel.project_id == project.id,
]
if not select_deleted:
filters.append(PoolModel.deleted == False)
res = await session.scalars(select(PoolModel).where(*filters))
query = select(PoolModel).where(*filters)
if load_instance_fleets:
query = query.options(joinedload(PoolModel.instances, InstanceModel.fleet))
res = await session.scalars(query)
return res.one_or_none()


async def get_or_create_pool_by_name(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
session: AsyncSession,
project: ProjectModel,
pool_name: Optional[str],
load_instance_fleets: bool = False,
) -> PoolModel:
if pool_name is None:
if project.default_pool_id is not None:
return await get_default_pool_or_error(session, project)
default_pool = await get_pool(session, project, DEFAULT_POOL_NAME)
return await get_default_pool_or_error(session, project, load_instance_fleets)
default_pool = await get_pool(
session, project, DEFAULT_POOL_NAME, load_instance_fleets=load_instance_fleets
)
if default_pool is not None:
await set_default_pool(session, project, DEFAULT_POOL_NAME)
return default_pool
return await create_pool(session, project, DEFAULT_POOL_NAME)
pool = await get_pool(session, project, pool_name)
pool = await get_pool(session, project, pool_name, load_instance_fleets=load_instance_fleets)
if pool is not None:
return pool
return await create_pool(session, project, pool_name)


async def get_default_pool_or_error(session: AsyncSession, project: ProjectModel) -> PoolModel:
res = await session.execute(select(PoolModel).where(PoolModel.id == project.default_pool_id))
async def get_default_pool_or_error(
session: AsyncSession, project: ProjectModel, load_instance_fleets: bool = False
) -> PoolModel:
query = select(PoolModel).where(PoolModel.id == project.default_pool_id)
if load_instance_fleets:
query = query.options(joinedload(PoolModel.instances, InstanceModel.fleet))
res = await session.execute(query)
return res.scalar_one()


Expand Down Expand Up @@ -201,11 +219,13 @@ async def show_pool_instances(
session: AsyncSession, project: ProjectModel, pool_name: Optional[str]
) -> PoolInstances:
if pool_name is not None:
pool = await get_pool(session, project, pool_name)
pool = await get_pool(session, project, pool_name, load_instance_fleets=True)
if pool is None:
raise ResourceNotExistsError("Pool not found")
else:
pool = await get_or_create_pool_by_name(session, project, pool_name)
pool = await get_or_create_pool_by_name(
session, project, pool_name, load_instance_fleets=True
)
pool_instances = get_pool_instances(pool)
instances = list(map(instance_model_to_instance, pool_instances))
return PoolInstances(
Expand All @@ -223,6 +243,8 @@ def instance_model_to_instance(instance_model: InstanceModel) -> Instance:
id=instance_model.id,
project_name=instance_model.project.name,
name=instance_model.name,
fleet_id=instance_model.fleet_id,
fleet_name=instance_model.fleet.name if instance_model.fleet else None,
instance_num=instance_model.instance_num,
status=instance_model.status,
unreachable=instance_model.unreachable,
Expand Down Expand Up @@ -478,6 +500,7 @@ def filter_pool_instances(
async def list_pools_instance_models(
session: AsyncSession,
projects: List[ProjectModel],
fleet_ids: Optional[Iterable[uuid.UUID]],
pool: Optional[PoolModel],
only_active: bool,
prev_created_at: Optional[datetime],
Expand All @@ -488,6 +511,8 @@ async def list_pools_instance_models(
filters: List = [
InstanceModel.project_id.in_(p.id for p in projects),
]
if fleet_ids is not None:
filters.append(InstanceModel.fleet_id.in_(fleet_ids))
if pool is not None:
filters.append(InstanceModel.pool_id == pool.id)
if only_active:
Expand Down Expand Up @@ -533,7 +558,7 @@ async def list_pools_instance_models(
.where(*filters)
.order_by(*order_by)
.limit(limit)
.options(joinedload(InstanceModel.pool))
.options(joinedload(InstanceModel.pool), joinedload(InstanceModel.fleet))
)
instance_models = list(res.scalars().all())
return instance_models
Expand All @@ -542,7 +567,8 @@ async def list_pools_instance_models(
async def list_user_pool_instances(
session: AsyncSession,
user: UserModel,
project_name: Optional[str],
project_names: Optional[Container[str]],
fleet_ids: Optional[Iterable[uuid.UUID]],
pool_name: Optional[str],
only_active: bool,
prev_created_at: Optional[datetime],
Expand All @@ -558,8 +584,8 @@ async def list_user_pool_instances(
return []

pool = None
if project_name is not None:
projects = [proj for proj in projects if proj.name == project_name]
if project_names is not None:
projects = [proj for proj in projects if proj.name in project_names]
if len(projects) == 0:
return []
if pool_name is not None:
Expand All @@ -573,6 +599,7 @@ async def list_user_pool_instances(
instance_models = await list_pools_instance_models(
session=session,
projects=projects,
fleet_ids=fleet_ids,
pool=pool,
only_active=only_active,
prev_created_at=prev_created_at,
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ async def create_instance(
region: str = "eu-west",
remote_connection_info: Optional[RemoteConnectionInfo] = None,
job_provisioning_data: Optional[JobProvisioningData] = None,
name: str = "test_instance",
) -> InstanceModel:
if instance_id is None:
instance_id = uuid.uuid4()
Expand Down Expand Up @@ -544,7 +545,7 @@ async def create_instance(

im = InstanceModel(
id=instance_id,
name="test_instance",
name=name,
instance_num=instance_num,
pool=pool,
fleet=fleet,
Expand Down
4 changes: 4 additions & 0 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async
"id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e",
"project_name": project.name,
"name": f"{spec.configuration.name}-0",
"fleet_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e",
"fleet_name": spec.configuration.name,
"instance_num": 0,
"job_name": None,
"hostname": None,
Expand Down Expand Up @@ -443,6 +445,8 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
},
},
"name": f"{spec.configuration.name}-0",
"fleet_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e",
"fleet_name": spec.configuration.name,
"instance_num": 0,
"pool_name": None,
"job_name": None,
Expand Down
Loading

0 comments on commit bc5f0ac

Please sign in to comment.