Skip to content

Commit

Permalink
clean reformer test and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Apr 5, 2020
1 parent 7e094aa commit 60e5d9c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
10 changes: 2 additions & 8 deletions src/transformers/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def forward(

ticker, undo_ticker = self._get_ticker_and_undo_ticker(sequence_length, buckets)

query_key_vectors = self._gather_by_expansion(query_key_vectors.repeat(1, 1, self.num_hashes, 1), ticker)
value_vectors = self._gather_by_expansion(value_vectors.repeat(1, 1, self.num_hashes, 1), ticker)
query_key_vectors = self._gather_by_expansion(query_key_vectors, ticker)
value_vectors = self._gather_by_expansion(value_vectors, ticker)

# q_info = ticker

Expand Down Expand Up @@ -339,8 +339,6 @@ def forward(self, hidden_states, input_tensor):
hidden_states = self.dropout(hidden_states)
# residual connection
output = (hidden_states + input_tensor)
# Uncomment here if testing only LSHSelfAttentionLayer
# output = hidden_states
return output


Expand All @@ -358,10 +356,6 @@ def forward(
head_mask=None,
):
norm_hidden_states = self.layer_norm(hidden_states)
# TODO: Remove comments later
# Uncomment here if testing only LSHSelfAttentionLayer
# norm_hidden_states = hidden_states

self_attention_outputs = self.self_attention(
norm_hidden_states, head_mask
)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_lsh_layer(self):
# Remove residual connection in ReformerSelfOutput to test this layer only
# Remove layer norm in ReformerAttention to test this layer only
config = ReformerConfig()
shape = (1, 7, config.hidden_size) # Batch x SeqLen x hiddenSize
shape = (2, 7, config.hidden_size) # Batch x SeqLen x hiddenSize
np_input = np.random.rand(*shape)

trax_utils = TraxUtils(shape)
Expand All @@ -260,7 +260,9 @@ def test_lsh_layer(self):
hf_input = torch.tensor(np_input, dtype=torch.float)
hf_layer = ReformerAttention(config)
self._set_layer_weights_in_torch(trax_weights, hf_layer, config.hidden_size)
hf_output = hf_layer(hf_input, hf_input)[0]

hf_attention_all_heads = hf_layer.self_attention(hf_input)[0]
hf_output = hf_layer.output(hf_attention_all_heads, torch.zeros_like(hf_input))

trax_torch_output = torch.tensor(np.asarray(trax_output))
self.assertTrue(torch.allclose(hf_output, trax_torch_output, atol=1e-3))
Expand Down

0 comments on commit 60e5d9c

Please sign in to comment.