From 2d35fceaa1f58ee1b832dc3b148ad7c0745bf0b3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 09:34:08 -0400 Subject: [PATCH 1/3] ref: part 3 of #3733 --- .../accelerators/ddp_spawn_backend.py | 129 +++++++++++++++++- 1 file changed, 125 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index fc2fc88563e1d..da2bad41903c5 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License import os +import re import torch +import torch.distributed as torch_distrib import torch.multiprocessing as mp +from pytorch_lightning import _logger as log +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import find_free_network_port -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase -class DDPSpawnBackend(DDPBase): +class DDPSpawnBackend(Accelerator): def __init__(self, trainer, nprocs): super().__init__(trainer) @@ -40,7 +46,7 @@ def train(self): model = self.trainer.model # train in children process - mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,)) + mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) # restore main state with best weights best_path = self.mp_queue.get() @@ -64,12 +70,109 @@ def __recover_child_process_weights(self, model, best_path, last_path): self.trainer.model = model + def ddp_train(self, process_idx, mp_queue, model): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # 16-bit + model = self.trainer.precision_connector.connect(model) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # get original model + model = self.trainer.get_model() + + # persist info in ddp_spawn + self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + def set_world_ranks(self, process_idx): self.trainer.local_rank = process_idx self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx, is_master): + def model_to_device(self, model, process_idx): gpu_idx = process_idx self.trainer.root_gpu = gpu_idx torch.cuda.set_device(self.trainer.root_gpu) @@ -78,3 +181,21 @@ def model_to_device(self, model, process_idx, is_master): def get_device_ids(self): device_ids = [self.trainer.root_gpu] return device_ids + + def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): + best_model_path = None + if self.trainer.checkpoint_callback is not None: + best_model_path = self.trainer.checkpoint_callback.best_model_path + + if self.trainer.global_rank == 0 and mp_queue is not None: + rank_zero_warn('cleaning up ddp environment...') + # todo, pass complete checkpoint as state dictionary + mp_queue.put(best_model_path) + mp_queue.put(results) + + # save the last weights + last_path = None + if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) + atomic_save(model.state_dict(), last_path) + mp_queue.put(last_path) From d60acb70bb70a2413ab49b539f102d94d73cff77 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 09:36:04 -0400 Subject: [PATCH 2/3] ref: part 3 of #3733 --- pytorch_lightning/accelerators/ddp_spawn_backend.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index da2bad41903c5..c97749812839b 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.distributed.dist import LightningDistributed class DDPSpawnBackend(Accelerator): @@ -32,6 +33,7 @@ def __init__(self, trainer, nprocs): super().__init__(trainer) self.mp_queue = None self.nprocs = nprocs + self.dist = LightningDistributed() def setup(self, model): os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) @@ -167,6 +169,9 @@ def barrier(self, name: str = None): if torch_distrib.is_initialized(): torch_distrib.barrier() + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + def set_world_ranks(self, process_idx): self.trainer.local_rank = process_idx self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx From 9561bbc158a1614e324edb39f330ed455761626f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 09:37:24 -0400 Subject: [PATCH 3/3] ref: part 3 of #3733 --- .../accelerators/ddp_spawn_backend.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index c97749812839b..0c3c42dae061c 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -59,19 +59,6 @@ def train(self): self.__recover_child_process_weights(model, best_path, last_path) return results - def __recover_child_process_weights(self, model, best_path, last_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.testing: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - def ddp_train(self, process_idx, mp_queue, model): """ Entry point for ddp @@ -187,6 +174,19 @@ def get_device_ids(self): device_ids = [self.trainer.root_gpu] return device_ids + def __recover_child_process_weights(self, model, best_path, last_path): + # transfer back the best path to the trainer + if self.trainer.checkpoint_callback: + self.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also best score + + # load last weights + if last_path is not None and not self.trainer.testing: + ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt) + + self.trainer.model = model + def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): best_model_path = None if self.trainer.checkpoint_callback is not None: