Skip to content

Commit

Permalink
big test refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed May 4, 2020
1 parent d84253e commit c113844
Show file tree
Hide file tree
Showing 2 changed files with 456 additions and 623 deletions.
9 changes: 5 additions & 4 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_attention_outputs(self):
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
Expand Down Expand Up @@ -648,8 +649,8 @@ def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]

# max length of input_ids should be < max_length
input_ids = input_ids[..., :10]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]

# iterate over all generative models
for model_class in self.all_generative_model_classes:
Expand Down Expand Up @@ -693,8 +694,8 @@ def test_lm_head_model_random_beam_search_generate(self):
torch_device
)

# max length of input_ids should be < max_length
input_ids = input_ids[..., :10]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]

for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
Expand Down
Loading

0 comments on commit c113844

Please sign in to comment.