Skip to content

Commit

Permalink
[Whisper] Make tests faster (#24105)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi authored Jun 20, 2023
1 parent f924df3 commit 6c13444
Showing 1 changed file with 157 additions and 4 deletions.
161 changes: 157 additions & 4 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self,
parent,
batch_size=2,
seq_length=1500,
seq_length=60,
is_training=True,
use_labels=False,
vocab_size=200,
Expand All @@ -107,7 +107,7 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=750,
max_source_positions=30,
max_target_positions=40,
bos_token_id=98,
eos_token_id=98,
Expand Down Expand Up @@ -1538,7 +1538,7 @@ def __init__(
self,
parent,
batch_size=2,
seq_length=3000,
seq_length=60,
is_training=True,
use_labels=True,
hidden_size=16,
Expand All @@ -1549,7 +1549,7 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=1500,
max_source_positions=30,
num_mel_bins=80,
num_conv_layers=1,
suppress_tokens=None,
Expand Down Expand Up @@ -1731,3 +1731,156 @@ def test_model_common_attributes(self):
# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass

@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__

if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

fx_model_class = getattr(transformers, fx_model_class_name)

# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False

# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)

# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}

# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}

fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state

# send pytorch model to the correct device
pt_model.to(torch_device)

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)

with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)

fx_outputs_loaded = fx_model_loaded(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)

@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
init_shape = (1,) + inputs_dict["input_features"].shape[1:]

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
fx_model_class_name = "Flax" + model_class.__name__

if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

fx_model_class = getattr(transformers, fx_model_class_name)

# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False

# load Flax class
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)

# make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

# prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class)

# remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}

# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}

pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

# make sure weights are tied in PyTorch
pt_model.tie_weights()

# send pytorch model to the correct device
pt_model.to(torch_device)

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)

with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)

# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()

with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)

0 comments on commit 6c13444

Please sign in to comment.