From 86c7d8d375566064ac449d8ff7043db89f722b56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Fri, 18 Mar 2022 15:07:45 +0800 Subject: [PATCH] [Enchance] Supplementary notes of sync_random_seed (#7440) * Supplementary Notes * update * update * update --- mmdet/core/utils/dist_utils.py | 6 ++++++ mmdet/datasets/samplers/distributed_sampler.py | 14 +++++++++++--- mmdet/datasets/samplers/infinite_sampler.py | 18 ++++++++++++------ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py index 27ecbb004c5..8760774fd90 100644 --- a/mmdet/core/utils/dist_utils.py +++ b/mmdet/core/utils/dist_utils.py @@ -162,6 +162,12 @@ def sync_random_seed(seed=None, device='cuda'): because the seed should be identical across all processes in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. diff --git a/mmdet/datasets/samplers/distributed_sampler.py b/mmdet/datasets/samplers/distributed_sampler.py index 3ed21bdb2c4..ab544a9c469 100644 --- a/mmdet/datasets/samplers/distributed_sampler.py +++ b/mmdet/datasets/samplers/distributed_sampler.py @@ -17,15 +17,23 @@ def __init__(self, seed=0): super().__init__( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) def __iter__(self): # deterministically shuffle based on epoch if self.shuffle: g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. g.manual_seed(self.epoch + self.seed) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: diff --git a/mmdet/datasets/samplers/infinite_sampler.py b/mmdet/datasets/samplers/infinite_sampler.py index cfea01a345d..d42487e6ac0 100644 --- a/mmdet/datasets/samplers/infinite_sampler.py +++ b/mmdet/datasets/samplers/infinite_sampler.py @@ -50,9 +50,12 @@ def __init__(self, self.world_size = world_size self.dataset = dataset self.batch_size = batch_size - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle @@ -138,9 +141,12 @@ def __init__(self, self.world_size = world_size self.dataset = dataset self.batch_size = batch_size - # Must be the same across all workers. If None, will use a - # random seed shared among workers - # (require synchronization among all workers) + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle self.size = len(dataset)