From 160a4584a3e19b12dae8528f2d03b6aa44d81ce1 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Thu, 8 Jun 2023 11:42:45 +0100 Subject: [PATCH] [Whisper] Make tests faster --- tests/models/whisper/test_modeling_whisper.py | 161 +++++++++++++++++- 1 file changed, 157 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 3eee5ad4967c..2bb34aa1af7b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -95,7 +95,7 @@ def __init__( self, parent, batch_size=2, - seq_length=1500, + seq_length=60, is_training=True, use_labels=False, vocab_size=200, @@ -107,7 +107,7 @@ def __init__( hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=20, - max_source_positions=750, + max_source_positions=30, max_target_positions=40, bos_token_id=98, eos_token_id=98, @@ -1538,7 +1538,7 @@ def __init__( self, parent, batch_size=2, - seq_length=3000, + seq_length=60, is_training=True, use_labels=True, hidden_size=16, @@ -1549,7 +1549,7 @@ def __init__( hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=20, - max_source_positions=1500, + max_source_positions=30, num_mel_bins=80, num_conv_layers=1, suppress_tokens=None, @@ -1731,3 +1731,156 @@ def test_model_common_attributes(self): # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings def test_resize_tokens_embeddings(self): pass + + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)