From 2fddd4487e5395b2e66a06f2ab05387094456626 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 2 May 2020 11:34:27 +0200 Subject: [PATCH] include reformer in generation tests --- tests/test_activations.py | 2 ++ tests/test_modeling_common.py | 14 +++++++++++--- tests/test_modeling_reformer.py | 8 +++++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_activations.py b/tests/test_activations.py index d6cbc1f9e564..79e9eec0184c 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -22,6 +22,8 @@ def test_get_activation(self): get_activation("swish") get_activation("relu") get_activation("tanh") + get_activation("gelu_new") + get_activation("gelu_fast") with self.assertRaises(KeyError): get_activation("bogus") with self.assertRaises(KeyError): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d44572cf55d..7156da376f10 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -127,7 +127,7 @@ def test_attention_outputs(self): encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) chunk_length = getattr(self.model_tester, "chunk_length", None) if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - chunk_length = self.model_tester.chunk_length * config.num_hashes + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes for model_class in self.all_model_classes: config.output_attentions = True @@ -144,7 +144,7 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], ) else: self.assertListEqual( @@ -187,7 +187,7 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length], + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], ) else: self.assertListEqual( @@ -648,9 +648,13 @@ def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] + # max length of input_ids should be < max_length + input_ids = input_ids[..., :10] + # iterate over all generative models for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined, model needs input_ids @@ -689,8 +693,12 @@ def test_lm_head_model_random_beam_search_generate(self): torch_device ) + # max length of input_ids should be < max_length + input_ids = input_ids[..., :10] + for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) + model.eval() if config.bos_token_id is None: # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1 diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c7ed0e03e11e..460e698526cc 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -37,6 +37,7 @@ @require_torch class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -46,7 +47,7 @@ def __init__( self, parent, batch_size=13, - seq_length=16, + seq_length=32, is_training=True, is_decoder=False, use_input_mask=True, @@ -68,7 +69,7 @@ def __init__( axial_norm_std=1.0, layer_norm_eps=1e-12, axial_pos_embds=True, - axial_pos_shape=[4, 4], + axial_pos_shape=[4, 8], axial_pos_embds_dim=[16, 16], attn_layers=["local", "local", "local", "local"], pad_token_id=0, @@ -344,6 +345,7 @@ def test_reformer_model_fp16_forward(self): @require_torch class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -377,7 +379,7 @@ def __init__( axial_norm_std=1.0, layer_norm_eps=1e-12, axial_pos_embds=True, - axial_pos_shape=[2, 8], + axial_pos_shape=[4, 8], axial_pos_embds_dim=[16, 48], attn_layers=["lsh", "lsh", "lsh", "lsh"], pad_token_id=0,