Skip to content

Commit

Permalink
block multi-GPU processing with PyTorch 1.2+ to avoid [pytorch/pytorc…
Browse files Browse the repository at this point in the history
  • Loading branch information
Emrys365 committed Nov 12, 2019
1 parent d9bcb3f commit df4e9c0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
7 changes: 6 additions & 1 deletion espnet/bin/asr_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""Automatic speech recognition model training script."""

from distutils.version import LooseVersion
import logging
import multiprocessing as mp
import os
Expand All @@ -19,6 +19,8 @@
from espnet.utils.cli_utils import strtobool
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES

is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion('1.2')


# NOTE: you need this func to generate our sphinx doc
def get_parser(parser=None, required=True):
Expand Down Expand Up @@ -314,6 +316,9 @@ def main(cmd_args):
else:
ngpu = len(p.stderr.decode().split('\n')) - 1
else:
if is_torch_1_2_plus:
assert args.ngpu == 1, "There are some bugs with multi-GPU processing in PyTorch 1.2+" \
" (see https://github.com/pytorch/pytorch/issues/21108)"
ngpu = args.ngpu
logging.info(f"ngpu: {ngpu}")

Expand Down
1 change: 1 addition & 0 deletions espnet/nets/pytorch_backend/e2e_tts_tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _make_masks(ilens, olens):
Args:
ilens (LongTensor or List): Batch of lengths (B,).
olens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor indicating non-padded part.
Expand Down

0 comments on commit df4e9c0

Please sign in to comment.