diff --git a/docs/docs/concepts/volumes.md b/docs/docs/concepts/volumes.md index 5b85dfbc0..884d39d60 100644 --- a/docs/docs/concepts/volumes.md +++ b/docs/docs/concepts/volumes.md @@ -236,8 +236,27 @@ volumes: Since persistence isn't guaranteed (instances may be interrupted or runs may occur on different instances), use instance volumes only for caching or with directories manually mounted to network storage. -> Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`, -> and can also be used with [SSH fleets](fleets.md#ssh). +!!! info "Backends" + Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`, and can also be used with [SSH fleets](fleets.md#ssh). + +??? info "Optional volumes" + If the volume is not critical for your workload, you can mark it as `optional`. + +
+ + ```yaml + type: task + + volumes: + - instance_path: /dstack-cache + path: /root/.cache/ + optional: true + ``` + + Configurations with optional volumes can run in any backend, but the volume is only mounted + if the selected backend supports it. + +
### Use instance volumes for caching diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 7422ef426..c0c065f1d 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -136,6 +136,15 @@ def parse(cls, v: str) -> Self: class InstanceMountPoint(CoreModel): instance_path: Annotated[str, Field(description="The absolute path on the instance (host)")] path: Annotated[str, Field(description="The absolute path in the container")] + optional: Annotated[ + bool, + Field( + description=( + "Allow running without this volume" + " in backends that do not support instance volumes" + ), + ), + ] = False _validate_instance_path = validator("instance_path", allow_reuse=True)( _validate_mount_point_path diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index fbb88fc9e..5b485c1ce 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -64,7 +64,7 @@ ) from dstack._internal.server.services.runs import ( check_can_attach_run_volumes, - check_run_spec_has_instance_mounts, + check_run_spec_requires_instance_mounts, get_offer_volumes, get_run_volume_models, get_run_volumes, @@ -418,7 +418,7 @@ async def _run_job_on_new_instance( master_job_provisioning_data=master_job_provisioning_data, volumes=volumes, privileged=job.job_spec.privileged, - instance_mounts=check_run_spec_has_instance_mounts(run.run_spec), + instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec), ) # Limit number of offers tried to prevent long-running processing # in case all offers fail. diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index dea0ff257..fc665a352 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -330,7 +330,7 @@ async def get_plan( multinode=jobs[0].job_spec.jobs_per_replica > 1, volumes=volumes, privileged=jobs[0].job_spec.privileged, - instance_mounts=check_run_spec_has_instance_mounts(run_spec), + instance_mounts=check_run_spec_requires_instance_mounts(run_spec), ) job_plans = [] @@ -897,9 +897,10 @@ def get_offer_mount_point_volume( raise ServerClientError("Failed to find an eligible volume for the mount point") -def check_run_spec_has_instance_mounts(run_spec: RunSpec) -> bool: +def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool: return any( - is_core_model_instance(mp, InstanceMountPoint) for mp in run_spec.configuration.volumes + is_core_model_instance(mp, InstanceMountPoint) and not mp.optional + for mp in run_spec.configuration.volumes ) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 2c2586d42..f4495f0e2 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional, Union +from typing import Any, List, Optional, Union from uuid import UUID from pydantic import parse_obj_as @@ -19,6 +19,7 @@ RunPlan, RunSpec, ) +from dstack._internal.core.models.volumes import InstanceMountPoint from dstack._internal.server.schemas.runs import ( ApplyRunPlanRequest, CreateInstanceRequest, @@ -122,32 +123,32 @@ def create_instance( def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]: spec_excludes: dict[str, set[str]] = {} - configuration_excludes: set[str] = set() + configuration_excludes: dict[str, Any] = {} profile_excludes: set[str] = set() configuration = run_spec.configuration profile = run_spec.profile # client >= 0.18.18 / server <= 0.18.17 compatibility tweak if not configuration.privileged: - configuration_excludes.add("privileged") + configuration_excludes["privileged"] = True # client >= 0.18.23 / server <= 0.18.22 compatibility tweak if configuration.type == "service" and configuration.gateway is None: - configuration_excludes.add("gateway") + configuration_excludes["gateway"] = True # client >= 0.18.30 / server <= 0.18.29 compatibility tweak if run_spec.configuration.user is None: - configuration_excludes.add("user") + configuration_excludes["user"] = True # client >= 0.18.30 / server <= 0.18.29 compatibility tweak if configuration.reservation is None: - configuration_excludes.add("reservation") + configuration_excludes["reservation"] = True if profile is not None and profile.reservation is None: profile_excludes.add("reservation") if configuration.idle_duration is None: - configuration_excludes.add("idle_duration") + configuration_excludes["idle_duration"] = True if profile is not None and profile.idle_duration is None: profile_excludes.add("idle_duration") # client >= 0.18.38 / server <= 0.18.37 compatibility tweak if configuration.stop_duration is None: - configuration_excludes.add("stop_duration") + configuration_excludes["stop_duration"] = True if profile is not None and profile.stop_duration is None: profile_excludes.add("stop_duration") # client >= 0.18.40 / server <= 0.18.39 compatibility tweak @@ -155,9 +156,14 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]: is_core_model_instance(configuration, ServiceConfiguration) and configuration.strip_prefix == STRIP_PREFIX_DEFAULT ): - configuration_excludes.add("strip_prefix") + configuration_excludes["strip_prefix"] = True if configuration.single_branch is None: - configuration_excludes.add("single_branch") + configuration_excludes["single_branch"] = True + if all( + not is_core_model_instance(v, InstanceMountPoint) or not v.optional + for v in configuration.volumes + ): + configuration_excludes["volumes"] = {"__all__": {"optional"}} if configuration_excludes: spec_excludes["configuration"] = configuration_excludes diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index c03005cac..89ab76f9d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -302,6 +302,78 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance await session.refresh(pool) assert not pool.instances + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_provisions_job_with_optional_instance_volume_not_attached( + self, + test_db, + session: AsyncSession, + ): + project = await create_project(session=session) + user = await create_user(session=session) + pool = await create_pool(session=session, project=project) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.volumes = [ + InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True) + ] + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=True, + ) + offer = InstanceOfferWithAvailability( + backend=BackendType.RUNPOD, + instance=InstanceType( + name="instance", + resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.RUNPOD + backend_mock.compute.return_value.get_offers_cached.return_value = [offer] + backend_mock.compute.return_value.run_job.return_value = JobProvisioningData( + backend=offer.backend, + instance_type=offer.instance, + instance_id="instance_id", + hostname="1.1.1.1", + internal_ip=None, + region=offer.region, + price=offer.price, + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + dockerized=False, + backend_data=None, + ) + await process_submitted_jobs() + + await session.refresh(job) + assert job is not None + assert job.status == JobStatus.PROVISIONING + + await session.refresh(pool) + instance_offer = InstanceOfferWithAvailability.parse_raw(pool.instances[0].offer) + assert offer == instance_offer + pool_job_provisioning_data = pool.instances[0].job_provisioning_data + assert pool_job_provisioning_data == job.job_provisioning_data + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession): @@ -412,7 +484,8 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) run_spec.configuration.volumes = [ VolumeMountPoint(name=volume.name, path="/volume"), - InstanceMountPoint(instance_path="/root/.cache", path="/cache"), + InstanceMountPoint(instance_path="/root/.data", path="/data"), + InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True), ] run = await create_run( session=session, diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py index dd89a4193..dd40a12d6 100644 --- a/src/tests/_internal/server/services/runner/test_client.py +++ b/src/tests/_internal/server/services/runner/test_client.py @@ -175,7 +175,9 @@ def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter): "device_name": "/dev/sdv", } ], - "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}], + "instance_mounts": [ + {"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False} + ], } self.assert_request(adapter, 0, "POST", "/api/submit", expected_request) @@ -341,7 +343,9 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter): } ], "volume_mounts": [{"name": "vol", "path": "/vol"}], - "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}], + "instance_mounts": [ + {"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False} + ], "host_ssh_user": "dstack", "host_ssh_keys": ["host_key"], "container_ssh_keys": ["project_key", "user_key"],