Skip to content

Commit

Permalink
ref: finish #3733
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Oct 3, 2020
1 parent ed1450a commit 274364d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 60 deletions.
83 changes: 41 additions & 42 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch.distributed as dist
import torch
import torch.distributed as torch_distrib
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Optional

import numpy as np
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist


from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

Expand All @@ -47,6 +47,7 @@ def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False
self.interactive_ddp_procs = []
self.dist = LightningDistributed()

def setup(self, model):
Expand All @@ -57,7 +58,6 @@ def setup(self, model):
self._call_children_scripts()

def _call_children_scripts(self):

assert self.trainer.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True
Expand Down Expand Up @@ -104,11 +104,12 @@ def _call_children_scripts(self):

os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}'

self.trainer.interactive_ddp_procs = []
self.interactive_ddp_procs = []
for local_rank in range(1, self.trainer.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank])
env_copy['PL_GLOBAL_SEED'] = os.environ.get('PL_GLOBAL_SEED', None)

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
Expand All @@ -117,7 +118,7 @@ def _call_children_scripts(self):
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.trainer.interactive_ddp_procs.append(proc)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
Expand All @@ -126,12 +127,36 @@ def _call_children_scripts(self):

self.task_idx = 0

# wait for all the procs to start
sleep(2)

def train(self):
model = self.trainer.model
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
del os.environ['WORLD_SIZE']
results = self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True)
if 'WORLD_SIZE' in os.environ:
del os.environ['WORLD_SIZE']
return results

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
Expand All @@ -145,17 +170,7 @@ def set_world_ranks(self, process_idx):
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx, is_master):
gpu_idx = process_idx

# when using ddp, the master process (proc 0) continues running as the main one
# this means that the local rank will always be 0
# (even if cuda visible devices has other visible gpus)
# this means that the master process needs to pull the 0th visible index as the device number
if is_master:
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx))
gpu_idx = int(os.environ.get('PL_DDP_PID', process_idx))

self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
Expand All @@ -165,25 +180,8 @@ def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()
def on_train_end(self):
pass

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
Expand All @@ -207,7 +205,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
Returns:
"""
seed = os.environ.get("PL_GLOBAL_SEED")
seed = os.environ.get("PL_GLOBAL_SEED", None)
if seed is not None:
seed_everything(int(seed))

Expand Down Expand Up @@ -268,6 +266,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
model = model.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
Expand Down
36 changes: 18 additions & 18 deletions tests/backends/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args):
assert result['status'] == 'complete'


# @pytest.mark.parametrize('cli_args', [
# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# # call the script
# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)
#
# # load the results of the script
# result_path = os.path.join(tmpdir, 'ddp.result')
# result = torch.load(result_path)
#
# # verify the file wrote the expected outputs
# assert result['status'] == 'complete'
#
# model_outs = result['result']
# for out in model_outs:
# assert out['test_acc'] > 0.90
@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# call the script
call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)

# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
result = torch.load(result_path)

# verify the file wrote the expected outputs
assert result['status'] == 'complete'

model_outs = result['result']
for out in model_outs:
assert out['test_acc'] > 0.90

0 comments on commit 274364d

Please sign in to comment.