diff --git a/allennlp/tests/training/trainer_test.py b/allennlp/tests/training/trainer_test.py index 1dcb1b85c17..0a03a4b59e2 100644 --- a/allennlp/tests/training/trainer_test.py +++ b/allennlp/tests/training/trainer_test.py @@ -17,7 +17,8 @@ from allennlp.common.params import Params from allennlp.models.simple_tagger import SimpleTagger from allennlp.data.iterators import BasicIterator -from allennlp.data.dataset_readers import SequenceTaggingDatasetReader +from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader +from allennlp.models.archival import load_archive from allennlp.models.model import Model @@ -96,6 +97,11 @@ def test_trainer_can_run_cuda(self): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need multiple GPUs.") def test_trainer_can_run_multiple_gpu(self): + wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/' + wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir, + dpd_output_directory=wikitables_dir + 'dpd_output/') + wikitables_instances = wikitables_reader.read(self.FIXTURES_ROOT / 'data' / 'wikitables' / + 'sample_data.examples') class MetaDataCheckWrapper(Model): """ @@ -129,6 +135,21 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint assert 'peak_gpu_1_memory_MB' in metrics assert isinstance(metrics['peak_gpu_1_memory_MB'], float) + @pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need multiple GPUs.") + def test_production_rule_field_with_multiple_gpus(self): + wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/' + wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir, + dpd_output_directory=wikitables_dir + 'dpd_output/') + instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples') + archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz' + model = load_archive(archive_path).model + + multigpu_iterator = BasicIterator(batch_size=4) + multigpu_iterator.index_with(model.vocab) + trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1]) + metrics = trainer.train() + def test_trainer_can_resume_training(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances,