diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7156da376f10..a9a2b7c35ddb 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -141,6 +141,7 @@ def test_attention_outputs(self): self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: self.assertListEqual( list(attentions[0].shape[-4:]), @@ -648,8 +649,8 @@ def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] - # max length of input_ids should be < max_length - input_ids = input_ids[..., :10] + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] # iterate over all generative models for model_class in self.all_generative_model_classes: @@ -693,8 +694,8 @@ def test_lm_head_model_random_beam_search_generate(self): torch_device ) - # max length of input_ids should be < max_length - input_ids = input_ids[..., :10] + # make sure that input_ids is at most of size 15 + input_ids = input_ids[..., :15] for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 460e698526cc..c79b212a8c59 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -34,646 +34,371 @@ import torch -@require_torch -class ReformerLocalAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () - test_pruning = False - test_headmasking = False - test_torchscript = False +class ReformerModelTester: + def __init__( + self, + parent, + batch_size=None, + seq_length=None, + is_training=None, + is_decoder=None, + use_input_mask=None, + vocab_size=None, + attention_head_size=None, + hidden_size=None, + num_attention_heads=None, + local_attn_chunk_length=None, + local_num_chunks_before=None, + local_num_chunks_after=None, + num_buckets=None, + num_hashes=1, + lsh_attn_chunk_length=None, + lsh_num_chunks_before=None, + lsh_num_chunks_after=None, + chunk_size_lm_head=None, + chunk_size_feed_forward=None, + feed_forward_size=None, + hidden_act=None, + hidden_dropout_prob=None, + local_attention_probs_dropout_prob=None, + lsh_attention_probs_dropout_prob=None, + max_position_embeddings=None, + initializer_range=None, + axial_norm_std=None, + layer_norm_eps=None, + axial_pos_embds=None, + axial_pos_shape=None, + axial_pos_embds_dim=None, + attn_layers=None, + pad_token_id=None, + eos_token_id=None, + scope=None, + hash_seed=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.is_decoder = is_decoder + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.attention_head_size = attention_head_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = len(attn_layers) + self.local_attn_chunk_length = local_attn_chunk_length + self.local_num_chunks_after = local_num_chunks_after + self.local_num_chunks_before = local_num_chunks_before + self.num_hashes = num_hashes + self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets + self.lsh_attn_chunk_length = lsh_attn_chunk_length + self.lsh_num_chunks_after = lsh_num_chunks_after + self.lsh_num_chunks_before = lsh_num_chunks_before + self.hidden_act = hidden_act + self.feed_forward_size = feed_forward_size + self.hidden_dropout_prob = hidden_dropout_prob + self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob + self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.axial_pos_embds = axial_pos_embds + self.axial_pos_shape = tuple(axial_pos_shape) + self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) + self.axial_norm_std = axial_norm_std + self.chunk_size_lm_head = chunk_size_lm_head + self.chunk_size_feed_forward = chunk_size_feed_forward + self.scope = scope + self.attn_layers = attn_layers + self.pad_token_id = pad_token_id + self.hash_seed = hash_seed + + attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length + num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after + num_chunks_before = local_num_chunks_before if local_num_chunks_before is not None else lsh_num_chunks_before + + self.encoder_seq_length = seq_length // attn_chunk_length + (self.seq_length % attn_chunk_length != 0) + self.key_length = (num_chunks_before + num_chunks_after + 1) * attn_chunk_length + self.chunk_length = attn_chunk_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = ReformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + feed_forward_size=self.feed_forward_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, + lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + is_decoder=self.is_decoder, + axial_pos_embds=self.axial_pos_embds, + axial_pos_shape=self.axial_pos_shape, + axial_pos_embds_dim=self.axial_pos_embds_dim, + local_attn_chunk_length=self.local_attn_chunk_length, + local_num_chunks_after=self.local_num_chunks_after, + local_num_chunks_before=self.local_num_chunks_before, + num_hashes=self.num_hashes, + num_buckets=self.num_buckets, + lsh_attn_chunk_length=self.lsh_attn_chunk_length, + lsh_num_chunks_after=self.lsh_num_chunks_after, + lsh_num_chunks_before=self.lsh_num_chunks_before, + attn_layers=self.attn_layers, + pad_token_id=self.pad_token_id, + hash_seed=self.hash_seed, + ) - class ReformerLocalAttnModelTester(object): - def __init__( - self, - parent, - batch_size=13, - seq_length=32, - is_training=True, - is_decoder=False, - use_input_mask=True, - vocab_size=32, - attention_head_size=16, - hidden_size=32, - num_attention_heads=2, - local_attn_chunk_length=4, - local_num_chunks_before=1, - local_num_chunks_after=0, - chunk_size_lm_head=0, - chunk_size_feed_forward=0, - feed_forward_size=32, - hidden_act="gelu", - hidden_dropout_prob=0.1, - local_attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - initializer_range=0.02, - axial_norm_std=1.0, - layer_norm_eps=1e-12, - axial_pos_embds=True, - axial_pos_shape=[4, 8], - axial_pos_embds_dim=[16, 16], - attn_layers=["local", "local", "local", "local"], - pad_token_id=0, - eos_token_id=2, - scope=None, - hash_seed=0, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.is_decoder = is_decoder - self.use_input_mask = use_input_mask - self.vocab_size = vocab_size - self.attention_head_size = attention_head_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = len(attn_layers) - self.local_attn_chunk_length = local_attn_chunk_length - self.local_num_chunks_after = local_num_chunks_after - self.local_num_chunks_before = local_num_chunks_before - self.hidden_act = hidden_act - self.feed_forward_size = feed_forward_size - self.hidden_dropout_prob = hidden_dropout_prob - self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.axial_pos_embds = axial_pos_embds - self.axial_pos_shape = tuple(axial_pos_shape) - self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) - self.axial_norm_std = axial_norm_std - self.chunk_size_lm_head = chunk_size_lm_head - self.chunk_size_feed_forward = chunk_size_feed_forward - self.scope = scope - self.attn_layers = attn_layers - self.pad_token_id = pad_token_id - self.hash_seed = hash_seed - - self.encoder_seq_length = seq_length // local_attn_chunk_length + ( - self.seq_length % local_attn_chunk_length != 0 - ) - self.key_length = ( - self.local_num_chunks_before + self.local_num_chunks_after + 1 - ) * local_attn_chunk_length - self.chunk_length = local_attn_chunk_length - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - - config = ReformerConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - feed_forward_size=self.feed_forward_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - local_attention_probs_dropout_prob=self.local_attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - is_decoder=self.is_decoder, - axial_pos_embds=self.axial_pos_embds, - axial_pos_shape=self.axial_pos_shape, - axial_pos_embds_dim=self.axial_pos_embds_dim, - local_attn_chunk_length=self.local_attn_chunk_length, - local_num_chunks_after=self.local_num_chunks_after, - local_num_chunks_before=self.local_num_chunks_before, - attn_layers=self.attn_layers, - pad_token_id=self.pad_token_id, - hash_seed=self.hash_seed, - ) - - return ( - config, - input_ids, - input_mask, - ) - - def check_loss_output(self, result): - self.parent.assertListEqual(list(result["loss"].size()), []) - - def create_and_check_reformer_model( - self, config, input_ids, input_mask, - ): - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - (sequence_output,) = model(input_ids, attention_mask=input_mask) - (sequence_output,) = model(input_ids) - - result = { - "sequence_output": sequence_output, - } - # 2 * hidden_size because we use reversible resnet layers - self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], - ) - - def create_and_check_reformer_model_with_lm_backward( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] - loss.backward() - - def create_and_check_reformer_with_lm( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) - result = { - "loss": loss, - "prediction_scores": prediction_scores, - } - self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], - ) - self.check_loss_output(result) - - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): - # no special position embeddings - config.axial_pos_embds = False - config.is_decoder = is_decoder - - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - # set all position encodings to zero so that postions don't matter - with torch.no_grad(): - embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) - embedding.weight.requires_grad = False - - half_seq_len = self.seq_length // 2 - roll = self.local_attn_chunk_length - roll = self.local_attn_chunk_length - half_input_ids = input_ids[:, :half_seq_len] - - # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) - input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - - # shifted padded - input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) - attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - - # input_ids_padded_begin = torch.cat([torch.full_like(input_ids[:, :half_seq_len], self.pad_token_id), input_ids[:, :half_seq_len],], dim=-1) - - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll - ] - - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): - config.is_decoder = is_decoder - layer = ReformerLayer(config).to(torch_device) - layer.train() - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - # get random tensors - hidden_states = floats_tensor(shape) - prev_attn_output = floats_tensor(shape) - - # now the random seeds for attention and feed forward is initialized - # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) - - next_attn_output = layer_outputs.attn_output - next_hidden_states = layer_outputs.hidden_states + return ( + config, + input_ids, + input_mask, + ) - torch.manual_seed(layer.attention_seed) - attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue( - torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) - ) + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) - torch.manual_seed(layer.feed_forward_seed) - feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue( - torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) - ) - - def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] - - config.chunk_size_lm_head = input_ids.shape[-1] - config.chunk_size_feed_forward = input_ids.shape[-1] - - torch.manual_seed(0) - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - - hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) - - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): - model = ReformerModel(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask,) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + def create_and_check_reformer_model( + self, config, input_ids, input_mask, + ): + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + (sequence_output,) = model(input_ids, attention_mask=input_mask) + (sequence_output,) = model(input_ids) - def setUp(self): - self.model_tester = ReformerLocalAttnModelTest.ReformerLocalAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + result = { + "sequence_output": sequence_output, + } + # 2 * hidden_size because we use reversible resnet layers + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], + ) - def test_config(self): - self.config_tester.run_common_tests() + def create_and_check_reformer_model_with_lm_backward( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] + loss.backward() - def test_reformer_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model(*config_and_inputs) + def create_and_check_reformer_with_lm( + self, config, input_ids, input_mask, + ): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], + ) + self.check_loss_output(result) - def test_reformer_lm_model_backward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_lm_backward(*config_and_inputs) + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + # no special position embeddings + config.axial_pos_embds = False + config.is_decoder = is_decoder - def test_reformer_model_attn_masking(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_model_with_attn_mask(*config_and_inputs, False) + if self.lsh_attn_chunk_length is not None: + # need to set chunk length equal sequence length to be certain that chunking works + config.lsh_attn_chunk_length = self.seq_length - def test_reformer_with_lm(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_with_lm(*config_and_inputs) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + # set all position encodings to zero so that postions don't matter + with torch.no_grad(): + embedding = model.embeddings.position_embeddings.embedding + embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) + embedding.weight.requires_grad = False - def test_reformer_layer_training_dropout(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) - self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + half_seq_len = self.seq_length // 2 + roll = self.chunk_length - def test_reformer_chunking_forward_equality(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + half_input_ids = input_ids[:, :half_seq_len] - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") - def test_reformer_model_fp16_forward(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) + # normal padded + attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) + input_ids_padded = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + # shifted padded + input_ids_roll = torch.cat( + [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, + ) + input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) + attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) + + output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] + output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll] + + self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) + + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + config.is_decoder = is_decoder + layer = ReformerLayer(config).to(torch_device) + layer.train() + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + # get random tensors + hidden_states = floats_tensor(shape) + prev_attn_output = floats_tensor(shape) + + # now the random seeds for attention and feed forward is initialized + # forward tensors with dropout + layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) + + next_attn_output = layer_outputs.attn_output + next_hidden_states = layer_outputs.hidden_states + + torch.manual_seed(layer.attention_seed) + attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) + self.parent.assertTrue( + torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) + ) -@require_torch -class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () - all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () - test_pruning = False - test_headmasking = False - test_torchscript = False + torch.manual_seed(layer.feed_forward_seed) + feed_forward_hidden_states = layer.feed_forward(next_attn_output) + self.parent.assertTrue( + torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) + ) - class ReformerLSHAttnModelTester(object): - def __init__( - self, - parent, - batch_size=13, - seq_length=13, - use_input_mask=True, - is_training=False, - is_decoder=False, - vocab_size=32, - attention_head_size=16, - hidden_size=64, - num_attention_heads=2, - num_buckets=2, - num_hashes=4, - lsh_attn_chunk_length=4, - lsh_num_chunks_before=2, - lsh_num_chunks_after=3, - chunk_size_lm_head=5, - chunk_size_feed_forward=6, - feed_forward_size=32, - hidden_act="relu", - hidden_dropout_prob=0.1, - lsh_attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - initializer_range=0.02, - axial_norm_std=1.0, - layer_norm_eps=1e-12, - axial_pos_embds=True, - axial_pos_shape=[4, 8], - axial_pos_embds_dim=[16, 48], - attn_layers=["lsh", "lsh", "lsh", "lsh"], - pad_token_id=0, - eos_token_id=2, - scope=None, - hash_seed=0, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.is_decoder = is_decoder - self.use_input_mask = use_input_mask - self.vocab_size = vocab_size - self.attention_head_size = attention_head_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hashes = num_hashes - self.num_hidden_layers = len(attn_layers) - self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets - self.lsh_attn_chunk_length = lsh_attn_chunk_length - self.lsh_num_chunks_after = lsh_num_chunks_after - self.lsh_num_chunks_before = lsh_num_chunks_before - self.hidden_act = hidden_act - self.feed_forward_size = feed_forward_size - self.hidden_dropout_prob = hidden_dropout_prob - self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.chunk_size_lm_head = chunk_size_lm_head - self.chunk_size_feed_forward = chunk_size_feed_forward - self.scope = scope - self.attn_layers = attn_layers - self.hash_seed = hash_seed - self.pad_token_id = pad_token_id - self.axial_pos_embds = axial_pos_embds - self.axial_pos_shape = tuple(axial_pos_shape) - self.axial_pos_embds_dim = tuple(axial_pos_embds_dim) - self.axial_norm_std = axial_norm_std - - self.encoder_seq_length = seq_length // lsh_attn_chunk_length + (seq_length % lsh_attn_chunk_length != 0) - self.key_length = (self.lsh_num_chunks_before + self.lsh_num_chunks_after + 1) * lsh_attn_chunk_length - self.chunk_length = lsh_attn_chunk_length - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - - config = ReformerConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - feed_forward_size=self.feed_forward_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - lsh_attention_probs_dropout_prob=self.lsh_attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - is_decoder=self.is_decoder, - axial_pos_embds=self.axial_pos_embds, - axial_pos_shape=self.axial_pos_shape, - axial_pos_embds_dim=self.axial_pos_embds_dim, - num_hashes=self.num_hashes, - num_buckets=self.num_buckets, - lsh_attn_chunk_length=self.lsh_attn_chunk_length, - lsh_num_chunks_after=self.lsh_num_chunks_after, - lsh_num_chunks_before=self.lsh_num_chunks_before, - attn_layers=self.attn_layers, - hash_seed=self.hash_seed, - pad_token_id=self.pad_token_id, - ) - - return ( - config, - input_ids, - input_mask, - ) - - def check_loss_output(self, result): - self.parent.assertListEqual(list(result["loss"].size()), []) - - def create_and_check_reformer_model( - self, config, input_ids, input_mask, - ): - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - (sequence_output,) = model(input_ids, attention_mask=input_mask) - (sequence_output,) = model(input_ids) - - result = { - "sequence_output": sequence_output, - } - # 2 * hidden_size because we use reversible resnet layers - self.parent.assertListEqual( - list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], - ) - - def create_and_check_reformer_model_with_lm_backward( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] - loss.backward() - - def create_and_check_reformer_with_lm( - self, config, input_ids, input_mask, - ): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) - result = { - "loss": loss, - "prediction_scores": prediction_scores, - } - self.parent.assertListEqual( - list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], - ) - self.check_loss_output(result) - - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): - # no special position embeddings - config.axial_pos_embds = False - config.is_decoder = is_decoder + def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() + hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] - # need to set chunk length equal sequence length to be certain that chunking works - config.lsh_attn_chunk_length = self.seq_length + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 - model = ReformerModel(config=config) - model.to(torch_device) - model.eval() - # set all position encodings to zero so that postions don't matter - - with torch.no_grad(): - embedding = model.embeddings.position_embeddings.embedding - embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) - embedding.weight.requires_grad = False - - half_seq_len = self.seq_length // 2 - roll = self.lsh_attn_chunk_length - roll = half_seq_len - half_input_ids = input_ids[:, :half_seq_len] - - # normal padded - attn_mask = torch.cat([torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1,) - input_ids_padded = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - - # shifted padded - input_ids_roll = torch.cat( - [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, - ) - input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) - attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) - - output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] - output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][ - :, roll : half_seq_len + roll - ] - - self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): - config.is_decoder = is_decoder - layer = ReformerLayer(config).to(torch_device) - layer.train() - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - # get random tensors - hidden_states = floats_tensor(shape) - prev_attn_output = floats_tensor(shape) - - # now the random seeds for attention and feed forward is initialized - # forward tensors with dropout - layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) - - next_attn_output = layer_outputs.attn_output - next_hidden_states = layer_outputs.hidden_states + torch.manual_seed(0) + model = ReformerModel(config=config) + model.to(torch_device) + model.eval() - torch.manual_seed(layer.attention_seed) - attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) - self.parent.assertTrue( - torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) - ) + hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) + + def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): + if not self.is_training: + return + # disable dropout + config.hidden_dropout_prob = 0 + config.local_attention_probs_dropout_prob = 0 + config.lsh_attention_probs_dropout_prob = 0 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_no_chunk.backward() + grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + + config.chunk_size_lm_head = 1 + config.chunk_size_feed_forward = 1 + + torch.manual_seed(0) + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.train() + model.zero_grad() + loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] + loss_chunk.backward() + grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] + grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] + grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] + self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) + self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) + ) + self.parent.assertTrue( + torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) + ) + + def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): + layer = ReformerLayer(config).to(torch_device) + layer.train() + + shape = ( + self.batch_size, + self.seq_length, + config.hidden_size, + ) # Batch x SeqLen x hiddenSize + + hidden_states = floats_tensor(shape) + attn_output = floats_tensor(shape) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states + torch.manual_seed(layer.attention_seed) + seeds.append(layer.attention_seed) + self.parent.assertGreater(len(set(seeds)), 70) + + seeds = [] + for _ in range(100): + layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) + attn_output = layer_outputs.attn_output + hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.feed_forward_seed) - feed_forward_hidden_states = layer.feed_forward(next_attn_output) - self.parent.assertTrue( - torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) - ) - - @slow - def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): - layer = ReformerLayer(config).to(torch_device) - layer.train() - - shape = ( - self.batch_size, - self.seq_length, - config.hidden_size, - ) # Batch x SeqLen x hiddenSize - - hidden_states = floats_tensor(shape) - attn_output = floats_tensor(shape) - - seeds = [] - for _ in range(100): - layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) - attn_output = layer_outputs.attn_output - hidden_states = layer_outputs.hidden_states - torch.manual_seed(layer.attention_seed) - seeds.append(layer.attention_seed) - self.parent.assertGreater(len(set(seeds)), 70) - - seeds = [] - for _ in range(100): - layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) - attn_output = layer_outputs.attn_output - hidden_states = layer_outputs.hidden_states - torch.manual_seed(layer.feed_forward_seed) - seeds.append(layer.feed_forward_seed) - self.parent.assertGreater(len(set(seeds)), 70) - - def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): - torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.train() - model.zero_grad() - loss_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] - loss_no_chunk.backward() - grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] - grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] - - config.chunk_size_lm_head = input_ids.shape[-1] - config.chunk_size_feed_forward = input_ids.shape[-1] - - torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.train() - model.zero_grad() - loss_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[0] - loss_chunk.backward() - grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] - grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] - grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] - self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) - self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) - self.parent.assertTrue( - torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) - ) - self.parent.assertTrue( - torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) - ) - - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): - model = ReformerModel(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model(input_ids, attention_mask=input_mask)[0] - self.parent.assertFalse(torch.isnan(output).any().item()) - - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): - model = ReformerModelWithLMHead(config=config) - model.to(torch_device) - model.half() - model.eval() - output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) - self.parent.assertFalse(torch.isnan(output).any().item()) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask,) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + seeds.append(layer.feed_forward_seed) + self.parent.assertGreater(len(set(seeds)), 70) - def setUp(self): - self.model_tester = ReformerLSHAttnModelTest.ReformerLSHAttnModelTester(self) - self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + model = ReformerModel(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model(input_ids, attention_mask=input_mask)[0] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.half() + model.eval() + output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) + self.parent.assertFalse(torch.isnan(output).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask,) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class ReformerTesterMixin: + """ + Reformer Local and Reformer LSH run essentially the same tests + """ def test_config(self): self.config_tester.run_common_tests() @@ -700,6 +425,15 @@ def test_reformer_layer_training_dropout(self): self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, True) self.model_tester.create_and_check_reformer_layer_dropout_seed(*config_and_inputs, False) + def test_reformer_chunking_forward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_forward_chunking(*config_and_inputs) + + def test_reformer_chunking_backward_equality(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_feed_backward_chunking(*config_and_inputs) + + @slow def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) @@ -714,6 +448,54 @@ def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + +@require_torch +class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 32, + "is_training": True, + "is_decoder": False, + "use_input_mask": True, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "local_attn_chunk_length": 4, + "local_num_chunks_before": 1, + "local_num_chunks_after": 0, + "chunk_size_lm_head": 0, + "chunk_size_feed_forward": 0, + "feed_forward_size": 32, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "local_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 16], + "attn_layers": ["local", "local", "local", "local"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + @slow def test_model_from_pretrained(self): for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: @@ -721,6 +503,56 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) +@require_torch +class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): + all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def prepare_kwargs(self): + return { + "batch_size": 13, + "seq_length": 13, + "use_input_mask": True, + "is_training": False, + "is_decoder": False, + "vocab_size": 32, + "attention_head_size": 16, + "hidden_size": 64, + "num_attention_heads": 2, + "num_buckets": 2, + "num_hashes": 4, + "lsh_attn_chunk_length": 4, + "lsh_num_chunks_before": 2, + "lsh_num_chunks_after": 3, + "chunk_size_lm_head": 5, + "chunk_size_feed_forward": 6, + "feed_forward_size": 32, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "lsh_attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "axial_norm_std": 1.0, + "layer_norm_eps": 1e-12, + "axial_pos_embds": True, + "axial_pos_shape": [4, 8], + "axial_pos_embds_dim": [16, 48], + "attn_layers": ["lsh", "lsh", "lsh", "lsh"], + "pad_token_id": 0, + "eos_token_id": 2, + "scope": None, + "hash_seed": 0, + } + + def setUp(self): + tester_kwargs = self.prepare_kwargs() + self.model_tester = ReformerModelTester(self, **tester_kwargs) + self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) + + @require_torch class ReformerIntegrationTests(unittest.TestCase): """