Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Remove scattering for multi-GPU training. #2200

Merged
merged 87 commits into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
5d179a4
Transformer ELMo
brendan-ai2 Nov 21, 2018
2db75b4
wip
brendan-ai2 Dec 1, 2018
f7deed3
Add bidirectional transformer token embedder
brendan-ai2 Dec 4, 2018
c9de1ec
transformer elmo config template
brendan-ai2 Dec 4, 2018
634b4a2
MORE
brendan-ai2 Dec 4, 2018
e4a7b51
Works
brendan-ai2 Dec 5, 2018
9eb6e46
Add broken layer norm.
brendan-ai2 Dec 5, 2018
ac425a4
Address some more comments
brendan-ai2 Dec 5, 2018
f203cde
Merge branch 'lm_without_dataset_modifications_2' into lm_without_dat…
brendan-ai2 Dec 5, 2018
bde39fe
Fix for vidurj
brendan-ai2 Dec 5, 2018
4b3a81c
easy feedback
brendan-ai2 Dec 10, 2018
595b668
Fix norm issue
brendan-ai2 Dec 10, 2018
731e69c
Rename
brendan-ai2 Dec 10, 2018
4522f1c
Start and end tokens in reader
brendan-ai2 Dec 10, 2018
d091cc8
comment fix
brendan-ai2 Dec 10, 2018
971e600
fixes
brendan-ai2 Dec 10, 2018
24b763b
style
brendan-ai2 Dec 10, 2018
71e2cce
fix docs
brendan-ai2 Dec 10, 2018
01a111a
Merge branch 'master' into lm_without_dataset_modifications_2
brendan-ai2 Dec 10, 2018
975060a
Merge branch 'master' into lm_without_dataset_modifications_2
brendan-ai2 Dec 13, 2018
100f07f
Merge branch 'lm_without_dataset_modifications_2' into lm_without_dat…
brendan-ai2 Dec 13, 2018
5dcd700
cleanup
brendan-ai2 Dec 14, 2018
87e6241
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 14, 2018
4b3ce38
Bidirectional fixture
brendan-ai2 Dec 14, 2018
1338afb
Test
brendan-ai2 Dec 14, 2018
fa86367
cleanup
brendan-ai2 Dec 14, 2018
7cd29aa
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 14, 2018
f6e57d1
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 14, 2018
6ec4a6e
works
brendan-ai2 Dec 16, 2018
d7c0208
Model file
brendan-ai2 Dec 16, 2018
c54534d
update parser config
brendan-ai2 Dec 16, 2018
16ca024
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 16, 2018
a8e8eb6
fixes
brendan-ai2 Dec 16, 2018
53a283c
formatting
brendan-ai2 Dec 16, 2018
75f03fd
Type fixes
brendan-ai2 Dec 16, 2018
e8ad0c6
Renames
brendan-ai2 Dec 16, 2018
8cfe033
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 16, 2018
13f6d83
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
726cf13
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 17, 2018
8dacca9
another test, jsonnet improvements
brendan-ai2 Dec 17, 2018
c6694b9
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
9e617b2
Merge branch 'lm_without_dataset_modifications_3' of github.com:brend…
brendan-ai2 Dec 17, 2018
468300f
docs
brendan-ai2 Dec 17, 2018
30894d0
Merge branch 'lm_without_dataset_modifications_3' into lm_train_fixes
brendan-ai2 Dec 17, 2018
e79119f
Drop scatter
brendan-ai2 Dec 17, 2018
724cf89
Potentially works? On one shard at least.
brendan-ai2 Dec 17, 2018
219d026
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 17, 2018
8d81d3f
Fix
brendan-ai2 Dec 17, 2018
ae8be54
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 17, 2018
2cf180e
Added failing test case
matt-gardner Dec 18, 2018
f68e647
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 18, 2018
e0d71c4
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 18, 2018
2d06b4b
Merge branch 'lm_train_fixes' of github.com:brendan-ai2/allennlp into…
brendan-ai2 Dec 18, 2018
7662c0f
hacks
brendan-ai2 Dec 18, 2018
a48e494
more hacks
brendan-ai2 Dec 18, 2018
560f99f
respond to feedback
brendan-ai2 Dec 19, 2018
4da2ff3
Add todo
brendan-ai2 Dec 19, 2018
50c1e15
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 19, 2018
f93b6dc
lint
brendan-ai2 Dec 19, 2018
4f667ab
Add todos
brendan-ai2 Dec 20, 2018
0cbb57a
Merge branch 'master' into lm_without_dataset_modifications_3
brendan-ai2 Dec 20, 2018
64bbfbd
Merge branch 'lm_without_dataset_modifications_3' into lm_train_fixes
brendan-ai2 Dec 21, 2018
255045a
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 21, 2018
862b87b
cleanups
brendan-ai2 Dec 21, 2018
cc51bbc
Merge branch 'master' into lm_train_fixes
brendan-ai2 Dec 21, 2018
eb8419c
Fix batch size
brendan-ai2 Dec 21, 2018
4856e77
Try for more
brendan-ai2 Dec 22, 2018
e329e72
3k samples
brendan-ai2 Dec 22, 2018
9104044
2k
brendan-ai2 Dec 22, 2018
4a133b1
log grad stats and learning rate
brendan-ai2 Jan 11, 2019
86c76fb
Merge branch 'pr-2199' into lm_train_fixes
brendan-ai2 Jan 15, 2019
8844663
merge
brendan-ai2 Jan 17, 2019
f4726a6
fix
brendan-ai2 Jan 17, 2019
2323bbf
Fix
brendan-ai2 Jan 17, 2019
98629f1
merge
brendan-ai2 Jan 17, 2019
a68db07
drop some logging
brendan-ai2 Jan 17, 2019
6eda737
stash pop
brendan-ai2 Jan 17, 2019
63644f1
fixes
brendan-ai2 Jan 17, 2019
79ed01a
cleanup
brendan-ai2 Jan 17, 2019
d3e4921
Add todos
brendan-ai2 Jan 17, 2019
34b2adf
fixes
brendan-ai2 Jan 18, 2019
c7a5a96
merge
brendan-ai2 Jan 18, 2019
7c19f04
fixes, delete ScatterableList, scatter_kwargs, etc.
brendan-ai2 Jan 18, 2019
95f5804
More cleanup
brendan-ai2 Jan 18, 2019
ee5df46
cleanup
brendan-ai2 Jan 18, 2019
2e6f990
drop no-op changes
brendan-ai2 Jan 18, 2019
12e62d4
Merge branch 'master' into lm_train_fixes
brendan-ai2 Jan 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment
from allennlp.common.util import prepare_environment, lazy_groups_of
from allennlp.data import Vocabulary, DataIterator
from allennlp.models import Model
from allennlp.training import Trainer
Expand Down Expand Up @@ -263,8 +263,11 @@ def search_learning_rate(trainer: Trainer,

trainer.model.train()

train_generator = trainer.iterator(trainer.train_data,
shuffle=trainer.shuffle)
num_gpus = len(trainer._cuda_devices) # pylint: disable=protected-access

raw_train_generator = trainer.iterator(trainer.train_data,
shuffle=trainer.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
train_generator_tqdm = Tqdm.tqdm(train_generator,
total=num_batches)

Expand All @@ -276,7 +279,7 @@ def search_learning_rate(trainer: Trainer,
else:
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)

for i, batch in enumerate(train_generator_tqdm):
for i, batch_group in enumerate(train_generator_tqdm):

if linear_steps:
current_lr = start_lr + (lr_update_factor * i)
Expand All @@ -287,7 +290,7 @@ def search_learning_rate(trainer: Trainer,
param_group['lr'] = current_lr

trainer.optimizer.zero_grad()
loss = trainer.batch_loss(batch, for_training=True)
loss = trainer.batch_loss(batch_group, for_training=True)
loss.backward()
loss = loss.detach().cpu().item()

Expand Down
95 changes: 0 additions & 95 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Various utilities that don't fit anwhere else.
"""
from ctypes import sizeof, c_void_p, c_int64, cast, py_object, c_uint64
from itertools import zip_longest, islice
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator, Union
import importlib
Expand All @@ -14,8 +13,6 @@
import os
import re

from torch.nn.parallel._functions import Scatter

try:
import resource
except ImportError:
Expand Down Expand Up @@ -392,98 +389,6 @@ def from_list(strings):
# TODO(brendanr): Determine why mypy can't tell that this matches the Union.
return int(cuda_device) # type: ignore

class ScatterableList(list):
"""
A normal list, but one that should be scattered like a tensor.
"""

# Ensure pointers will fit in a torch.LongTensor. "64 bits ought to be enough for anybody."
assert sizeof(c_void_p) <= sizeof(c_int64)

def to_pointer_tensor(self) -> torch.LongTensor:
"""
Converts the elements to pointers, casts them to ``int64`` and then returns them in a tensor. This cast is
important as ``id`` gives back unsigned integers while ``torch.LongTensor`` is signed.

See:
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Python/bltinmodule.c#L1118
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Objects/longobject.c#L990
"""
pointers = [c_int64(id(element)).value for element in self]
return torch.LongTensor(pointers)

@classmethod
def from_pointer_tensor(cls, pointers: torch.LongTensor) -> list:
"""
The inverse of ``to_pointer_tensor`` except that a plain ``list`` is returned. Typically this will be
called on a single chunk of the scattered tensor.

Parameters
----------
pointers : ``torch.LongTensor``, required.
A tensor of shape (list_length,).
"""
return [cast(c_uint64(pointer.item()).value, py_object).value for pointer in pointers]

def scatter(inputs, target_gpus, dim=0):
"""
Slices tensors and ScatterableLists into approximately equal chunks and distributes them across given GPUs.
Duplicates references to objects that are not tensors or ScatterableLists.

Adapted from `scatter` at:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L5-L30.

Please see the LICENSE and NOTICE files as well:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, ScatterableList):
# In order to have precisely the same method of scattering as PyTorch we scatter
# a tensor of pointers.
pointers = scatter_map(obj.to_pointer_tensor())
# Then we reconstruct the lists from the pointer tensors.
return [obj.from_pointer_tensor(chunk) for chunk in pointers]
if isinstance(obj, tuple) and obj:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and obj:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and obj:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for _ in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary.

Adapted from `scatter_kwargs` at:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L33-L43

Please see the LICENSE and NOTICE files as well:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
"""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
frozen_parameter_names = []
tunable_parameter_names = []
Expand Down
5 changes: 2 additions & 3 deletions allennlp/data/fields/metadata_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from overrides import overrides

from allennlp.common.util import ScatterableList
from allennlp.data.fields.field import DataArray, Field


Expand Down Expand Up @@ -61,8 +60,8 @@ def empty_field(self) -> 'MetadataField':

@classmethod
@overrides
def batch_tensors(cls, tensor_list: List[DataArray]) -> ScatterableList: # type: ignore
return ScatterableList(tensor_list)
def batch_tensors(cls, tensor_list: List[DataArray]) -> List[DataArray]: # type: ignore
return tensor_list


def __str__(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions allennlp/data/fields/production_rule_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from overrides import overrides

from allennlp.common.util import ScatterableList
from allennlp.data.fields.field import Field
from allennlp.data.vocabulary import Vocabulary

Expand Down Expand Up @@ -114,9 +113,9 @@ def empty_field(self): # pylint: disable=no-self-use
return ProductionRuleField(rule='', is_global_rule=False)

@overrides
def batch_tensors(self, tensor_list: List[ProductionRule]) -> ScatterableList: # type: ignore
def batch_tensors(self, tensor_list: List[ProductionRule]) -> List[ProductionRule]: # type: ignore
# pylint: disable=no-self-use
return ScatterableList(tensor_list)
return tensor_list

def __str__(self) -> str:
return f"ProductionRuleField with rule: {self.rule} (is_global_rule: " \
Expand Down
3 changes: 3 additions & 0 deletions allennlp/data/iterators/bucket_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Itera
if excess:
batches.append(Batch(excess))

# TODO(brendanr): Add multi-GPU friendly grouping, i.e. group
# num_gpu batches together, shuffle and then expand the groups.
# This guards against imbalanced batches across GPUs.
move_to_front = self._biggest_batch_first and len(batches) > 1
if move_to_front:
# We'll actually pop the last _two_ batches, because the last one might not be full.
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/iterators/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __call__(self,
tensor_dicts = self._cache[key]

if shuffle:
# TODO(brendanr): How can we handle this shuffle in a way
# that respects multi-GPU friendly grouping?
random.shuffle(tensor_dicts)
for tensor_dict in tensor_dicts:
if self._track_epoch:
Expand Down
8 changes: 4 additions & 4 deletions allennlp/tests/models/simple_tagger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_regularization(self):
training_batch = next(iterator(self.instances, num_epochs=1))
validation_batch = next(iterator(self.instances, num_epochs=1))

training_loss = trainer.batch_loss(training_batch, for_training=True).item()
validation_loss = trainer.batch_loss(validation_batch, for_training=False).item()
training_loss = trainer.batch_loss([training_batch], for_training=True).item()
validation_loss = trainer.batch_loss([validation_batch], for_training=False).item()

# Training loss should have the regularization penalty, but validation loss should not.
numpy.testing.assert_almost_equal(training_loss, validation_loss)
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_regularization(self):
training_batch = next(self.iterator(self.instances, num_epochs=1))
validation_batch = next(self.iterator(self.instances, num_epochs=1))

training_loss = self.trainer.batch_loss(training_batch, for_training=True).data
validation_loss = self.trainer.batch_loss(validation_batch, for_training=False).data
training_loss = self.trainer.batch_loss([training_batch], for_training=True).data
validation_loss = self.trainer.batch_loss([validation_batch], for_training=False).data

# Training loss should have the regularization penalty, but validation loss should not.
assert (training_loss != validation_loss).all()
Expand Down
19 changes: 18 additions & 1 deletion allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from allennlp.common.params import Params
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.data.iterators import BasicIterator
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader
from allennlp.models.archival import load_archive
from allennlp.models.model import Model


Expand Down Expand Up @@ -133,6 +134,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint
assert 'peak_gpu_1_memory_MB' in metrics
assert isinstance(metrics['peak_gpu_1_memory_MB'], int)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need multiple GPUs.")
def test_production_rule_field_with_multiple_gpus(self):
wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir,
dpd_output_directory=wikitables_dir + 'dpd_output/')
instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples')
archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
model = load_archive(archive_path).model
model.cuda()

multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(model.vocab)
trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1])
trainer.train()

def test_trainer_can_resume_training(self):
trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances,
Expand Down
47 changes: 29 additions & 18 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import logging
import math
import os
import time
import re
Expand All @@ -13,10 +14,10 @@
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import (dump_metrics, gpu_memory_mb, parse_cuda_device, peak_memory_mb,
get_frozen_and_tunable_parameter_names)
get_frozen_and_tunable_parameter_names, lazy_groups_of)
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.model import Model
from allennlp.nn import util as nn_util
Expand Down Expand Up @@ -216,14 +217,16 @@ def __init__(self,
def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)

def batch_loss(self, batch: torch.Tensor, for_training: bool) -> torch.Tensor:
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
"""
Does a forward pass on the given batch and returns the ``loss`` value in the result.
Does a forward pass on the given batches and returns the ``loss`` value in the result.
If ``for_training`` is `True` also applies regularization penalty.
"""
if self._multiple_gpu:
output_dict = training_util.data_parallel(batch, self.model, self._cuda_devices)
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
else:
assert len(batch_group) == 1
batch = batch_group[0]
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
output_dict = self.model(**batch)

Expand Down Expand Up @@ -255,11 +258,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
# Set the model to "train" mode.
self.model.train()

num_gpus = len(self._cuda_devices)

# Get tqdm for the training batches
train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
num_training_batches = self.iterator.get_num_batches(self.train_data)
raw_train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus)
self._last_log = time.time()
last_save_time = time.time()

Expand All @@ -269,18 +275,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())


logger.info("Training")
train_generator_tqdm = Tqdm.tqdm(train_generator,
total=num_training_batches)
cumulative_batch_size = 0
for batch in train_generator_tqdm:
for batch_group in train_generator_tqdm:
batches_this_epoch += 1
self._batch_num_total += 1
batch_num_total = self._batch_num_total

self.optimizer.zero_grad()

loss = self.batch_loss(batch, for_training=True)
loss = self.batch_loss(batch_group, for_training=True)

if torch.isnan(loss):
raise ValueError("nan loss encountered")

Expand Down Expand Up @@ -329,7 +337,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self._tensorboard.log_histograms(self.model, histogram_parameters)

if self._log_batch_size_period:
cur_batch = training_util.get_batch_size(batch)
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
cumulative_batch_size += cur_batch
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
average = cumulative_batch_size/batches_this_epoch
Expand Down Expand Up @@ -365,17 +373,20 @@ def _validation_loss(self) -> Tuple[float, int]:
else:
val_iterator = self.iterator

val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
num_validation_batches = val_iterator.get_num_batches(self._validation_data)
num_gpus = len(self._cuda_devices)

raw_val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
num_validation_batches = math.ceil(val_iterator.get_num_batches(self._validation_data)/num_gpus)
val_generator_tqdm = Tqdm.tqdm(val_generator,
total=num_validation_batches)
batches_this_epoch = 0
val_loss = 0
for batch in val_generator_tqdm:
for batch_group in val_generator_tqdm:

loss = self.batch_loss(batch, for_training=False)
loss = self.batch_loss(batch_group, for_training=False)
if loss is not None:
# You shouldn't necessarily have to compute a loss for validation, so we allow for
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
Expand Down
Loading