Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Make tests faster #24105

Merged
merged 1 commit into from
Jun 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 157 additions & 4 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self,
parent,
batch_size=2,
seq_length=1500,
seq_length=60,
is_training=True,
use_labels=False,
vocab_size=200,
Expand All @@ -107,7 +107,7 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=750,
max_source_positions=30,
max_target_positions=40,
bos_token_id=98,
eos_token_id=98,
Expand Down Expand Up @@ -1538,7 +1538,7 @@ def __init__(
self,
parent,
batch_size=2,
seq_length=3000,
seq_length=60,
is_training=True,
use_labels=True,
hidden_size=16,
Expand All @@ -1549,7 +1549,7 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=1500,
max_source_positions=30,
num_mel_bins=80,
num_conv_layers=1,
suppress_tokens=None,
Expand Down Expand Up @@ -1731,3 +1731,156 @@ def test_model_common_attributes(self):
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass

@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__

if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

fx_model_class = getattr(transformers, fx_model_class_name)

# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False

# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to override this method to ensure that we init the Flax weights with the downsampled sequence length correctly (e.g. pass input_shape=init_shape)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused here. If you look at FlaxWhisperModelTest, there is no such overriding to pass input_shape. However, FlaxWhisperModelTester uses the low number as in this PR.

Why we don't need to pass init_shape in FlaxWhisperModelTest?


# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}

# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}

fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state

# send pytorch model to the correct device
pt_model.to(torch_device)

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)

with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)

fx_outputs_loaded = fx_model_loaded(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)

@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__

if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

fx_model_class = getattr(transformers, fx_model_class_name)

# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False

# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)

# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}

# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}

pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

# make sure weights are tied in PyTorch
pt_model.tie_weights()

# send pytorch model to the correct device
pt_model.to(torch_device)

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)

with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)

# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()

with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)