diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py index 27ecbb004c5..8760774fd90 100644 --- a/mmdet/core/utils/dist_utils.py +++ b/mmdet/core/utils/dist_utils.py @@ -162,6 +162,12 @@ def sync_random_seed(seed=None, device='cuda'): because the seed should be identical across all processes in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. diff --git a/mmdet/datasets/samplers/distributed_sampler.py b/mmdet/datasets/samplers/distributed_sampler.py index 3ed21bdb2c4..ab544a9c469 100644 --- a/mmdet/datasets/samplers/distributed_sampler.py +++ b/mmdet/datasets/samplers/distributed_sampler.py @@ -17,15 +17,23 @@ def __init__(self, seed=0): super().__init__( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) def __iter__(self): # deterministically shuffle based on epoch if self.shuffle: g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. g.manual_seed(self.epoch + self.seed) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: diff --git a/mmdet/datasets/samplers/infinite_sampler.py b/mmdet/datasets/samplers/infinite_sampler.py index cfea01a345d..d42487e6ac0 100644 --- a/mmdet/datasets/samplers/infinite_sampler.py +++ b/mmdet/datasets/samplers/infinite_sampler.py @@ -50,9 +50,12 @@ def __init__(self, self.world_size = world_size self.dataset = dataset self.batch_size = batch_size - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle @@ -138,9 +141,12 @@ def __init__(self, self.world_size = world_size self.dataset = dataset self.batch_size = batch_size - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle self.size = len(dataset)