From 86375c4bd48f0b867ed3b24d9dd890b6cfe795c2 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 8 Dec 2022 14:25:20 +0100 Subject: [PATCH] Fix flaky BetterTransformer test (#564) * fix flaky test * fix flaky --- tests/bettertransformer/test_bettertransformer_encoder.py | 5 ++++- tests/bettertransformer/testing_bettertransformer_utils.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index bd1a8d78d7..980458bc51 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -309,7 +309,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size) stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size) lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int) - lengths = torch.clamp(lengths, min=0, max=max_sequence_length) + + # need at least a sequence length of 1 for BetterTransformer to work + lengths = torch.clamp(lengths, min=1, max=max_sequence_length) + tokens = torch.full( (batch_size, max_sequence_length), pad_idx, diff --git a/tests/bettertransformer/testing_bettertransformer_utils.py b/tests/bettertransformer/testing_bettertransformer_utils.py index 7d616a3009..7fdbc82499 100644 --- a/tests/bettertransformer/testing_bettertransformer_utils.py +++ b/tests/bettertransformer/testing_bettertransformer_utils.py @@ -186,7 +186,10 @@ def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_s mean_tensor = torch.Tensor([avg_seqlen]).expand(batch_size) stdev_tensor = torch.Tensor([seqlen_stdev]).expand(batch_size) lengths = torch.normal(mean_tensor, stdev_tensor).to(torch.int) - lengths = torch.clamp(lengths, min=0, max=max_sequence_length) + + # need at least a sequence length of 1 for BetterTransformer to work + lengths = torch.clamp(lengths, min=1, max=max_sequence_length) + tokens = torch.full( (batch_size, max_sequence_length), pad_idx,