Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added helper methods for torch.distributed.boradcast #189

Merged
merged 11 commits into from
Jan 4, 2022
45 changes: 44 additions & 1 deletion composer/utils/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import warnings
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Callable, ContextManager, List, Optional, Sequence, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, ContextManager, List, Optional, Sequence, TypeVar, Union, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -141,6 +141,49 @@ def all_reduce(
"The mosaic trainer will automatically do this for you.")


def broadcast(tensor: torch.Tensor, src: int) -> None:
"""Broadcasts the tensor to the whole group.

``tensor`` must have the same number of elements in all processes participating in the collective.
See :meth:`torch.distributed.broadcast`.

Args:
tensor (torch.Tensor): Data to be sent if ``src`` is the rank of current process,
and tensor to be used to save received data otherwise.
src (int): Source rank
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast(tensor, src)
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.broadcast(). "
"The mosaic trainer will automatically do this for you.")


def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
"""Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :meth:`broadcast`, but Python objects can be passed in.
Note that all objects in ``object_list`` must be picklable in order to be broadcasted.
See :meth:`torch.distributed.broadcast`.

Args:
object_list (torch.Tensor): List of input objects to broadcast.
Each object must be picklable. Only objects on the ``src`` rank will be broadcast,
but each rank must provide lists of equal sizes.
src (int, optional): Source rank (default: ``0``)
"""
if dist.is_available():
dist.broadcast_object_list(object_list, src)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f"Since the world_size({world_size}) > 1, please configure DDP to use ddp.all_gather_object(). "
"The mosaic trainer will automatically do this for you.")


def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""all_gather collects a tensor from each rank, and returns a sequence of tensors indexed by rank

Expand Down