Skip to content

Commit

Permalink
[generate] move max time tests (#35962)
Browse files Browse the repository at this point in the history
* move max time tests to their right place

* move test to the right place
  • Loading branch information
gante authored Jan 29, 2025
1 parent 4d1d489 commit 4d3b107
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 171 deletions.
41 changes: 41 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import collections
import copy
import datetime
import gc
import inspect
import tempfile
Expand Down Expand Up @@ -4249,6 +4250,46 @@ def test_assisted_generation_early_exit(self):
decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True)
self.assertEqual(decoded_assisted, [expected_output])

@slow
def test_max_time(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)

MAX_TIME = 0.1
MAX_LENGTH = 64

# sampling on
start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=MAX_LENGTH)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

# sampling off
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=MAX_LENGTH)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

# beam search
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=MAX_LENGTH)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

# sanity check: no time limit
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=MAX_LENGTH)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down
45 changes: 1 addition & 44 deletions tests/models/codegen/test_modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# limitations under the License.


import datetime
import unittest

from transformers import CodeGenConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import backend_manual_seed, is_flaky, require_torch, slow, torch_device
from transformers.testing_utils import backend_manual_seed, require_torch, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -493,45 +492,3 @@ def test_codegen_sample(self):
self.assertTrue(
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
) # token_type_ids should change output

@is_flaky(max_attempts=3, description="measure of timing is somehow flaky.")
@slow
def test_codegen_sample_max_time(self):
tokenizer = self.cached_tokenizer
model = self.cached_model
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)

MAX_TIME = 0.05

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=2 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=2 * MAX_TIME))
43 changes: 0 additions & 43 deletions tests/models/gpt2/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.


import datetime
import math
import unittest

Expand Down Expand Up @@ -828,48 +827,6 @@ def test_gpt2_sample(self):
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
) # token_type_ids should change output

@slow
def test_gpt2_sample_max_time(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)

MAX_TIME = 0.5

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

@slow
def test_contrastive_search_gpt2(self):
article = (
Expand Down
42 changes: 0 additions & 42 deletions tests/models/gptj/test_modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.


import datetime
import unittest

from transformers import GPTJConfig, is_torch_available
Expand Down Expand Up @@ -546,47 +545,6 @@ def test_gptj_sample(self):
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
) # token_type_ids should change output

@slow
def test_gptj_sample_max_time(self):
tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
model = GPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random")
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)

MAX_TIME = 0.5

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

@tooslow
def test_contrastive_search_gptj(self):
article = (
Expand Down
42 changes: 0 additions & 42 deletions tests/models/xglm/test_modeling_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import math
import unittest

Expand Down Expand Up @@ -434,47 +433,6 @@ def test_xglm_sample(self):
]
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)

@slow
def test_xglm_sample_max_time(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)

torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device)

MAX_TIME = 0.15

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.25 * MAX_TIME))

@require_torch_accelerator
@require_torch_fp16
def test_batched_nan_fp16(self):
Expand Down

0 comments on commit 4d3b107

Please sign in to comment.