Skip to content

Commit

Permalink
fixing TPU tests (#2632)
Browse files Browse the repository at this point in the history
* init

* rename

* tpu_core_idx

* idx 8

* idxs

* @pl_multi_process_test

* assert

* assert

* deamon

* no close

* imort

* msg

* use_single_gpu

* dataset

* idx

* fix idx

* dataset

* format

* add pickable

* typo

* apex

* typo

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* docs

* typo

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* tests

* docs

* docs

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* docs

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
3 people authored Jul 27, 2020
1 parent 84c507c commit 0fe933e
Show file tree
Hide file tree
Showing 23 changed files with 339 additions and 192 deletions.
8 changes: 5 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ references:
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
# still the job hasn't finished, give up and return the starting
# non-zero status code.
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
# Allow time for logs to flush.
sleep 30 && \
sleep 10 && \
echo "JOB_NAME: $job_name" && \
gcloud logging read "resource.type=k8s_container resource.labels.project_id=$GOOGLE_PROJECT_ID resource.labels.location=$GOOGLE_COMPUTE_ZONE resource.labels.cluster_name=$GKE_CLUSTER resource.labels.namespace_name=default resource.labels.pod_name:$job_name" --limit 10000000 --order asc --format 'value(textPayload)' --project=$GOOGLE_PROJECT_ID > /tmp/full_output.txt && \
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
Expand Down Expand Up @@ -101,7 +102,8 @@ jobs:
docker:
- image: circleci/python:3.7
environment:
- MAX_CHECKS: 60
- MAX_CHECKS: 240
- CHECK_SPEEP: 5
steps:
- checkout
- go/install
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ jobs:
# TODO: temporary fix till hanging jobs on macOS for py38 is resolved
- python-version: 3.8
os: macOS-10.15
# TODO: temporary fix till pyYaml can be installed, see: https://github.com/actions/setup-python/issues/114
- python-version: 3.7
os: ubuntu-18.04
requires: 'minimal'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 25
Expand Down
12 changes: 7 additions & 5 deletions .github/workflows/tpu-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ env:
GKE_CLUSTER: lightning-cluster
GKE_ZONE: us-central1-a
IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
MAX_CHECKS: 240
CHECK_SPEEP: 5

jobs:
setup-build-publish-deploy:
Expand Down Expand Up @@ -82,17 +84,17 @@ jobs:
job_name=${job_name% created} && \
echo "Waiting on kubernetes job: $job_name in cluster: $GKE_CLUSTER" && \
i=0 && \
# 30 checks spaced 30s apart = 900s total.
max_checks=30 && \
# 60 checks spaced 30s apart = 900s total.
status_code=2 && \
# Check on the job periodically. Set the status code depending on what
# happened to the job in Kubernetes. If we try max_checks times and
# happened to the job in Kubernetes. If we try MAX_CHECKS times and
# still the job hasn't finished, give up and return the starting
# non-zero status code.
while [ $i -lt $max_checks ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "." ; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
# Allow time for logs to flush.
sleep 60 && \
sleep 10 && \
echo "JOB_NAME: $job_name" && \
echo "GKE_CLUSTER: $GKE_CLUSTER" && \
echo "GKE_ZONE: $GKE_ZONE" && \
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632))

## [0.8.5] - 2020-07-09

### Added
Expand Down
2 changes: 0 additions & 2 deletions docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms


Quick Start
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import metrics
from pytorch_lightning.core.step_result import TrainResult, EvalResult

__all__ = [
'Trainer',
Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,27 @@ class DDPSpawnBackend(object):

def __init__(self, trainer):
self.trainer = trainer
self.q = None
self.mp_queue = None

def setup(self):
self.trainer.set_random_port()

# pass in a state q
smp = mp.get_context('spawn')
self.q = smp.SimpleQueue()
self.mp_queue = smp.SimpleQueue()

def train(self, model, nprocs):
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.q, model,))
mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.mp_queue, model,))

def teardown(self, model):
# restore main state with best weights
best_path = self.q.get()
results = self.q.get()
last_path = self.q.get()
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path is not None and not self.trainer.testing:
Expand All @@ -59,13 +60,13 @@ def teardown(self, model):
self.trainer.model = model
return results

def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp
Args:
process_idx:
q:
mp_queue: multiprocessing queue
model:
is_master:
proc_offset:
Expand Down Expand Up @@ -166,7 +167,7 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_ddp_spawn_state_on_fit_end(model, q, results)
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerator_backends/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch

from pytorch_lightning.core import LightningModule
try:
from apex import amp
except ImportError:
Expand Down Expand Up @@ -45,15 +46,15 @@ def setup(self, model):

# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.trainer.use_amp and not native_amp_available:
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
model = self._setup_nvidia_apex(model)
return model

def train(self, model):
results = self.trainer.run_pretrain_routine(model)
return results

def _setup_nvidia_apex(self, model):
def _setup_nvidia_apex(self, model: LightningModule):
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
105 changes: 68 additions & 37 deletions pytorch_lightning/accelerator_backends/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
# limitations under the License.

import os

import torch
import torch.multiprocessing as mp

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log


try:
import torch_xla
import torch_xla.core.xla_model as xm
Expand All @@ -33,31 +37,52 @@ class TPUBackend(object):
def __init__(self, trainer):
self.trainer = trainer
self.start_method = None
self.mp_queue = None

def setup(self):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')
raise MisconfigurationException('PyTorch XLA not installed.')

# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
self.start_method = 'fork'

# pass in a state q
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()

def teardown(self, model):
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# COLAB_GPU is an env var available by default in Colab environments.
self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

def teardown(self):
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
return results

def train(self, model):
def train(self, model: LightningModule):
self.trainer.model = model

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model)
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model,),
args=(model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
Expand All @@ -71,63 +96,69 @@ def __load_weights_on_main_process(self):

self.trainer.model = model

def tpu_train_in_process(self, tpu_core_idx, model):
def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None):
"""
Here we are inside each individual process
"""
if not self.trainer.testing:
self.trainer.setup('fit')
if not trainer:
trainer = self.trainer
if not trainer.testing:
trainer.setup('fit')
model.setup('fit')

# setup TPU training
self.__setup_tpu_training(model)
self.__setup_tpu_training(model, trainer)

# Run the pretrain routine
self.trainer.run_pretrain_routine(model)
results = trainer.run_pretrain_routine(model)

# save weights at the end of training
self.__save_end_of_training_weights(model)
self.__save_end_of_training_weights(model, trainer)

def __save_end_of_training_weights(self, model):
# persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if self.trainer.on_colab_kaggle:
if trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.trainer.save_spawn_weights(model)
trainer.save_spawn_weights(model)

def __setup_tpu_training(self, model):
def __setup_tpu_training(self, model: LightningModule, trainer):
# use the default device from the process
tpu_device = xm.xla_device()
# tpu_device = xm.xla_device()

# if given an ordinal device, use this as the device
if self.trainer.tpu_id is not None:
tpu_device = xm.xla_device(self.trainer.tpu_id)

if trainer.tpu_id is not None:
tpu_device = xm.xla_device(trainer.tpu_id)
else:
tpu_device = xm.xla_device()
# track the device and move model to it
self.trainer._device = tpu_device
model.to(self.trainer._device)
trainer._device = tpu_device
model.to(trainer._device)

# get the appropriate tpu ranks
self.trainer.tpu_local_core_rank = xm.get_local_ordinal()
self.trainer.tpu_global_core_rank = xm.get_ordinal()
trainer.tpu_local_core_rank = xm.get_local_ordinal()
trainer.tpu_global_core_rank = xm.get_ordinal()

# avoid duplicating progress bar
if self.trainer.tpu_global_core_rank != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

self.trainer.global_rank = self.trainer.tpu_local_core_rank
rank_zero_only.rank = self.trainer.global_rank
trainer.global_rank = trainer.tpu_local_core_rank
rank_zero_only.rank = trainer.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
trainer.optimizers = optimizers
trainer.lr_schedulers = lr_schedulers
trainer.optimizer_frequencies = optimizer_frequencies

# init 16 bit for TPU
if self.trainer.precision == 16:
if trainer.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

log.info(f'INIT TPU local core: {self.trainer.tpu_local_core_rank},'
f' global rank: {self.trainer.tpu_global_core_rank}')
log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},'
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')
6 changes: 5 additions & 1 deletion pytorch_lightning/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,9 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.lightning import LightningModule

__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
__all__ = [
'LightningDataModule',
'LightningModule',
'data_loader',
]
# __call__ = __all__
2 changes: 0 additions & 2 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from functools import wraps
from typing import Callable

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn

Expand Down
Loading

0 comments on commit 0fe933e

Please sign in to comment.