Skip to content

Commit

Permalink
Fix flaky BetterTransformer test (#564)
Browse files Browse the repository at this point in the history
* fix flaky test

* fix flaky
  • Loading branch information
fxmarty authored Dec 8, 2022
1 parent 1588a2e commit 86375c4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s
mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size)
stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size)
lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int)
lengths = torch.clamp(lengths, min=0, max=max_sequence_length)

# need at least a sequence length of 1 for BetterTransformer to work
lengths = torch.clamp(lengths, min=1, max=max_sequence_length)

tokens = torch.full(
(batch_size, max_sequence_length),
pad_idx,
Expand Down
5 changes: 4 additions & 1 deletion tests/bettertransformer/testing_bettertransformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s
mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size)
stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size)
lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int)
lengths = torch.clamp(lengths, min=0, max=max_sequence_length)

# need at least a sequence length of 1 for BetterTransformer to work
lengths = torch.clamp(lengths, min=1, max=max_sequence_length)

tokens = torch.full(
(batch_size, max_sequence_length),
pad_idx,
Expand Down

0 comments on commit 86375c4

Please sign in to comment.