Skip to content

Commit

Permalink
[Enchance] Supplementary notes of sync_random_seed (#7440)
Browse files Browse the repository at this point in the history
* Supplementary Notes

* update

* update

* update
  • Loading branch information
hhaAndroid authored Mar 18, 2022
1 parent 95f199c commit 86c7d8d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
6 changes: 6 additions & 0 deletions mmdet/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def sync_random_seed(seed=None, device='cuda'):
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Expand Down
14 changes: 11 additions & 3 deletions mmdet/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@ def __init__(self,
seed=0):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)

# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
Expand Down
18 changes: 12 additions & 6 deletions mmdet/datasets/samplers/infinite_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle

Expand Down Expand Up @@ -138,9 +141,12 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle
self.size = len(dataset)
Expand Down

0 comments on commit 86c7d8d

Please sign in to comment.