Skip to content

Commit

Permalink
include reformer in generation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed May 2, 2020
1 parent a6272e5 commit 2fddd44
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 2 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_get_activation(self):
get_activation("swish")
get_activation("relu")
get_activation("tanh")
get_activation("gelu_new")
get_activation("gelu_fast")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
Expand Down
14 changes: 11 additions & 3 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_attention_outputs(self):
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
chunk_length = self.model_tester.chunk_length * config.num_hashes
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes

for model_class in self.all_model_classes:
config.output_attentions = True
Expand All @@ -144,7 +144,7 @@ def test_attention_outputs(self):
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_attention_outputs(self):
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, chunk_length, encoder_seq_length, encoder_key_length],
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
Expand Down Expand Up @@ -648,9 +648,13 @@ def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]

# max length of input_ids should be < max_length
input_ids = input_ids[..., :10]

# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()

if config.bos_token_id is None:
# if bos token id is not defined, model needs input_ids
Expand Down Expand Up @@ -689,8 +693,12 @@ def test_lm_head_model_random_beam_search_generate(self):
torch_device
)

# max length of input_ids should be < max_length
input_ids = input_ids[..., :10]

for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()

if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
Expand Down
8 changes: 5 additions & 3 deletions tests/test_modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
@require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
Expand All @@ -46,7 +47,7 @@ def __init__(
self,
parent,
batch_size=13,
seq_length=16,
seq_length=32,
is_training=True,
is_decoder=False,
use_input_mask=True,
Expand All @@ -68,7 +69,7 @@ def __init__(
axial_norm_std=1.0,
layer_norm_eps=1e-12,
axial_pos_embds=True,
axial_pos_shape=[4, 4],
axial_pos_shape=[4, 8],
axial_pos_embds_dim=[16, 16],
attn_layers=["local", "local", "local", "local"],
pad_token_id=0,
Expand Down Expand Up @@ -344,6 +345,7 @@ def test_reformer_model_fp16_forward(self):
@require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
Expand Down Expand Up @@ -377,7 +379,7 @@ def __init__(
axial_norm_std=1.0,
layer_norm_eps=1e-12,
axial_pos_embds=True,
axial_pos_shape=[2, 8],
axial_pos_shape=[4, 8],
axial_pos_embds_dim=[16, 48],
attn_layers=["lsh", "lsh", "lsh", "lsh"],
pad_token_id=0,
Expand Down

0 comments on commit 2fddd44

Please sign in to comment.