Skip to content

Commit

Permalink
Passing rank and num_replicas to dist.get_sampler (#3137)
Browse files Browse the repository at this point in the history
* passing custom rank to dist.get_sampler

* passing in num_replicas

* adding docstring
  • Loading branch information
ShashankMosaicML authored and Chuck Tang committed May 16, 2024
1 parent 034b076 commit 203aa6c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,14 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0):
dist.init_process_group(device_obj.dist_backend, timeout=timeout_timedelta)


def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, shuffle: bool = False):
def get_sampler(
dataset: torch.utils.data.Dataset,
*,
drop_last: bool = False,
shuffle: bool = False,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
):
"""Constructs a :class:`~torch.utils.data.distributed.DistributedSampler` for a dataset.
The :class:`~torch.utils.data.distributed.DistributedSampler` assumes that each rank has a complete copy of the
Expand All @@ -595,6 +602,8 @@ def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, s
dataset (torch.utils.data.Dataset): The dataset.
drop_last (bool): Whether to trop the last batch.
shuffle (bool): Whether to shuffle the dataset.
num_replicas (int, optional): The number of replicas. If ``None``, defaults to the world size.
rank (int, optional): The rank. If ``None``, defaults to the global rank.
Returns:
torch.utils.data.distributed.DistributedSampler: The sampler.
Expand All @@ -603,8 +612,8 @@ def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, s
dataset,
drop_last=drop_last,
shuffle=shuffle,
num_replicas=get_world_size(),
rank=get_global_rank(),
num_replicas=get_world_size() if num_replicas is None else num_replicas,
rank=get_global_rank() if rank is None else rank,
)


Expand Down

0 comments on commit 203aa6c

Please sign in to comment.