diff --git a/docs/docs/concepts/volumes.md b/docs/docs/concepts/volumes.md
index 5b85dfbc0..884d39d60 100644
--- a/docs/docs/concepts/volumes.md
+++ b/docs/docs/concepts/volumes.md
@@ -236,8 +236,27 @@ volumes:
Since persistence isn't guaranteed (instances may be interrupted or runs may occur on different instances), use instance
volumes only for caching or with directories manually mounted to network storage.
-> Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`,
-> and can also be used with [SSH fleets](fleets.md#ssh).
+!!! info "Backends"
+ Instance volumes are currently supported for all backends except `runpod`, `vastai` and `kubernetes`, and can also be used with [SSH fleets](fleets.md#ssh).
+
+??? info "Optional volumes"
+ If the volume is not critical for your workload, you can mark it as `optional`.
+
+
+
+ ```yaml
+ type: task
+
+ volumes:
+ - instance_path: /dstack-cache
+ path: /root/.cache/
+ optional: true
+ ```
+
+ Configurations with optional volumes can run in any backend, but the volume is only mounted
+ if the selected backend supports it.
+
+
### Use instance volumes for caching
diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py
index 7422ef426..c0c065f1d 100644
--- a/src/dstack/_internal/core/models/volumes.py
+++ b/src/dstack/_internal/core/models/volumes.py
@@ -136,6 +136,15 @@ def parse(cls, v: str) -> Self:
class InstanceMountPoint(CoreModel):
instance_path: Annotated[str, Field(description="The absolute path on the instance (host)")]
path: Annotated[str, Field(description="The absolute path in the container")]
+ optional: Annotated[
+ bool,
+ Field(
+ description=(
+ "Allow running without this volume"
+ " in backends that do not support instance volumes"
+ ),
+ ),
+ ] = False
_validate_instance_path = validator("instance_path", allow_reuse=True)(
_validate_mount_point_path
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index fbb88fc9e..5b485c1ce 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -64,7 +64,7 @@
)
from dstack._internal.server.services.runs import (
check_can_attach_run_volumes,
- check_run_spec_has_instance_mounts,
+ check_run_spec_requires_instance_mounts,
get_offer_volumes,
get_run_volume_models,
get_run_volumes,
@@ -418,7 +418,7 @@ async def _run_job_on_new_instance(
master_job_provisioning_data=master_job_provisioning_data,
volumes=volumes,
privileged=job.job_spec.privileged,
- instance_mounts=check_run_spec_has_instance_mounts(run.run_spec),
+ instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec),
)
# Limit number of offers tried to prevent long-running processing
# in case all offers fail.
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index dea0ff257..fc665a352 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -330,7 +330,7 @@ async def get_plan(
multinode=jobs[0].job_spec.jobs_per_replica > 1,
volumes=volumes,
privileged=jobs[0].job_spec.privileged,
- instance_mounts=check_run_spec_has_instance_mounts(run_spec),
+ instance_mounts=check_run_spec_requires_instance_mounts(run_spec),
)
job_plans = []
@@ -897,9 +897,10 @@ def get_offer_mount_point_volume(
raise ServerClientError("Failed to find an eligible volume for the mount point")
-def check_run_spec_has_instance_mounts(run_spec: RunSpec) -> bool:
+def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:
return any(
- is_core_model_instance(mp, InstanceMountPoint) for mp in run_spec.configuration.volumes
+ is_core_model_instance(mp, InstanceMountPoint) and not mp.optional
+ for mp in run_spec.configuration.volumes
)
diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py
index 2c2586d42..f4495f0e2 100644
--- a/src/dstack/api/server/_runs.py
+++ b/src/dstack/api/server/_runs.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import List, Optional, Union
+from typing import Any, List, Optional, Union
from uuid import UUID
from pydantic import parse_obj_as
@@ -19,6 +19,7 @@
RunPlan,
RunSpec,
)
+from dstack._internal.core.models.volumes import InstanceMountPoint
from dstack._internal.server.schemas.runs import (
ApplyRunPlanRequest,
CreateInstanceRequest,
@@ -122,32 +123,32 @@ def create_instance(
def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]:
spec_excludes: dict[str, set[str]] = {}
- configuration_excludes: set[str] = set()
+ configuration_excludes: dict[str, Any] = {}
profile_excludes: set[str] = set()
configuration = run_spec.configuration
profile = run_spec.profile
# client >= 0.18.18 / server <= 0.18.17 compatibility tweak
if not configuration.privileged:
- configuration_excludes.add("privileged")
+ configuration_excludes["privileged"] = True
# client >= 0.18.23 / server <= 0.18.22 compatibility tweak
if configuration.type == "service" and configuration.gateway is None:
- configuration_excludes.add("gateway")
+ configuration_excludes["gateway"] = True
# client >= 0.18.30 / server <= 0.18.29 compatibility tweak
if run_spec.configuration.user is None:
- configuration_excludes.add("user")
+ configuration_excludes["user"] = True
# client >= 0.18.30 / server <= 0.18.29 compatibility tweak
if configuration.reservation is None:
- configuration_excludes.add("reservation")
+ configuration_excludes["reservation"] = True
if profile is not None and profile.reservation is None:
profile_excludes.add("reservation")
if configuration.idle_duration is None:
- configuration_excludes.add("idle_duration")
+ configuration_excludes["idle_duration"] = True
if profile is not None and profile.idle_duration is None:
profile_excludes.add("idle_duration")
# client >= 0.18.38 / server <= 0.18.37 compatibility tweak
if configuration.stop_duration is None:
- configuration_excludes.add("stop_duration")
+ configuration_excludes["stop_duration"] = True
if profile is not None and profile.stop_duration is None:
profile_excludes.add("stop_duration")
# client >= 0.18.40 / server <= 0.18.39 compatibility tweak
@@ -155,9 +156,14 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]:
is_core_model_instance(configuration, ServiceConfiguration)
and configuration.strip_prefix == STRIP_PREFIX_DEFAULT
):
- configuration_excludes.add("strip_prefix")
+ configuration_excludes["strip_prefix"] = True
if configuration.single_branch is None:
- configuration_excludes.add("single_branch")
+ configuration_excludes["single_branch"] = True
+ if all(
+ not is_core_model_instance(v, InstanceMountPoint) or not v.optional
+ for v in configuration.volumes
+ ):
+ configuration_excludes["volumes"] = {"__all__": {"optional"}}
if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
index c03005cac..89ab76f9d 100644
--- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
+++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py
@@ -302,6 +302,78 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance
await session.refresh(pool)
assert not pool.instances
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_provisions_job_with_optional_instance_volume_not_attached(
+ self,
+ test_db,
+ session: AsyncSession,
+ ):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+ pool = await create_pool(session=session, project=project)
+ repo = await create_repo(
+ session=session,
+ project_id=project.id,
+ )
+ run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
+ run_spec.configuration.volumes = [
+ InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True)
+ ]
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-run",
+ run_spec=run_spec,
+ )
+ job = await create_job(
+ session=session,
+ run=run,
+ instance_assigned=True,
+ )
+ offer = InstanceOfferWithAvailability(
+ backend=BackendType.RUNPOD,
+ instance=InstanceType(
+ name="instance",
+ resources=Resources(cpus=1, memory_mib=512, spot=False, gpus=[]),
+ ),
+ region="us",
+ price=1.0,
+ availability=InstanceAvailability.AVAILABLE,
+ )
+ with patch("dstack._internal.server.services.backends.get_project_backends") as m:
+ backend_mock = Mock()
+ m.return_value = [backend_mock]
+ backend_mock.TYPE = BackendType.RUNPOD
+ backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
+ backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
+ backend=offer.backend,
+ instance_type=offer.instance,
+ instance_id="instance_id",
+ hostname="1.1.1.1",
+ internal_ip=None,
+ region=offer.region,
+ price=offer.price,
+ username="ubuntu",
+ ssh_port=22,
+ ssh_proxy=None,
+ dockerized=False,
+ backend_data=None,
+ )
+ await process_submitted_jobs()
+
+ await session.refresh(job)
+ assert job is not None
+ assert job.status == JobStatus.PROVISIONING
+
+ await session.refresh(pool)
+ instance_offer = InstanceOfferWithAvailability.parse_raw(pool.instances[0].offer)
+ assert offer == instance_offer
+ pool_job_provisioning_data = pool.instances[0].job_provisioning_data
+ assert pool_job_provisioning_data == job.job_provisioning_data
+
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession):
@@ -412,7 +484,8 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn
run_spec = get_run_spec(run_name="test-run", repo_id=repo.name)
run_spec.configuration.volumes = [
VolumeMountPoint(name=volume.name, path="/volume"),
- InstanceMountPoint(instance_path="/root/.cache", path="/cache"),
+ InstanceMountPoint(instance_path="/root/.data", path="/data"),
+ InstanceMountPoint(instance_path="/root/.cache", path="/cache", optional=True),
]
run = await create_run(
session=session,
diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py
index dd89a4193..dd40a12d6 100644
--- a/src/tests/_internal/server/services/runner/test_client.py
+++ b/src/tests/_internal/server/services/runner/test_client.py
@@ -175,7 +175,9 @@ def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter):
"device_name": "/dev/sdv",
}
],
- "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}],
+ "instance_mounts": [
+ {"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False}
+ ],
}
self.assert_request(adapter, 0, "POST", "/api/submit", expected_request)
@@ -341,7 +343,9 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter):
}
],
"volume_mounts": [{"name": "vol", "path": "/vol"}],
- "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}],
+ "instance_mounts": [
+ {"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False}
+ ],
"host_ssh_user": "dstack",
"host_ssh_keys": ["host_key"],
"container_ssh_keys": ["project_key", "user_key"],