Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Improve UX of pytorch-elastic plugin by configuring reasonable defaults #2543

Merged
merged 12 commits into from
Jul 21, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from flytekit.core.pod_template import PodTemplate

from kubernetes.client import V1Container, V1PodSpec, V1Volume, V1VolumeMount, V1EmptyDirVolumeSource


def add_shared_mem_volume_to_pod_template(pod_template: PodTemplate) -> None:
"""Add shared memory volume and volume mount to the pod template."""
shm_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
shm_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm")

if pod_template.pod_spec is None:
pod_template.pod_spec = V1PodSpec()

if pod_template.pod_spec.containers is None:
pod_template.pod_spec.containers = []

if pod_template.pod_spec.volumes is None:
pod_template.pod_spec.volumes = []
pod_template.pod_spec.volumes.append(shm_volume)

num_containers = len(pod_template.pod_spec.containers)
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
if num_containers == 0:
pod_template.pod_spec.containers.append(V1Container(name="primary"))
elif num_containers == 1:
pass
else:
raise ValueError(
"When configuring a pod template with multiple containers, please set `increase_shared_mem=False` "
"in the task config and if required mount a volume to increase the shared memory size in the respective "
"container yourself."
)

if pod_template.pod_spec.containers[0].volume_mounts is None:
pod_template.pod_spec.containers[0].volume_mounts = []

pod_template.pod_spec.containers[0].volume_mounts.append(shm_volume_mount)
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 27 additions & 1 deletion plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.extend import IgnoreOutputs, TaskPlugins
from flytekit.loggers import logger

from .error_handling import create_recoverable_error_file, is_recoverable_worker_error
from .pod_template import add_shared_mem_volume_to_pod_template


cloudpickle = lazy_module("cloudpickle")

Expand Down Expand Up @@ -103,13 +106,18 @@ class PyTorch(object):
worker: Configuration for the worker replica group.
run_policy: Configuration for the run policy.
num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead.
increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used
(e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough
and and one might have to increase the shared memory size. This option configures the task's pod template to mount
an `emptyDir` volume with medium `Memory` to to `/dev/shm`.
"""

master: Master = field(default_factory=lambda: Master())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None
increase_shared_mem: bool = True
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand All @@ -130,14 +138,23 @@ class Elastic(object):
max_restarts (int): Maximum number of worker group restarts before failing.
rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`.
See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`.

Default timeouts are set to 15 minutes to account for the fact that some workers might start faster than others: Some pods might
be assigned to a running node which might have the image in its cache while other workers might require a node scale up and image pull.

increase_shared_mem (bool): PyTorch uses shared memory to share data between processes. If torch multiprocessing is used
(e.g. for multithreaded data loaders) the default shared memory segment size that the container runs with might not be enough
and and one might have to increase the shared memory size. This option configures the task's pod template to mount
an `emptyDir` volume with medium `Memory` to to `/dev/shm`.
"""

nnodes: Union[int, str] = 1
nproc_per_node: int = 1
start_method: str = "spawn"
monitor_interval: int = 5
max_restarts: int = 0
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900})
increase_shared_mem: bool = True


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
Expand Down Expand Up @@ -165,6 +182,10 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
task_type_version=1,
**kwargs,
)
if self.task_config.increase_shared_mem:
if self.pod_template is None:
self.pod_template = PodTemplate()
add_shared_mem_volume_to_pod_template(self.pod_template)

def _convert_replica_spec(
self, replica_config: Union[Master, Worker]
Expand Down Expand Up @@ -299,6 +320,11 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
"""
self.rdzv_backend = "c10d"

if self.task_config.increase_shared_mem:
if self.pod_template is None:
self.pod_template = PodTemplate()
add_shared_mem_volume_to_pod_template(self.pod_template)

def _execute(self, **kwargs) -> Any:
"""
Execute the task function using torch distributed's `elastic_launch`.
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,12 @@ def wf(recoverable: bool):
else:
with pytest.raises(RuntimeError):
wf(recoverable=recoverable)


def test_default_timeouts():
"""Test that default timeouts are set for the elastic task."""
@task(task_config=Elastic(nnodes=1))
def test_task():
pass

assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900}
109 changes: 109 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Test functionality that is shared between the pytorch and pytorch-elastic tasks."""

from contextlib import nullcontext
from typing import Union

import pytest
from flytekitplugins.kfpytorch.task import Elastic, PyTorch
from kubernetes.client import V1Container, V1EmptyDirVolumeSource, V1PodSpec, V1Volume, V1VolumeMount

from flytekit import PodTemplate, task


@pytest.mark.parametrize(
"task_config, pod_template, needs_shm_volume, raises",
[
# Test that by default shared memory volume is added
(PyTorch(num_workers=3), None, True, False),
(Elastic(nnodes=2, increase_shared_mem=True), None, True, False),
# Test disabling shared memory volume
(PyTorch(num_workers=3, increase_shared_mem=False), None, False, False),
(Elastic(nnodes=2, increase_shared_mem=False), None, False, False),
# Test that explicitly passed pod template does not break adding shm volume
(Elastic(nnodes=2, increase_shared_mem=True), PodTemplate(), True, False),
# Test that pod template with container does not break adding shm volume
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(containers=[V1Container(name="primary")]),
),
True,
False,
),
# Test that pod template with volume/volume mount does not break adding shm volume
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary", volume_mounts=[V1VolumeMount(name="foo", mount_path="/bar")])
],
volumes=[V1Volume(name="foo")],
),
),
True,
False,
),
# Test that pod template with multiple containers raises an error
(
Elastic(nnodes=2),
PodTemplate(
pod_spec=V1PodSpec(
containers=[
V1Container(name="primary"),
V1Container(name="secondary"),
]
),
),
True,
True,
),
],
)
def test_task_shared_memory(
task_config: Union[Elastic, PyTorch], pod_template: PodTemplate, needs_shm_volume: bool, raises: bool
):
"""Test that the task pod template is configured with a shared memory volume if needed."""

expected_volume = V1Volume(name="shm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
expected_volume_mount = V1VolumeMount(name="shm", mount_path="/dev/shm")

with pytest.raises(ValueError) if raises else nullcontext():

@task(
task_config=task_config,
pod_template=pod_template,
)
def test_task() -> None:
pass

if needs_shm_volume:
assert test_task.pod_template is not None
assert test_task.pod_template.pod_spec is not None
assert test_task.pod_template.pod_spec.volumes is not None
assert test_task.pod_template.pod_spec.containers is not None
assert test_task.pod_template.pod_spec.containers[0].volume_mounts is not None

assert any([v == expected_volume for v in test_task.pod_template.pod_spec.volumes])
assert any(
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts]
)

else:
no_pod_template = test_task.pod_template is None
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
no_pod_spec = no_pod_template or test_task.pod_template.pod_spec is None
no_volumes = no_pod_spec or test_task.pod_template.pod_spec.volumes is None
no_containers = no_pod_spec or len(test_task.pod_template.pod_spec.containers) == 0
no_volume_mounts = no_containers or test_task.pod_template.pod_spec.containers[0].volume_mounts is None
empty_volume_mounts = (
no_volume_mounts or len(test_task.pod_template.pod_spec.containers[0].volume_mounts) == 0
)
no_shm_volume_condition = no_volumes or not any(
[v == expected_volume for v in test_task.pod_template.pod_spec.volumes]
)
no_shm_volume_mount_condition = empty_volume_mounts or not any(
[v == expected_volume_mount for v in test_task.pod_template.pod_spec.containers[0].volume_mounts]
)

assert no_shm_volume_condition
assert no_shm_volume_mount_condition
Loading