From 75b1a97122976e07d4dda61527fa0ea8b69dec40 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Mon, 16 Dec 2024 12:08:44 -0800 Subject: [PATCH] Revert "Update transformers tests generation util v4.45.2 (#1441)" (#1614) This reverts commit 2ba520a9f12467bfeb25a8fde4dee7caf27c0067. --- conftest.py | 141 - .../habana/transformers/generation/utils.py | 29 +- .../transformers/models/bart/modeling_bart.py | 4 +- pyproject.toml | 9 - .../generation/test_framework_agnostic.py | 43 +- .../tests/generation/test_utils.py | 3995 +++++++---------- 6 files changed, 1600 insertions(+), 2621 deletions(-) diff --git a/conftest.py b/conftest.py index 5775644c48..71cb6bb7ca 100644 --- a/conftest.py +++ b/conftest.py @@ -1,88 +1,3 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# tests directory-specific settings - this file is run automatically -# by pytest before any tests are run -import doctest -import sys -import warnings -from os.path import abspath, dirname, join - -import _pytest -import pytest -from transformers.testing_utils import HfDoctestModule, HfDocTestParser - - -NOT_DEVICE_TESTS = { - "test_tokenization", - "test_processor", - "test_processing", - "test_beam_constraints", - "test_configuration_utils", - "test_data_collator", - "test_trainer_callback", - "test_trainer_utils", - "test_feature_extraction", - "test_image_processing", - "test_image_processor", - "test_image_transforms", - "test_optimization", - "test_retrieval", - "test_config", - "test_from_pretrained_no_checkpoint", - "test_keep_in_fp32_modules", - "test_gradient_checkpointing_backward_compatibility", - "test_gradient_checkpointing_enable_disable", - "test_save_load_fast_init_from_base", - "test_fast_init_context_manager", - "test_fast_init_tied_embeddings", - "test_save_load_fast_init_to_base", - "test_torch_save_load", - "test_initialization", - "test_forward_signature", - "test_model_get_set_embeddings", - "test_model_main_input_name", - "test_correct_missing_keys", - "test_tie_model_weights", - "test_can_use_safetensors", - "test_load_save_without_tied_weights", - "test_tied_weights_keys", - "test_model_weights_reload_no_missing_tied_weights", - "test_pt_tf_model_equivalence", - "test_mismatched_shapes_have_properly_initialized_weights", - "test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist", - "test_model_is_small", - "test_tf_from_pt_safetensors", - "test_flax_from_pt_safetensors", - "ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device - "ModelTester::test_pipeline_", - "/repo_utils/", - "/utils/", - "/agents/", -} - -# allow having multiple repository checkouts and not needing to remember to rerun -# `pip install -e '.[dev]'` when switching between checkouts and running tests. -git_repo_path = abspath(join(dirname(__file__), "src")) -sys.path.insert(1, git_repo_path) - -# silence FutureWarning warnings in tests since often we can't act on them until -# they become normal warnings - i.e. the tests still need to test the current functionality -warnings.simplefilter(action="ignore", category=FutureWarning) - - class Secret: """ Taken from: https://stackoverflow.com/a/67393351 @@ -98,47 +13,9 @@ def __str___(self): return "*******" -def pytest_configure(config): - config.addinivalue_line( - "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested" - ) - config.addinivalue_line( - "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" - ) - config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") - config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") - config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") - config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule") - config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu") - - -def pytest_collection_modifyitems(items): - for item in items: - if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS): - item.add_marker(pytest.mark.not_device_test) - - def pytest_addoption(parser): parser.addoption("--token", action="store", default=None) - from transformers.testing_utils import pytest_addoption_shared - - pytest_addoption_shared(parser) - - -def pytest_terminal_summary(terminalreporter): - from transformers.testing_utils import pytest_terminal_summary_main - - make_reports = terminalreporter.config.getoption("--make-reports") - if make_reports: - pytest_terminal_summary_main(terminalreporter, id=make_reports) - - -def pytest_sessionfinish(session, exitstatus): - # If no tests are collected, pytest exists with code 5, which makes the CI fail. - if exitstatus == 5: - session.exitstatus = 0 - def pytest_generate_tests(metafunc): # This is called for every test. Only get/set command line arguments @@ -146,21 +23,3 @@ def pytest_generate_tests(metafunc): option_value = Secret(metafunc.config.option.token) if "token" in metafunc.fixturenames: metafunc.parametrize("token", [option_value]) - - -# Doctest custom flag to ignore output. -IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") - -OutputChecker = doctest.OutputChecker - - -class CustomOutputChecker(OutputChecker): - def check_output(self, want, got, optionflags): - if IGNORE_RESULT & optionflags: - return True - return OutputChecker.check_output(self, want, got, optionflags) - - -doctest.OutputChecker = CustomOutputChecker -_pytest.doctest.DoctestModule = HfDoctestModule -doctest.DocTestParser = HfDocTestParser diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d81e0d179a..68b445c1b2 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -211,20 +211,19 @@ def _prepare_decoder_input_ids_for_generation( # 2. `decoder_start_token_id` must have shape (batch_size, 1) if device is None: device = self.device - if decoder_start_token_id.ndim == 1: - if decoder_start_token_id.shape[0] != batch_size: - raise ValueError( - f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + if token_idx is None: + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: + raise ValueError( + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + ) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) + else: + decoder_start_token_id = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id ) - decoder_start_token_id = decoder_start_token_id.view(-1, 1) else: - decoder_start_token_id = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id - ) - - if token_idx is not None: - # creating padded decoder_input_ids to achieve static shapes. - # Later new tokens once generated are copied in to decoder_input_ids based on token_idx + # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length decoder_start_token_id = ( torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id @@ -3040,8 +3039,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): if self.generation_config.early_stopping: num_eos_tokens.add_(beam_tokens[0:num_beams].eq(self.config.eos_token_id).sum()) - if self.config.eos_token_id is not None: - beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0)) + beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0)) beam_scores = beam_scores.view(batch_size, -1).unsqueeze(0) _, selected = torch.topk(beam_scores, k=num_beams, dim=-1, largest=True, sorted=True) offset = torch.arange(0, torch.numel(beam_scores), beam_scores.shape[-1]).unsqueeze(-1) @@ -3213,9 +3211,6 @@ def move(obj, device): if not output_scores: sequence_outputs["sequence_scores"] = None - if self.generation_config.static_shapes: - raise NotImplementedError("sequence_scores is not implemented for static_shapes") - if self.config.is_encoder_decoder: return GenerateBeamEncoderDecoderOutput( sequences=sequence_outputs["sequences"], diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py index 08ea48e1a5..3e5f822cb1 100644 --- a/optimum/habana/transformers/models/bart/modeling_bart.py +++ b/optimum/habana/transformers/models/bart/modeling_bart.py @@ -458,9 +458,7 @@ def gaudi_BartDecoder_forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - tensor_past_key_values_length = ( - token_idx - 1 if (use_cache and token_idx is not None) else torch.tensor(past_key_values_length) - ) + tensor_past_key_values_length = token_idx - 1 if use_cache else torch.tensor(past_key_values_length) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) diff --git a/pyproject.toml b/pyproject.toml index f53b25d1c0..b7896da5e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,12 +41,3 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" - -[tool.pytest.ini_options] -addopts = "--doctest-glob='**/*.md'" -doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" -markers = [ - "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", - "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", - "generate: marks tests that use the GenerationTesterMixin" -] diff --git a/tests/transformers/tests/generation/test_framework_agnostic.py b/tests/transformers/tests/generation/test_framework_agnostic.py index 906a90a95a..7fcc4de752 100644 --- a/tests/transformers/tests/generation/test_framework_agnostic.py +++ b/tests/transformers/tests/generation/test_framework_agnostic.py @@ -3,12 +3,8 @@ """ import numpy as np -import pytest from transformers import AutoTokenizer -from transformers.testing_utils import slow - - -torch_device = "hpu" +from transformers.testing_utils import slow, torch_device class GenerationIntegrationTestsMixin: @@ -50,8 +46,6 @@ def test_validate_generation_inputs(self): valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))} model.generate(input_ids, **valid_model_kwargs) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_custom_logits_processor(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"] @@ -72,8 +66,6 @@ def test_custom_logits_processor(self): bart_model.config.min_length = None bart_model.generate(input_ids, logits_processor=logits_processor) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_max_new_tokens_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -230,8 +222,6 @@ def test_transition_scores_greedy_search_normalized(self): ) self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_transition_scores_beam_search_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -267,8 +257,6 @@ def test_transition_scores_beam_search_encoder_decoder(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_transition_scores_beam_search_encoder_decoder_with_eos(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -303,8 +291,6 @@ def test_transition_scores_beam_search_encoder_decoder_with_eos(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_transition_scores_beam_search_decoder_only(self): model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -342,8 +328,6 @@ def test_transition_scores_beam_search_decoder_only(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_transition_scores_beam_sample_encoder_decoder(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -381,7 +365,6 @@ def test_transition_scores_beam_sample_encoder_decoder(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3)) @slow - @pytest.mark.skip("Not Implemented: sequence_scores is not implemented for static_shapes") def test_transition_scores_early_stopping(self): # This is an aggressive test that makes sure that `beam_search's` # transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1` @@ -417,8 +400,6 @@ def test_transition_scores_early_stopping(self): self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_encoder_decoder_generate_attention_mask(self): model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -520,8 +501,6 @@ def test_generate_too_many_encoder_kwargs(self): with self.assertRaises(ValueError): model.generate(input_ids=input_ids, inputs_embeds=input_ids) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_generate_input_features_as_encoder_kwarg(self): model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] @@ -563,8 +542,6 @@ def test_generate_pixel_values_as_encoder_kwarg(self): self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs)) self.assertEqual(output_sequences.shape, (2, 5)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_generate_encoder_outputs_attention_mask(self): model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] @@ -599,6 +576,7 @@ def test_eos_token_id_int_and_list_greedy_search(self): "do_sample": False, "num_beams": 1, } + expectation = 13 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" @@ -608,7 +586,6 @@ def test_eos_token_id_int_and_list_greedy_search(self): model = model.to(torch_device) tokens = tokens.to(torch_device) - expectation = model.config.max_length # static shape should give max_length eos_token_id = 873 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) @@ -628,6 +605,7 @@ def test_eos_token_id_int_and_list_contrastive_search(self): "penalty_alpha": 0.6, "top_k": 4, } + expectation = 17 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" @@ -637,7 +615,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self): model = model.to(torch_device) tokens = tokens.to(torch_device) - expectation = model.config.max_length # static shape should give max_length eos_token_id = 225 generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) @@ -646,8 +623,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self): generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_eos_token_id_int_and_list_beam_search(self): model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"] return_tensors = self.framework_dependent_parameters["return_tensors"] @@ -673,10 +648,7 @@ def test_eos_token_id_int_and_list_beam_search(self): padded_correct_condition = expectation < len(generated_tokens[0]) and all( token == model.config.pad_token_id for token in generated_tokens[0][expectation:] ) - static_shape_condition = expectation < len(generated_tokens[0]) and all( - token == eos_token_id for token in generated_tokens[0][expectation:] - ) - self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition) + self.assertTrue(unpadded_correct_condition or padded_correct_condition) eos_token_id = [873, 198] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) @@ -684,13 +656,8 @@ def test_eos_token_id_int_and_list_beam_search(self): padded_correct_condition = expectation < len(generated_tokens[0]) and all( token == model.config.pad_token_id for token in generated_tokens[0][expectation:] ) - static_shape_condition = expectation < len(generated_tokens[0]) and all( - token in eos_token_id for token in generated_tokens[0][expectation:] - ) - self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition) + self.assertTrue(unpadded_correct_condition or padded_correct_condition) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail def test_generate_vision2text_conditioning(self): model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"] floats_tensor = self.framework_dependent_parameters["floats_tensor"] diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 954bcd14d5..512935e9dd 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -14,27 +14,14 @@ # limitations under the License. -import copy import inspect -import tempfile import unittest import warnings import numpy as np import pytest -from parameterized import parameterized -from transformers import is_torch_available, pipeline, set_seed -from transformers.testing_utils import ( - is_flaky, - require_accelerate, - require_auto_gptq, - require_quanto, - require_torch, - require_torch_gpu, - require_torch_multi_accelerator, - require_torch_multi_gpu, - slow, -) +from transformers import is_torch_available, pipeline +from transformers.testing_utils import require_torch, slow from optimum.habana.checkpoint_utils import model_is_optimized from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -45,50 +32,54 @@ if is_torch_available(): import torch - import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, - AutoProcessor, AutoTokenizer, - BartForCausalLM, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer, ImageGPTForCausalImageModeling, + PreTrainedModel, SpeechEncoderDecoderModel, - T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, + BeamSearchScorer, + ConstrainedBeamSearchScorer, DisjunctiveConstraint, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, - GenerationConfig, GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, + HammingDiversityLogitsProcessor, LogitsProcessorList, MaxLengthCriteria, MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, PhrasalConstraint, - PromptLookupCandidateGenerator, + RepetitionPenaltyLogitsProcessor, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, StoppingCriteria, StoppingCriteriaList, - WatermarkDetector, - WatermarkingConfig, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, ) - from transformers.generation.utils import _speculative_sampling + from transformers.generation.candidate_generator import AssistedCandidateGenerator, CandidateGenerator + from transformers.generation.streamers import BaseStreamer torch_device = "hpu" adapt_transformers_to_gaudi() @@ -100,84 +91,116 @@ class GenerationTesterMixin: input_name = "input_ids" max_new_tokens = 3 + def _update_default_model_kwargs(self, model_kwargs): + model_kwargs["limit_hpu_graphs"] = False + model_kwargs["reuse_cache"] = False + model_kwargs["bucket_size"] = -1 + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids` - input_ids = inputs_dict.pop(self.input_name)[:batch_size, :] - inputs_dict.pop("attention_mask", None) - - # we don't want encoder-decoder models to start from filled decoder ids - inputs_dict.pop("decoder_input_ids", None) - inputs_dict.pop("decoder_attention_mask", None) + input_ids = inputs_dict[self.input_name] # cut to half length & take max batch_size 3 sequence_length = input_ids.shape[-1] // 2 input_ids = input_ids[:batch_size, :sequence_length] - # we'll set cache use in each test differently - inputs_dict.pop("use_cache", None) - - inputs_dict = { - k: v[:batch_size, ...] - for k, v in inputs_dict.items() - if "head_mask" not in k and isinstance(v, torch.Tensor) - } + # generate max 3 tokens + max_length = input_ids.shape[-1] + 3 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` if isinstance(config.eos_token_id, int): config.eos_token_id = [config.eos_token_id] config.pad_token_id = config.eos_token_id[0] - - if self.has_attentions: - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - else: + # TransfoXL has no attention mask + if "transfoxl" in config.__class__.__name__.lower(): attention_mask = None - - # It is important set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated - config.eos_token_id = None - config.forced_eos_token_id = None - - return config, input_ids, attention_mask, inputs_dict - - def _get_logits_processor_kwargs(self, do_sample=False, config=None): - logits_processor_kwargs = { + else: + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] + + return config, input_ids, attention_mask, max_length + + @staticmethod + def _get_logits_processor_and_kwargs( + input_length, + eos_token_id, + forced_bos_token_id=None, + forced_eos_token_id=None, + max_length=None, + diversity_penalty=None, + ): + process_kwargs = { + "min_length": input_length + 1 if max_length is None else max_length - 1, "bad_words_ids": [[1, 0]], + "no_repeat_ngram_size": 2, "repetition_penalty": 1.2, - "remove_invalid_values": True, } - if do_sample: - logits_processor_kwargs.update( - { - "top_k": 10, - "top_p": 0.7, - "temperature": 0.7, - } + logits_processor = LogitsProcessorList( + ( + [ + HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), + ] + if diversity_penalty is not None + else [] ) - # TODO (joao, raushan): see this comment for a long-term fix - # https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264) - # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them - # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. - if config is not None: - image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None - video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None - if image_token_index is not None and image_token_index < config.get_text_config().vocab_size: - logits_processor_kwargs["bad_words_ids"].append([image_token_index]) - if video_token_index is not None and video_token_index < config.get_text_config().vocab_size: - logits_processor_kwargs["bad_words_ids"].append([video_token_index]) - - return logits_processor_kwargs - - def _get_beam_kwargs(self, num_return_sequences=1): + + ( + [ + MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), + ] + if eos_token_id is not None + else [] + ) + + ( + [ + ForcedBOSTokenLogitsProcessor(forced_bos_token_id), + ] + if forced_bos_token_id is not None + else [] + ) + + ( + [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] + if forced_eos_token_id is not None + else [] + ) + + [ + NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), + NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), + RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), + ] + ) + return process_kwargs, logits_processor + + @staticmethod + def _get_warper_and_kwargs(num_beams): + warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} + logits_warper = LogitsProcessorList( + [ + TemperatureLogitsWarper(warp_kwargs["temperature"]), + TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), + TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), + ] + ) + return warp_kwargs, logits_warper + + @staticmethod + def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": 2, "num_return_sequences": num_return_sequences, } - return beam_kwargs + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + ) + return beam_kwargs, beam_scorer - def _get_diverse_beam_kwargs(self, num_return_sequences=1): + @staticmethod + def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -186,46 +209,93 @@ def _get_diverse_beam_kwargs(self, num_return_sequences=1): "num_beam_groups": 2, # one beam per group "diversity_penalty": 2.0, } - return beam_kwargs + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=beam_kwargs["num_beam_groups"], + ) + return beam_kwargs, beam_scorer - def _get_constrained_beam_kwargs(self, num_return_sequences=1): + @staticmethod + def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, "num_beams": num_return_sequences * 4, "num_return_sequences": num_return_sequences, } - return beam_kwargs + beam_scorer = ConstrainedBeamSearchScorer( + batch_size=batch_size, + constraints=constraints, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + ) + return beam_kwargs, beam_scorer + + @staticmethod + def _get_encoder_outputs( + model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 + ): + encoder = model.get_encoder() + encoder_outputs = encoder( + input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( + num_interleave, dim=0 + ) + input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() + attention_mask = None + return encoder_outputs, input_ids, attention_mask + + @staticmethod + def _get_static_shapes(): + return False def _greedy_generate( self, model, input_ids, attention_mask, - inputs_dict, + max_length, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) + if model.config.is_encoder_decoder: + max_length = 4 + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + eos_token_id=model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, - output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, - **logits_processor_kwargs, + remove_invalid_values=True, + **logits_process_kwargs, **model_kwargs, - **inputs_dict, ) return output_generate @@ -235,33 +305,35 @@ def _sample_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, num_return_sequences, + logits_processor, + logits_warper, + logits_warper_kwargs, + process_kwargs, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): torch.manual_seed(0) - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=True, num_beams=1, - max_new_tokens=self.max_new_tokens, + max_length=max_length, num_return_sequences=num_return_sequences, output_scores=output_scores, - output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, - **logits_processor_kwargs, + remove_invalid_values=True, + **logits_warper_kwargs, + **process_kwargs, **model_kwargs, - **inputs_dict, ) return output_generate @@ -271,31 +343,31 @@ def _beam_search_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, + beam_scorer, beam_kwargs, + logits_processor, + logits_process_kwargs, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_scores=output_scores, - output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, + remove_invalid_values=True, **beam_kwargs, - **logits_processor_kwargs, + **logits_process_kwargs, **model_kwargs, - **inputs_dict, ) return output_generate @@ -305,34 +377,32 @@ def _beam_sample_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, + beam_scorer, beam_kwargs, + logits_warper, + logits_warper_kwargs, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): torch.manual_seed(0) - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) output_generate = model.generate( input_ids, do_sample=True, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_scores=output_scores, - output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, + remove_invalid_values=True, **beam_kwargs, - **logits_processor_kwargs, + **logits_warper_kwargs, **model_kwargs, - **inputs_dict, ) - return output_generate def _group_beam_search_generate( @@ -340,31 +410,30 @@ def _group_beam_search_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, + beam_scorer, beam_kwargs, + logits_processor, + logits_process_kwargs, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) output_generate = model.generate( input_ids, do_sample=False, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_scores=output_scores, - output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, + remove_invalid_values=True, **beam_kwargs, - **logits_processor_kwargs, + **logits_process_kwargs, **model_kwargs, - **inputs_dict, ) return output_generate @@ -374,33 +443,33 @@ def _constrained_beam_search_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, + constrained_beam_scorer, constraints, beam_kwargs, + logits_processor, + logits_process_kwargs, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_scores=output_scores, - output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, constraints=constraints, - use_cache=use_cache, **beam_kwargs, - **logits_processor_kwargs, + **logits_process_kwargs, **model_kwargs, - **inputs_dict, ) return output_generate @@ -410,72 +479,76 @@ def _contrastive_generate( model, input_ids, attention_mask, - inputs_dict, + max_length, output_scores=False, - output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, - use_cache=True, ): contrastive_search_kwargs = { "penalty_alpha": 0.6, "top_k": 5, } - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) + if model.config.is_encoder_decoder: + max_length = 4 + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + eos_token_id=model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, num_beams=1, - max_new_tokens=self.max_new_tokens, + max_length=max_length, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, - output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - use_cache=use_cache, - **logits_processor_kwargs, + remove_invalid_values=True, + **logits_process_kwargs, **model_kwargs, **contrastive_search_kwargs, - **inputs_dict, ) return output_generate - @pytest.mark.generate def test_greedy_generate(self): + # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict + model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) - if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - @pytest.mark.generate def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) if model.config.is_encoder_decoder: @@ -491,50 +564,58 @@ def test_greedy_generate_dict_outputs(self): self._check_outputs(output_generate, input_ids, model.config) - @pytest.mark.generate def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): - self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") + # only relevant if model has "use_cache" + return + config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=True, ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self._check_outputs(output_generate, input_ids, model.config, use_cache=True) - @pytest.mark.generate def test_sample_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() + + if model.config.is_encoder_decoder: + max_length = 4 + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, num_return_sequences=1, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, ) if model.config.is_encoder_decoder: @@ -542,24 +623,38 @@ def test_sample_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - @pytest.mark.generate def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 4 + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, num_return_sequences=2, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) if model.config.is_encoder_decoder: @@ -575,20 +670,38 @@ def test_sample_generate_dict_output(self): self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2) - @pytest.mark.generate def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, ) if model.config.is_encoder_decoder: @@ -596,26 +709,72 @@ def test_beam_search_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - @pytest.mark.generate + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + + output_generate = self._beam_search_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, + ) + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + else: + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - beam_kwargs = self._get_beam_kwargs() + if model.config.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) + if model.config.is_encoder_decoder: + self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + else: + self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) @@ -631,139 +790,148 @@ def test_beam_search_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - @pytest.mark.generate def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: # enable cache - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): - self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") + # only relevant if model has "use_cache" + return model = model_class(config).to(torch_device).eval() - beam_kwargs = self._get_beam_kwargs() + if model.config.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_process_kwargs=logits_process_kwargs, + logits_processor=logits_processor, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=True, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self._check_outputs( - output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] + output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams ) - @require_accelerate - @require_torch_multi_accelerator - @pytest.mark.generate - def test_model_parallel_beam_search(self): + @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") + def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: - if "xpu" in torch_device: - return unittest.skip(reason="device_map='auto' does not work with XPU devices") - - if model_class._no_split_modules is None: - continue + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None - model = model_class(config).eval() - with tempfile.TemporaryDirectory() as tmp_dir: - model.cpu().save_pretrained(tmp_dir) - new_model = model_class.from_pretrained(tmp_dir, device_map="auto") + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - new_model.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=self.max_new_tokens, - num_beams=2, - **inputs_dict, - ) + model = model_class(config).to(torch_device).eval() - @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") - @pytest.mark.generate - def test_beam_sample_generate(self): - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + # check `generate()` and `beam_search()` are equal + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) - model = model_class(config).to(torch_device).eval() - beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - - # for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly - # no quick fix available, since obtaining image embeddings step is very model-specific - if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")): - prepare_inputs_for_generation_args = set( - inspect.signature(model.prepare_inputs_for_generation).parameters + if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + input_embeds = model.get_input_embeddings()(input_ids) + beam_kwargs.update({"inputs_embeds": input_embeds}) + output_generate2 = self._beam_sample_generate( + model=model, + input_ids=None, + attention_mask=attention_mask, + beam_kwargs=beam_kwargs, + logits_warper_kwargs=logits_warper_kwargs, ) - # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling - # code is up to date with our most recent standards - if ( - "inputs_embeds" in prepare_inputs_for_generation_args - and "cache_positions" in prepare_inputs_for_generation_args - ): - input_embeds = model.get_input_embeddings()(input_ids) - beam_kwargs.update({"inputs_embeds": input_embeds}) - output_generate2 = self._beam_sample_generate( - model=model, - input_ids=None, - attention_mask=attention_mask, - inputs_dict={}, - beam_kwargs=beam_kwargs, - ) - torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) + torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet") - @pytest.mark.generate def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - beam_kwargs = self._get_beam_kwargs() + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) @@ -779,131 +947,192 @@ def test_beam_sample_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - @pytest.mark.generate def test_generate_without_input_ids(self): - config, _, _, _ = self._get_input_ids_and_config() + config, _, _, max_length = self._get_input_ids_and_config() # if no bos token id => cannot generate from None if config.bos_token_id is None: - self.skipTest(reason="bos_token_id is None") - - # hack in case they are equal, otherwise the attn mask will be [0] - if config.bos_token_id == config.pad_token_id: - config.pad_token_id = None + return for model_class in self.all_generative_model_classes: model = model_class(config).to(torch_device) model.eval() - output_ids_generate = model.generate( - do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True - ) + output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) self.assertIsNotNone(output_ids_generate) @pytest.mark.skip("Group beam search is not supported by optimum-habana") - @pytest.mark.generate def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + diversity_penalty=2.0, + ) + # check `generate()` and `group_beam_search()` are equal - beam_kwargs = self._get_diverse_beam_kwargs() - output_generate = self._group_beam_search_generate( + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + output_generate, output_group_beam_search = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) - # check `group_beam_search` for higher than 1 `num_return_sequences` + # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` num_return_sequences = 2 - beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) - output_generate = self._group_beam_search_generate( + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + output_generate, output_group_beam_search = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist()) @pytest.mark.skip("Group beam search is not supported by optimum-habana") - @pytest.mark.generate def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - beam_kwargs = self._get_diverse_beam_kwargs() - output_generate = self._group_beam_search_generate( + if model.config.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + diversity_penalty=2.0, + ) + + num_return_sequences = 1 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + output_generate, output_group_beam_search = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + beam_scorer=beam_scorer, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check + self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check + self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] + self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist()) + self.assertTrue( + torch.allclose( + output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3 + ) ) + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + + for output in (output_group_beam_search, output_generate): + self._check_outputs( + output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams + ) - # TODO: @gante - @is_flaky() - @pytest.mark.generate def test_constrained_beam_search_generate(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() + max_length = 20 + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + + # check `generate()` and `constrained_beam_search()` are equal # Sample constraints - min_id = 3 - max_id = config.get_text_config(decoder=True).vocab_size + if not input_ids.dtype == torch.float32: + min_id = torch.min(input_ids) + 3 + max_id = torch.max(input_ids) + else: + # otherwise this throws an error for Speech2TextModel since its inputs are floating points + min_id = 3 + max_id = 100 force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] - beam_kwargs = self._get_constrained_beam_kwargs() + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=1 + ) output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertTrue(output_generate.shape[-1] == max_length) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -915,63 +1144,86 @@ def test_constrained_beam_search_generate(self): PhrasalConstraint(force_tokens), ] - beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2) + num_return_sequences = 2 + max_length = 20 + + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences + ) output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + self.assertTrue(output_generate.shape[-1] == max_length) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) - @pytest.mark.generate def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 20 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) # Sample constraints min_id = 3 - max_id = model.config.get_text_config(decoder=True).vocab_size + max_id = model.config.vocab_size force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), ] - beam_kwargs = self._get_constrained_beam_kwargs() + beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, constraints, num_return_sequences=1 + ) output_generate = self._constrained_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, + constrained_beam_scorer=beam_scorer, constraints=constraints, beam_kwargs=beam_kwargs, + logits_processor=logits_processor, + logits_process_kwargs=logits_process_kwargs, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=False, ) - + self.assertTrue(output_generate.sequences.shape[-1] == max_length) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) @@ -980,52 +1232,47 @@ def test_constrained_beam_search_generate_dict_output(self): output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] ) - @pytest.mark.generate + self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],)) + self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) + def test_contrastive_generate(self): + # check `generate()` and `contrastive_search()` are equal for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") + return - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + return + config.use_cache = True config.is_decoder = True # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - inputs_dict=inputs_dict, - use_cache=True, + model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - @pytest.mark.generate def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") + return - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + return + config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -1033,40 +1280,36 @@ def test_contrastive_generate_dict_outputs_use_cache(self): model=model, input_ids=input_ids, attention_mask=attention_mask, - inputs_dict=inputs_dict, + max_length=max_length, output_scores=True, - output_logits=True, output_hidden_states=True, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, - use_cache=True, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self._check_outputs(output_generate, input_ids, model.config, use_cache=True) - @pytest.mark.generate def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support contrastive search generation") - - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): - self.skipTest(reason="Won't fix: old model with different cache format") - if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): - self.skipTest(reason="TODO: fix me") + # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). + if any( + model_name in model_class.__name__.lower() + for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] + ): + return - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + return + config.use_cache = True config.is_decoder = True # test output equality of low versus high memory @@ -1077,10 +1320,8 @@ def test_contrastive_generate_low_memory(self): top_k=4, penalty_alpha=0.6, low_memory=True, - max_new_tokens=self.max_new_tokens, + max_length=max_length, attention_mask=attention_mask, - **inputs_dict, - use_cache=True, ) high_output = model.generate( @@ -1088,10 +1329,8 @@ def test_contrastive_generate_low_memory(self): top_k=4, penalty_alpha=0.6, low_memory=False, - max_new_tokens=self.max_new_tokens, + max_length=max_length, attention_mask=attention_mask, - **inputs_dict, - use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1138,75 +1377,89 @@ def test_contrastive_generate_dynamic_shapes(self): ) self.assertListEqual(dynamic_output.tolist(), static_output.tolist()) - @pytest.mark.generate - @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703") - def test_beam_search_low_memory(self): - # Check that choosing 'low_memory' does not change the model output + # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable + @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") + @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. + def test_assisted_decoding_matches_greedy_search(self): + # This test ensures that the assisted generation does not introduce output changes over greedy search. + # It breaks the pattern in the tests above, for multiple reasons: + # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to + # prepare the assistant encoder outputs in the main generate body); + # - assisted_decoding does not support `use_cache = False` + # - assisted_decoding does not support `batch_size > 1` + for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="May fix in the future: need custom cache handling") + # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") + return + # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in [ - "ctrl", - "gptbigcode", - "transo_xl", - "xlnet", - "cpm", - "jamba", - ] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): - self.skipTest(reason="May fix in the future: need model-specific fixes") - config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2) - # batch_size=1 is ok, but batch_size>1 will cause non-identical output + return - config.use_cache = True - config.is_decoder = True + # This for loop is a naive and temporary effort to make the test less flaky. + failed = 0 + for i in range(10): + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) - # test output equality of low versus high memory - model = model_class(config).to(torch_device).eval() + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return - low_output = model.generate( - input_ids, - max_new_tokens=8, - num_beams=5, - early_stopping=True, - low_memory=True, - use_cache=True, - ) + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_greedy = model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will + # be correct + output_assisted = model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + assistant_model=model, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) - high_output = model.generate( - input_ids, - max_new_tokens=8, - num_beams=5, - early_stopping=True, - low_memory=False, - use_cache=True, - ) - self.assertListEqual(low_output.tolist(), high_output.tolist()) + try: + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) - @pytest.mark.generate - @parameterized.expand([("random",), ("same",)]) - @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. - def test_assisted_decoding_matches_greedy_search(self, assistant_type): - # This test ensures that the assisted generation does not introduce output changes over greedy search. - # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul - # shape differences -- and it may result in a different output. The input shape difference happens in the - # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at - # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. - # NOTE (2): It breaks the pattern in the tests above, for multiple reasons: - # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to - # prepare the assistant encoder outputs in the main generate body); - # - assisted_decoding does not support `use_cache = False` - # - assisted_decoding does not support `batch_size > 1` + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + except AssertionError: + failed += 1 + if failed > 1: + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable + @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana") + def test_assisted_decoding_sample(self): + # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not + # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with + # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") + self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ @@ -1220,15 +1473,16 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "clvp", ] ): - self.skipTest(reason="May fix in the future: need model-specific fixes") + self.skipTest("May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + self.skipTest("This model doesn't support caching") + config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1237,253 +1491,503 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # the assistant model is correct # c) there are at least two forward passes in the main model, to ensure the input preparation of # the main model is correct - generation_kwargs = { - "eos_token_id": -1, # see a) - "max_new_tokens": 4, # see c) - "num_beams": 1, - "do_sample": False, - "output_scores": True, - "output_logits": True, - "output_hidden_states": True, - "output_attentions": self.has_attentions, - "return_dict_in_generate": True, - "use_cache": True, - } - output_greedy = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) - - # test with the same assistant model or randomly init one - # in the first case all candidate tokens are accepted, in the second none is accepted - # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) - if assistant_type == "random": - assistant_model = model_class(config).to(torch_device).eval() - else: - assistant_model = model + assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) - - # The two outputs must match and their shape must be as expected - - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) - for output in (output_greedy, output_assisted): - self._check_outputs(output, input_ids, model.config, use_cache=True) - - @is_flaky() - @pytest.mark.generate - def test_prompt_lookup_decoding_matches_greedy_search(self): - # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. - # This test is mostly a copy of test_assisted_decoding_matches_greedy_search - - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - if any( - model_name in model_class.__name__.lower() - for model_name in [ - "bigbirdpegasus", - "led", - "mega", - "speech2text", - "git", - "prophetnet", - "seamlessm4t", - "clvp", - ] - ): - self.skipTest(reason="May fix in the future: need model-specific fixes") - - # enable cache - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) - - # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - # Sets assisted generation arguments such that: - # a) no EOS is generated, to ensure generation doesn't break early - # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of - # prompt lookup is correct - # c) there are at least two forward passes in the main model, to ensure the input preparation of - # the main model is correct generation_kwargs = { "eos_token_id": -1, # see a) "max_new_tokens": 4, # see c) "num_beams": 1, - "do_sample": False, + "do_sample": True, + "assistant_model": assistant_model, "output_scores": True, - "output_logits": True, "output_hidden_states": True, - "output_attentions": self.has_attentions, + "output_attentions": True, "return_dict_in_generate": True, - "use_cache": True, } - output_greedy = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) - - generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) - output_prompt_lookup = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) - - # The two outputs must match and their shape must be as expected + ####################################################################### + # Monkey patch assisted decoding function till SW issue is resolved + import copy + from types import MethodType + from typing import List, Optional, Union + + from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, + _split_model_outputs, + ) + + def _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, + ): + """ + Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns + the selected tokens, as well as the number of candidate matches. + + NOTE: Unless otherwise stated, the variable names match those in the paper. + """ + new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] + # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens + # selected by the assistant, respectively. + q = candidate_logits.softmax(dim=-1) + q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller + # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio + # (= keep with p = probability_ratio). Keep all the tokens until the first rejection + r_i = torch.rand_like(probability_ratio) + is_accepted = r_i <= probability_ratio + n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) + if last_assistant_token_is_eos and n_matches == candidate_length: + # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model + # due to acceptance on EOS we fix `n_matches` + n_matches -= 1 + valid_tokens = new_candidate_input_ids[:, : n_matches + 1] + else: + n_matches = min(n_matches, max_matches) + + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = min(candidate_logits.shape[1], max_matches) + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + else: + valid_tokens = t + + return valid_tokens, n_matches + + def assisted_decoding( + self, + input_ids: torch.LongTensor, + assistant_model: Optional["PreTrainedModel"] = None, + candidate_generator: Optional["CandidateGenerator"] = None, + do_sample: bool = False, + logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** or + **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a + candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text + models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use + generate() instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + candidate_generator (`CandidateGenerator`, *optional*): + A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For + more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> outputs = model.assisted_decoding( + ... input_ids, + ... assistant_model=assistant_model, + ... logits_processor=logits_processor, + ... stopping_criteria=stopping_criteria, + ... ) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # handling deprecated arguments + if (assistant_model is None) == (candidate_generator is None): + raise ValueError( + "One (and only one) of `assistant_model` and `candidate_generator` should be defined." + ) - self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist()) - for output in (output_greedy, output_prompt_lookup): - self._check_outputs(output, input_ids, model.config, use_cache=True) + if assistant_model is not None: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + eos_token_id=eos_token_id, + ) + warnings.warn( + "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " + "Pass the `candidate_generator` argument instead.", + FutureWarning, + ) - @pytest.mark.generate - def test_dola_decoding_sample(self): - # TODO (joao): investigate skips, try to reduce incompatibilities - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support DoLa decoding") + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None and pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) - if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): - self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.") + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]): - self.skipTest("DoLa is not supported for models that don't return layerwise hidden states") + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + ) + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) - # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + # other auxiliary variables + max_len = stopping_criteria[0].max_length + + this_peer_finished = False # used by synced_gpus only + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + torch.dist.all_reduce(this_peer_finished_flag, op=torch.dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + cur_len = input_ids.shape[-1] + + # 1. Fetch candidate sequences from a `CandidateGenerator` + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids = candidate_input_ids.to(self.device) + if candidate_logits is not None: + candidate_logits = candidate_logits.to(self.device) + + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + last_assistant_token_is_eos = ( + ~candidate_input_ids[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + .bool() + ) - # Encoder-decoder models are not supported - if config.is_encoder_decoder: - self.skipTest("DoLa is not supported for encoder-decoder models") - config.is_decoder = True - model = model_class(config).to(torch_device).eval() + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, + # we use this forward pass to also pick the subsequent logits in the original model. - if model.get_output_embeddings() is None: - self.skipTest("DoLa is not supported for models that don't have output embeddings") - # Sets dola generation arguments such that: - # a) no EOS is generated, to ensure generation doesn't break early - # b) there are at least two forward passes in the main model, to ensure the input preparation of - # the main model is correct - generation_kwargs = { - "eos_token_id": -1, # see a) - "max_new_tokens": 4, # see b) - "num_beams": 1, - "do_sample": True, - "output_scores": True, - "output_logits": True, - "output_hidden_states": True, - "output_attentions": self.has_attentions, - "return_dict_in_generate": True, - "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache - } - generation_kwargs.update({"dola_layers": "low"}) - model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict) - self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache")) + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + ) + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) - @pytest.mark.generate - def test_assisted_decoding_sample(self): - # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not - # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with - # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). - for model_class in self.all_generative_model_classes: - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest(reason="Won't fix: old model with different cache format") - if any( - model_name in model_class.__name__.lower() - for model_name in [ - "bigbirdpegasus", - "led", - "mega", - "speech2text", - "git", - "prophetnet", - "seamlessm4t", - "clvp", - ] - ): - self.skipTest(reason="May fix in the future: need model-specific fixes") + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) - # enable cache - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) + # 2.2. Run a forward pass on the candidate sequence + outputs = self( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) - # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + # 2.3. Process the new logits + new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + if len(logits_processor) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_processor( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] + ) + if len(logits_warper) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_warper( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :] + ) - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - # Sets assisted generation arguments such that: - # a) no EOS is generated, to ensure generation doesn't break early - # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of - # the assistant model is correct - # c) there are at least two forward passes in the main model, to ensure the input preparation of - # the main model is correct - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 2 # see b) - assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - generation_kwargs = { - "eos_token_id": -1, # see a) - "max_new_tokens": 4, # see c) - "num_beams": 1, - "do_sample": True, - "assistant_model": assistant_model, - "output_scores": True, - "output_logits": True, - "output_hidden_states": True, - "output_attentions": self.has_attentions, - "return_dict_in_generate": True, - "use_cache": True, - } - output_assisted = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). + max_matches = max_len - cur_len - 1 + if do_sample and candidate_logits is not None: + valid_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + max_matches, + ) - self._check_outputs(output_assisted, input_ids, config, use_cache=True) + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. + else: + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) + + candidate_new_tokens = candidate_input_ids[:, cur_len:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + + # Ensure we don't generate beyond max_len or an EOS token + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + n_matches = min(n_matches, max_matches) + valid_tokens = selected_tokens[:, : n_matches + 1] + + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated + # by the model after the last candidate match is also valid, as it is generated from a correct sequence. + # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there + # is no match. + + # 4.1. Get the valid continuation, after the matching tokens + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + if streamer is not None: + streamer.put(valid_tokens.cpu()) + new_cur_len = input_ids.shape[-1] + + # 4.2. Discard past key values relative to unused assistant tokens + new_cache_size = new_cur_len - 1 + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + + # 5. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. + if return_dict_in_generate: + if output_scores: + scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + + if "past_key_values" not in model_kwargs: + added_len = new_cur_len + else: + added_len = n_matches + 1 + + if output_attentions: + if self.config.is_encoder_decoder: + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, cur_len, added_len + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + cur_len, + added_len, + is_decoder_attention=True, + ) + else: + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + cur_len, + added_len, + is_decoder_attention=True, + ) + if output_hidden_states: + if self.config.is_encoder_decoder: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len + ) + else: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, cur_len, added_len + ) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) - @pytest.mark.generate - def test_prompt_lookup_decoding_stops_at_eos(self): - # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens - # (see https://github.com/huggingface/transformers/pull/31301) + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + input_ids[:, -1] + .tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + ) - # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids. - # First time at the very end, so input ends with the unigrams, and second any arbitrary location. - # Also, we need an EOS token which will be injected just after the arbitrary located ngram. - # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams - # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model. - # That seems as if the model "generated" and EOS but didn't stop from user's perspective + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids - input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 - arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests - input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs - input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen + model.assisted_decoding = MethodType(assisted_decoding, model) - eos_token_id = torch.tensor([0], device=torch_device) - input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram + ####################################################################### - # init cand geenerator with max_matching_ngram_size=1 to match per-token - candidate_generator = PromptLookupCandidateGenerator( - eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 - ) - output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0] + output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) - # PLD shouldn't propose any new tokens based on eos-match - self.assertTrue(output_prompt_lookup.shape[-1] == 10) + self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) - @pytest.mark.generate def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # We want to test only encoder-decoder models if not config.is_encoder_decoder: continue @@ -1509,93 +2013,60 @@ def test_generate_with_head_masking(self): input_ids, attention_mask=attention_mask, num_beams=1, - output_attentions=self.has_attentions, + output_attentions=True, return_dict_in_generate=True, remove_invalid_values=True, **{name: mask}, - **inputs_dict, ) # We check the state of decoder_attentions and cross_attentions just from the last step attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) - @pytest.mark.generate def test_left_padding_compatibility(self): - # NOTE: left-padding results in small numerical differences. This is expected. - # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 - - # First, filter out models that don't support left padding - # - The model must have generative capabilities - if len(self.all_generative_model_classes) == 0: - self.skipTest(reason="No generative architecture available for this model.") + # The check done in this test is fairly difficult -- depending on the model architecture, passing the right + # position index for the position embeddings can still result in a different output, due to numerical masking. + # On the other hand, for some types of position embeddings, an incorrect position index can have a minimal + # impact on the output. + # There are two tricks employed to check whether left-padding compatibility is in place: + # 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the + # padding size to 1. + # 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check + # multiple times with random inputs, and it has to pass with all of them. + # NOTE: because of 2), there is some chance of false positives in this test. - # - The model must support padding - if not self.has_attentions: - self.skipTest(reason="This model doesn't support padding.") - - # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) - decoder_only_classes = [] for model_class in self.all_generative_model_classes: config, _, _, _ = self._get_input_ids_and_config() if config.is_encoder_decoder: - continue - else: - decoder_only_classes.append(model_class) - if len(decoder_only_classes) == 0: - self.skipTest(reason="No decoder-only architecture available for this model.") - - # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't - # added support for it yet. We skip these models for now. - has_encoder_attributes = any( - attr_name - for attr_name in config.to_dict().keys() - if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" - ) - if has_encoder_attributes: - self.skipTest( - reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." - ) - - # Then, test left-padding - def _prepare_model_kwargs(input_ids, attention_mask, signature): - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} - if "position_ids" in signature: - position_ids = torch.cumsum(attention_mask, dim=-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - model_kwargs["position_ids"] = position_ids - if "cache_position" in signature: - cache_position = torch.arange(input_ids.shape[-1], device=torch_device) - model_kwargs["cache_position"] = cache_position - return model_kwargs - - for model_class in decoder_only_classes: - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + continue # skip for encoder-decoder models -- they don't need left-padding compatibility model = model_class(config).to(torch_device).eval() signature = inspect.signature(model.forward).parameters.keys() - # no cache as some models require special cache classes to be init outside forward - model.generation_config.use_cache = False - - # Without padding - model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) - next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] - - # With left-padding (length 32) - # can hardcode pad_token to be 0 as we'll do attn masking anyway - pad_token_id = ( - config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 - ) - pad_size = (input_ids.shape[0], 32) - padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id - padded_input_ids = torch.cat((padding, input_ids), dim=1) - padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) - model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) - next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] - - # They should result in very similar logits - self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + no_failures = True + for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness + _, input_ids, attention_mask, _ = self._get_input_ids_and_config() + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + pad_size = (input_ids.shape[0], 1) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1 + position_ids.masked_fill_(padded_attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7): + no_failures = False + break + + self.assertTrue(no_failures) - @pytest.mark.generate def test_past_key_values_format(self): # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a # standard KV cache format is important for a consistent API (and for advanced generation methods). @@ -1604,7 +2075,7 @@ def test_past_key_values_format(self): # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + return model = model_class(config).to(torch_device) if "use_cache" not in inputs: @@ -1613,7 +2084,7 @@ def test_past_key_values_format(self): # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) if "past_key_values" not in outputs: - self.skipTest(reason="This model doesn't return `past_key_values`") + return num_hidden_layers = ( getattr(config, "decoder_layers", None) @@ -1667,7 +2138,6 @@ def test_past_key_values_format(self): past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) ) - @pytest.mark.generate def test_generate_from_inputs_embeds_decoder_only(self): # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function @@ -1694,581 +2164,100 @@ def test_generate_from_inputs_embeds_decoder_only(self): continue # Traditional way of generating text - outputs_from_ids = model.generate( - input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True - ) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + outputs_from_ids = model.generate(input_ids) + self.assertEqual(outputs_from_ids.shape, (2, 20)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate( - input_ids, - inputs_embeds=inputs_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) + outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) + self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) - # But if we pass different inputs_embeds, we should get different outputs (the output text may be the - # same, but the logits will almost surely be different) + # But if we pass different inputs_embeds, we should get different outputs + torch.manual_seed(0) random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate( - input_ids, - inputs_embeds=random_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - for i in range(len(outputs_from_rand_embeds.scores)): - self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) + outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1] ) self.assertListEqual( - outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids.sequences.tolist(), + outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids.tolist(), ) - @pytest.mark.generate - def test_generate_from_inputs_embeds_with_static_cache(self): - """ - Test that StaticCache can generate from inputs_embeds and calculates max_cache_length - correctly in `generate()`. We force the model to not stop generation until max-length is reached - to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. - """ - for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest(reason="This model does not support the static cache format") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - if config.is_encoder_decoder: - self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - - model = model_class(config).to(torch_device).eval() - if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): - self.skipTest(reason="This model does not support `inputs_embeds` in generation") - - model.config.use_cache = True - model.config.is_decoder = True - batch_size, seq_length = input_ids.shape - max_cache_len = 30 + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + batch_size, seq_length = input_ids.shape + num_sequences_in_output = batch_size * num_return_sequences + gen_len = ( + output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length + ) - # here we force to not stop at eos and go until max-length - model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 - generation_kwargs = { - "max_length": max_cache_len, - "cache_implementation": "static", - "return_dict_in_generate": True, # Required to return `past_key_values` - } + # scores + self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) - text_config = model.config.get_text_config() - head_dim = ( - text_config.head_dim - if hasattr(text_config, "head_dim") - else text_config.hidden_size // text_config.num_attention_heads + # Attentions + if config.is_encoder_decoder: + # encoder + self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) + # decoder + self._check_attentions_for_generate( + num_sequences_in_output, + output.decoder_attentions, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, ) - num_key_value_heads = ( - text_config.num_attention_heads - if getattr(text_config, "num_key_value_heads", None) is None - else text_config.num_key_value_heads + else: + # if use_cache first input is equal to no use_cache, so skip here + attentions = output.attentions if not use_cache else output.attentions[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_attentions_for_generate( + num_sequences_in_output, + attentions=attentions, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, ) - num_hidden_layers = text_config.num_hidden_layers - inputs_embeds = model.get_input_embeddings()(input_ids) - outputs = model.generate( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict + # Hidden States + if config.is_encoder_decoder: + # encoder + self._check_encoder_hidden_states_for_generate( + output.encoder_hidden_states, batch_size, config, seq_length ) - # we should get `max_length` in shape, not `max_length - embeds_length` - cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) - self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) - self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape) - - @pytest.mark.generate - def test_generate_continue_from_past_key_values(self): - # Tests that we can continue generating from past key values, returned from a previous `generate` call - for model_class in self.all_generative_model_classes: - if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): - self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") - if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): - self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") - - config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + # decoder + self._check_hidden_states_for_generate( + num_sequences_in_output, + output.decoder_hidden_states, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_hidden_states_for_generate( + num_sequences_in_output, + hidden_states, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) - if not hasattr(config, "use_cache"): - self.skipTest(reason=f"{model_class.__name__} doesn't support caching") - - # Let's make it always: - # 1. use cache (for obvious reasons) - # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which - # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the - # continuation would force it to generate beyond an EOS token) - # 3. ignore `token_type_ids` for simplicity - # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is - # active by default on some models - # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When - # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents - # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls - # with cache, what is considered a prompt is different in the two cases. - - if "token_type_ids" in inputs: - del inputs["token_type_ids"] - - model = model_class(config).to(torch_device) - model.eval() - model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 - model.generation_config.forced_eos_token_id = None - model.generation_config.encoder_no_repeat_ngram_size = 0 - model.generation_config.use_cache = True - - # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) - outputs = model(**inputs) - if "past_key_values" not in outputs: - self.skipTest(reason="This model doesn't return `past_key_values`") - - # Traditional way of generating text, with `return_dict_in_generate` to return the past key values - outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) - - # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the - # inputs may need to be tweaked across `generate` calls (like the attention mask). - outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) - - # Continue from the tokens generated above, preparing the inputs accordingly - inputs["past_key_values"] = outputs_cached.past_key_values - new_attention_len = outputs_cached.sequences.shape[-1] - if config.is_encoder_decoder: - inputs["decoder_input_ids"] = outputs_cached.sequences - if "decoder_attention_mask" in inputs: - inputs["decoder_attention_mask"] = torch.nn.functional.pad( - inputs["decoder_attention_mask"], - (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), - mode="constant", - value=1, - ) - else: - inputs["input_ids"] = outputs_cached.sequences - if "attention_mask" in inputs: - inputs["attention_mask"] = torch.nn.functional.pad( - inputs["attention_mask"], - (0, new_attention_len - inputs["attention_mask"].shape[1]), - mode="constant", - value=1, - ) - outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) - - # The two sets of generated text and past kv should be equal to each other - self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) - for layer_idx in range(len(outputs_cached.past_key_values)): - for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): - self.assertTrue( - torch.allclose( - outputs.past_key_values[layer_idx][kv_idx], - outputs_cached.past_key_values[layer_idx][kv_idx], - ) - ) - - @parameterized.expand([(1, False), (1, True), (4, False)]) - @pytest.mark.generate - def test_new_cache_format(self, num_beams, do_sample): - # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). - # 👉 tests with and without beam search so that we can test with and without cache reordering. - # 👉 tests with and without sampling so we can cover the most common use cases. - for model_class in self.all_generative_model_classes: - if not model_class._supports_cache_class: - self.skipTest(reason="This model does not support the new cache format") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_new_tokens": 5, - "do_sample": do_sample, - "num_beams": num_beams, - "num_return_sequences": num_beams, - "return_dict_in_generate": True, # Required to return `past_key_values` - "use_cache": True, - } - - # Sets seed before calling `generate` for the case with do_sample=True - seed = torch.randint(0, 1000000, (1,)).item() - set_seed(seed) - legacy_results = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict - ) - set_seed(seed) - if config.is_encoder_decoder: - cache_cls = EncoderDecoderCache - past_key_values = cache_cls(DynamicCache(), DynamicCache()) - past_key_values = cache_cls(DynamicCache(), DynamicCache()) - else: - cache_cls = DynamicCache - past_key_values = cache_cls() - - new_results = model.generate( - input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - **generation_kwargs, - **inputs_dict, - ) - - # The two sets of generated sequences must match, despite the cache format between forward passes being - # different - self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) - self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) - self.assertTrue(isinstance(new_results.past_key_values, cache_cls)) - - # The contents of the two caches, when converted to the same format (in both directions!), must match - legacy_cache = legacy_results.past_key_values - new_cache_converted = new_results.past_key_values.to_legacy_cache() - for layer_idx in range(len(legacy_cache)): - for kv_idx in range(len(legacy_cache[layer_idx])): - # TODO: @raushan, please look into this for new cache format - if legacy_cache[layer_idx][kv_idx] != []: - self.assertTrue( - torch.allclose( - legacy_cache[layer_idx][kv_idx], - new_cache_converted[layer_idx][kv_idx], - ) - ) - - new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) - for layer_idx in range(len(new_cache)): - for kv_idx in range(len(new_cache[layer_idx])): - # TODO: @raushan, please look into this for new cache format - if new_cache[layer_idx][kv_idx] != []: - self.assertTrue( - torch.allclose( - new_cache[layer_idx][kv_idx], - legacy_cache_converted[layer_idx][kv_idx], - ) - ) - - @pytest.mark.generate - def test_generate_with_static_cache(self): - """ - Tests if StaticCache works if we set attn_implementation=static when generation. - This doesn't test if generation quality is good, but tests that models with - self._supports_static_cache don't throw an error when generating and return - a StaticCache object at the end. - """ - for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest(reason="This model does not support the static cache format") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - if config.is_encoder_decoder: - self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - - config.is_decoder = True - batch_size, seq_length = input_ids.shape - max_new_tokens = 20 - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_length": None, - "max_new_tokens": max_new_tokens, - "cache_implementation": "static", - "return_dict_in_generate": True, # Required to return `past_key_values` - "use_cache": True, - } - - max_cache_len = seq_length + max_new_tokens - config = config.text_config if hasattr(config, "text_config") else config - head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_hidden_layers = config.num_hidden_layers - results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict) - - cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) - self.assertTrue(isinstance(results.past_key_values, StaticCache)) - self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape) - - @require_quanto - @pytest.mark.generate - def test_generate_with_quant_cache(self): - for model_class in self.all_generative_model_classes: - if not model_class._supports_quantized_cache: - self.skipTest(reason="This model does not support the quantized cache format") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - config.is_decoder = True - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_new_tokens": 5, - "cache_implementation": "quantized", - # careful with group size, should be divisor of model's hidden size - "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, - "return_dict_in_generate": True, # Required to return `past_key_values` - "use_cache": True, - } - - results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict) - self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache)) - - # passing past key values of different type should raise Error - with self.assertRaises(ValueError): - num_hidden_layers = config.get_text_config().num_hidden_layers - model.generate( - input_ids, - attention_mask=attention_mask, - past_key_valyes=DynamicCache(num_hidden_layers), - **generation_kwargs, - ) - - # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense - generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128} - with self.assertRaises(ValueError): - model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) - - @pytest.mark.generate - @require_torch_gpu - @slow - @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky - def test_generate_compile_fullgraph(self): - """ - Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. - ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ - """ - for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest("This model doesn't support static cache") - # TODO (joao) -- fix and enable me :) - if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): - self.skipTest("whisper model end-to-end generate compile not yet supported") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # TODO (joao) -- fix and enable me :) - if config.is_encoder_decoder: - self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") - - model = model_class(config).to(torch_device) - model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time - - input_ids = inputs_dict["input_ids"].to(torch_device) - # creates two sets of *different* inputs with the same shape - half_batch_size = input_ids.shape[0] // 2 - input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]] - self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape) - - generation_kwargs = { - "do_sample": False, - "max_new_tokens": 10, - } - - max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"] - config = config.get_text_config() - past_key_values = StaticCache( - config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device - ) - - for model_inputs in input_ids_sets: - # eager dynamic cache - output_dynamic = model.generate(model_inputs, **generation_kwargs) - - # end-to-end compiled dynamic cache - torch.compiler.reset() - compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") - generation_config = copy.deepcopy(model.generation_config) - generation_config.update(**generation_kwargs) - output_compiled = compiled_generate( - model_inputs, generation_config=generation_config, past_key_values=past_key_values - ) - self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) - - @pytest.mark.generate - def test_generate_methods_with_num_logits_to_keep(self): - for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() - config.use_cache = True - config.is_decoder = True - - model = model_class(config).to(torch_device).eval() - # All generation methods (except assisted decoding) rely on always extracting the last token logits of the - # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, - # other methods will work as well) - generation_kwargs = { - "max_new_tokens": 10, - "do_sample": False, - } - - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 - ) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) - without_all_logits = model.generate( - input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs - ) - self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) - - @pytest.mark.generate - @is_flaky() # assisted generation tests are flaky (minor fp ops differences) - def test_assisted_decoding_with_num_logits_to_keep(self): - for model_class in self.all_generative_model_classes: - if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): - self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") - if model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") - - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1) - config.use_cache = True - config.is_decoder = True - - model = model_class(config).to(torch_device).eval() - assistant_model = model - # All generation methods (except assisted decoding) rely on always extracting the last token logits of the - # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, - # other methods will work as well) - generation_kwargs = { - "max_new_tokens": 10, - "do_sample": False, - "assistant_model": assistant_model, - } - - assistant_model.generation_config.assistant_confidence_threshold = None - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 - ) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) - without_all_logits = model.generate( - input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs - ) - self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) - - @pytest.mark.generate - def test_inherits_generation_mixin(self): - """ - Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` - to inherit it. - """ - for model_class in self.all_generative_model_classes: - self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - - def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): - batch_size, seq_length = input_ids.shape - config = config.text_config if hasattr(config, "text_config") else config - num_sequences_in_output = batch_size * num_return_sequences - - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # scores - self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) - - # unprocessed logits - self._check_logits(num_sequences_in_output, output.logits, config=config) - - # Attentions - if self.has_attentions: - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) - # decoder - self._check_attentions_for_generate( - num_sequences_in_output, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - num_sequences_in_output, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - if config.is_encoder_decoder: - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, batch_size, config, seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - num_sequences_in_output, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_hidden_states_for_generate( - num_sequences_in_output, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Past Key Value States -- a few notes here: - # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" - # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the - # standard cache format (e.g.gptbigcode ) - models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") - has_standard_cache = not any( - model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache - ) - if has_standard_cache: - if use_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - num_sequences_in_output, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) - elif use_cache is False: - self.assertTrue(output.past_key_values is None) - - def _check_scores(self, batch_size, scores, length, config): - vocab_size = config.get_text_config(decoder=True).vocab_size - expected_shape = (batch_size, vocab_size) - self.assertIsInstance(scores, tuple) - self.assertEqual(len(scores), length) - self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) - - def _check_logits(self, batch_size, scores, config): - vocab_size = config.get_text_config(decoder=True).vocab_size - self.assertIsInstance(scores, tuple) - self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) - # vocabulary difference equal to one (imagegptmodel?) or zero (all other models) - vocab_diff = vocab_size - scores[0].shape[-1] - self.assertTrue(vocab_diff in [0, 1]) - self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) + def _check_scores(self, batch_size, scores, length, config): + expected_shape = (batch_size, config.vocab_size) + self.assertIsInstance(scores, tuple) + self.assertEqual(len(scores), length) + self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 @@ -2329,30 +2318,6 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c [encoder_expected_shape] * len(hidden_states), ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): - self.assertIsInstance(past_key_values, tuple) - self.assertListEqual( - [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], - [True] * len(past_key_values), - ) - - # (batch, head, seq_length, head_features) - expected_shape = ( - batch_size * num_beam_groups, - config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, - seq_length, - config.hidden_size // config.num_attention_heads, - ) - # check shape key, value - self.assertListEqual( - [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), - ) - self.assertListEqual( - [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), - ) - def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device. @@ -2377,45 +2342,6 @@ def _check_sequence_inside_sequence(self, tensor_1, tensor_2): self.assertTrue(flag) -@require_torch -class UtilsFunctionsTest(unittest.TestCase): - def test_speculative_sampling(self): - # assume vocab size 10, input length 5 + 3 generated candidates - candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens - candidate_logits = torch.tensor( - [ - [ - [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1 - [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4 - [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5 - ] - ] - ) - candidate_length = 3 - inf = float("inf") - new_logits = torch.tensor( - [ - [ - [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1 - [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4 - [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8 - [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A - ] - ] - ) - last_assistant_token_is_eos = False - validated_tokens, n_matches = _speculative_sampling( - candidate_input_ids, - candidate_logits, - candidate_length, - new_logits, - last_assistant_token_is_eos, - ) - self.assertTrue(n_matches.item() == 2) - self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) - - -@pytest.mark.generate @require_torch class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): # setting framework_dependent_parameters needs to be gated, just like its contents' imports @@ -2433,7 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi } @slow - @pytest.mark.skip("Group beam search is not supported by optimum-habana") def test_diverse_beam_search(self): # PT-only test: TF doesn't have a diverse beam search implementation article = """Justin Timberlake and Jessica Biel, welcome to parenthood. @@ -2468,101 +2393,284 @@ def test_diverse_beam_search(self): ], ) - def test_max_length_if_input_embeds(self): + def test_max_length_backward_compat_greedy(self): # PT-only test: TF doesn't have StoppingCriteria - article = "Today a dragon flew over Paris." - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - inputs_embeds = model.get_input_embeddings()(input_ids) - - # Controlling max_length via the configuration is deprecated in favor of max_new_tokens - max_new_tokens = 20 - input_len = input_ids.shape[-1] - out_gen = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens) - out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens) - self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - def test_min_length_if_input_embeds(self): - # PT-only test: TF doesn't have StoppingCriteria - article = "Today a dragon flew over Paris." - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - inputs_embeds = model.get_input_embeddings()(input_ids) + max_length = 20 + input_ids = input_ids.expand(2, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=bart_model.config.decoder_start_token_id, + bos_token_id=bart_model.config.bos_token_id, + ) - # Controlling max_length via the configuration is deprecated in favor of max_new_tokens - min_length = 10 - input_len = input_ids.shape[-1] - out_gen = model.generate(input_ids=input_ids, min_length=min_length, max_new_tokens=20) - out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length, max_new_tokens=20) - self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) + with self.assertWarns(UserWarning): + bart_model.greedy_search( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) - def test_custom_stopping_criteria_overload_error(self): + def test_max_length_backward_compat_sample(self): # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) - + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(MaxLengthCriteria(max_length=42)) - with self.assertRaises(ValueError): - bart_model.generate(input_ids, stopping_criteria=stopping_criteria) - with self.assertRaises(ValueError): - bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) - def test_custom_stopping_criteria(self): + max_length = 20 + input_ids = input_ids.expand(2, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=bart_model.config.decoder_start_token_id, + bos_token_id=bart_model.config.bos_token_id, + ) + with torch.no_grad(): + with self.assertWarns(UserWarning): + bart_model.sample( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) + + def test_max_length_backward_compat_beam_search(self): # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - class DummyCriteria(StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return input_ids.shape[-1] >= 20 - - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(DummyCriteria()) + batch_size = 1 + max_length = 20 + num_beams = 2 - output = bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22) - self.assertEqual( - list(output.shape), - [1, 22], # still produces the max_length + input_ids = input_ids.expand(2, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=bart_model.config.decoder_start_token_id, + bos_token_id=bart_model.config.bos_token_id, ) - # make sure final tokens are padding - self.assertEqual(output[:, 20:].tolist(), [[bart_model.config.pad_token_id, bart_model.config.pad_token_id]]) - self.assertEqual( - list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape), - [1, 18], + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, ) + with self.assertWarns(UserWarning): + _ = bart_model.beam_search( + input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs + ) - # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_stop_sequence_stopping_criteria(self): - # PT-only test: TF doesn't have StoppingCriteria - prompt = """Hello I believe in""" - generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") - output = generator(prompt) - self.assertEqual( - output, - [{"generated_text": ("Hello I believe in we we we we we we we we we")}], + def test_max_length_backward_compat_group_beam_search(self): + # PT-only test: TF doesn't have StoppingCriteria & group beam search + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - output = generator(prompt, stop_sequence=" we") - self.assertEqual(output, [{"generated_text": "Hello I believe in we"}]) + batch_size = 1 + max_length = 20 + num_beams = 6 + num_beam_groups = 3 + num_return_sequences = num_beams * batch_size - def test_generate_non_nlp_input_ids_as_kwarg(self): - # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input - model = ImageGPTForCausalImageModeling.from_pretrained( - "hf-internal-testing/tiny-random-imagegpt", max_length=10 - ).to(torch_device) - input_ids = ids_tensor((3, 5), vocab_size=10) + input_ids = input_ids.expand(6, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=bart_model.config.decoder_start_token_id, + bos_token_id=bart_model.config.bos_token_id, + ) - output_sequences_kwargs = model.generate(input_ids=input_ids).cpu() + diverse_beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, + ) + with self.assertWarns(UserWarning): + bart_model.group_beam_search( + input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs + ) + + def test_max_length_warning_if_different(self): + # PT-only test: TF doesn't have StoppingCriteria + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + batch_size = 1 + + max_length = 20 + num_beams = 6 + num_beam_groups = 3 + num_return_sequences = num_beams * batch_size + stopping_criteria_max_length = 18 + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) + + # Greedy + input_ids = input_ids.expand(6, -1) + model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=bart_model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=bart_model.config.decoder_start_token_id, + bos_token_id=bart_model.config.bos_token_id, + ) + + with self.assertWarns(UserWarning): + bart_model.greedy_search( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + stopping_criteria=stopping_criteria, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) + + # Sample + with self.assertWarns(UserWarning): + with torch.no_grad(): + bart_model.sample( + input_ids, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) + + # Beam + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, + ) + with self.assertWarns(UserWarning): + with torch.no_grad(): + bart_model.beam_search( + input_ids, + num_beams=num_beams, + stopping_criteria=stopping_criteria, + max_length=max_length, + beam_scorer=beam_scorer, + **model_kwargs, + ) + + # Grouped beam search + diverse_beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=torch_device, + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, + ) + with self.assertWarns(UserWarning): + bart_model.group_beam_search( + input_ids, + diverse_beam_scorer, + stopping_criteria=stopping_criteria, + num_beams=num_beams, + max_length=max_length, + **model_kwargs, + ) + + def test_custom_stopping_criteria_overload_error(self): + # PT-only test: TF doesn't have StoppingCriteria + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(MaxLengthCriteria(max_length=42)) + with self.assertRaises(ValueError): + bart_model.generate(input_ids, stopping_criteria=stopping_criteria) + with self.assertRaises(ValueError): + bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32) + + def test_custom_stopping_criteria(self): + # PT-only test: TF doesn't have StoppingCriteria + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + class DummyCriteria(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids.shape[-1] >= 20 + + stopping_criteria = StoppingCriteriaList() + stopping_criteria.append(DummyCriteria()) + + self.assertEqual( + list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22).shape), + [1, 20], + ) + self.assertEqual( + list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape), + [1, 18], + ) + + def test_stop_sequence_stopping_criteria(self): + # PT-only test: TF doesn't have StoppingCriteria + prompt = """Hello I believe in""" + generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") + output = generator(prompt) + self.assertEqual( + output, + [ + { + "generated_text": ( + "Hello I believe in in in number number number number number number number number number" + ) + } + ], + ) + + output = generator(prompt, stop_sequence=" number") + self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + + def test_generate_non_nlp_input_ids_as_kwarg(self): + # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input + model = ImageGPTForCausalImageModeling.from_pretrained( + "hf-internal-testing/tiny-random-imagegpt", max_length=10 + ).to(torch_device) + input_ids = ids_tensor((3, 5), vocab_size=10) + + output_sequences_kwargs = model.generate(input_ids=input_ids).cpu() output_sequences = model.generate(input_ids).cpu() self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) @@ -2579,7 +2687,6 @@ def test_generate_input_values_as_encoder_kwarg(self): self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) self.assertEqual(output_sequences.shape, (2, 5)) - @pytest.mark.skip("Group beam search is not supported by optimum-habana") def test_transition_scores_group_beam_search_encoder_decoder(self): # PT-only test: TF doesn't have group beam search articles = [ @@ -2609,61 +2716,13 @@ def test_transition_scores_group_beam_search_encoder_decoder(self): self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_beam_search_low_memory(self): - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - tokenizer.pad_token_id = tokenizer.eos_token_id - model_inputs = tokenizer("I", return_tensors="pt")["input_ids"] - - low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True) - - high_output = model.generate( - model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False - ) - self.assertListEqual(low_output.tolist(), high_output.tolist()) - - @slow - @pytest.mark.skip("Watermarking is not supported by optimum-habana yet") - def test_watermark_generation(self): - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer.pad_token_id = tokenizer.eos_token_id - model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) - input_len = model_inputs["input_ids"].shape[-1] - - # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :) - watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") - _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) - - # We will not check watermarked text, since we check it in `logits_processors` tests - # Checking if generated ids are as expected fails on different hardware - args = { - "bias": 2.0, - "context_width": 1, - "seeding_scheme": "selfhash", - "greenlist_ratio": 0.25, - "hashing_key": 15485863, - } - output = model.generate(**model_inputs, do_sample=False, max_length=15) - output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) - - # Check that the detector is detecting watermarked text - detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) - detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) - detection_out = detector(output[:, input_len:], return_dict=True) - - self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) - self.assertListEqual(detection_out.prediction.tolist(), [False]) - @slow def test_beam_search_example_integration(self): # PT-only test: TF doesn't have a BeamSearchScorer # exactly the example provided in the docstrings of beam search, which previously # failed after directly copying from it. Refer to PR #15555 - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") encoder_input_str = "translate English to German: How old are you?" encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids @@ -2671,15 +2730,31 @@ def test_beam_search_example_integration(self): # lets run beam search using 3 beams num_beams = 3 # define decoder start token ids - input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} + model_kwargs = { + "encoder_outputs": model.get_encoder()( + encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ) + } - outputs = model.generate( - input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs + # instantiate beam scorer + beam_scorer = BeamSearchScorer( + batch_size=1, + num_beams=num_beams, + device=model.device, + ) + + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ] ) + + outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual(outputs, ["Wie alt bist du?"]) @@ -2687,8 +2762,8 @@ def test_beam_search_example_integration(self): @slow def test_constrained_beam_search(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids @@ -2725,8 +2800,8 @@ def test_constrained_beam_search(self): @slow def test_constrained_beam_search_mixed(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids flexible_phrases = tokenizer( @@ -2766,8 +2841,8 @@ def test_constrained_beam_search_mixed(self): @slow def test_constrained_beam_search_mixed_mixin(self): # PT-only test: TF doesn't have constrained beam search - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") force_word = "scared" force_flexible = ["scream", "screams", "screaming", "screamed"] @@ -2802,15 +2877,9 @@ def test_constrained_beam_search_mixed_mixin(self): ) @slow - @pytest.mark.xfail def test_cfg_mixin(self): - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - - # add pad_token_id for static shape - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - model.generation_config.pad_token_id = model.generation_config.eos_token_id + model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True) input["input_ids"] = input["input_ids"].to(torch_device) @@ -2850,8 +2919,8 @@ def test_cfg_mixin(self): @slow def test_constrained_beam_search_example_translation_mixin(self): # PT-only test: TF doesn't have constrained beam search - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") encoder_input_str = "translate English to German: How old are you?" force_words = ["sind"] @@ -2875,8 +2944,8 @@ def test_constrained_beam_search_example_translation_mixin(self): @slow def test_constrained_beam_search_example_integration(self): # PT-only test: TF doesn't have constrained beam search - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") encoder_input_str = "translate English to German: How old are you?" encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids @@ -2884,65 +2953,38 @@ def test_constrained_beam_search_example_integration(self): # lets run beam search using 5 beams num_beams = 5 # define decoder start token ids - input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) + input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) input_ids = input_ids * model.config.decoder_start_token_id # add encoder_outputs to model keyword arguments - model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} + model_kwargs = { + "encoder_outputs": model.get_encoder()( + encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ) + } constraint_str = "sind" constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token + constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - outputs = model.generate( - input_ids, - num_beams=num_beams, - force_words_ids=[constraint_token_ids], - min_length=5, - eos_token_id=model.config.eos_token_id, - **model_kwargs, + # instantiate beam scorer + beam_scorer = ConstrainedBeamSearchScorer( + batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints ) - outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - self.assertListEqual(outputs, ["Wie alt sind Sie?"]) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - @slow - def test_per_row_stopping_criteria(self): - text = [ - "They completed the challenging puzzle, revealing the hidden", - "Today a dragon flew over France", - "The aroma of freshly baked pizza filled the kitchen", - ] - stop_strings = ["secrets"] + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ] + ) - model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - tokenizer.padding_side = "left" - tokenizer.pad_token_id = tokenizer.eos_token_id - input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to( - torch_device + outputs = model.constrained_beam_search( + input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - # normal generation with one stopping criteria - out = model.generate(input_ids, max_length=15) - out_text = tokenizer.batch_decode(out) - expected_out = [ - "They completed the challenging puzzle, revealing the hidden secrets of the world.\n", - "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", - "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", - ] - self.assertListEqual(out_text, expected_out) - - # generation should stop at "secrets" for first batch only, filling the rest with eos tokens - out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer) - out_text = tokenizer.batch_decode(out) - expected_out = [ - "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", - "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", - "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", - ] - self.assertListEqual(out_text, expected_out) + self.assertListEqual(outputs, ["Wie alt sind Sie?"]) def test_constrained_beam_search_mixin_type_checks(self): # PT-only test: TF doesn't have constrained beam search @@ -2985,55 +3027,6 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) - def test_batched_decoder_start_id(self): - # PT-only test: TF doesn't support batched_decoder_start_id - articles = [ - "Justin Timberlake and Jessica Biel, welcome to parenthood.", - "Michael Phelps is arguably the most decorated Olympian of all time.", - ] - bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) - decoder_start_token_id = bart_model.generation_config.decoder_start_token_id - decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0] - - outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id) - - outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch) - - self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) - - def test_decoder_start_id_from_config(self): - # Refer to: (#30899) - articles = [ - "Justin Timberlake and Jessica Biel, welcome to parenthood.", - "Michael Phelps is arguably the most decorated Olympian of all time.", - ] - bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) - decoder_start_token_id = bart_model.generation_config.decoder_start_token_id - - # we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type - outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) - - # If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config - bart_model.generation_config.decoder_start_token_id = None - bart_model.generation_config.bos_token_id = None - outputs_with_user_id = bart_model.generate( - input_ids, - generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id), - ) - - self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist()) - - with self.assertRaises(ValueError): - outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False)) - def test_contrastive_search_batched(self): # PT-only test: TF doesn't have constrained beam search # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) @@ -3060,27 +3053,6 @@ def test_contrastive_search_batched(self): max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max() self.assertTrue(max_score_diff < 1e-5) - def test_logits_processor_not_inplace(self): - # PT-only test: TF fixes were not made - article = "Today a dragon flew over Paris." - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - - out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True) - out_with_temp = model.generate( - input_ids, - temperature=0.5, - do_sample=True, - output_logits=True, - output_scores=True, - return_dict_in_generate=True, - ) - - # if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores - self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist()) - self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist()) - def test_eos_token_id_int_and_list_top_k_top_sampling(self): # Has TF equivalent: this test relies on random sampling generation_kwargs = { @@ -3135,10 +3107,6 @@ def forward(self, input_ids, foo=None, **kwargs): # because it doesn't do signature filtering. class FakeEncoder(bart_model.model.encoder.__class__): def forward(self, input_ids, **kwargs): - # We remove these to pass gaudi_BartEncoder_forward TypeError - kwargs.pop("bucket_size", None) - kwargs.pop("bucket_internal", None) - kwargs.pop("reduce_recompile", None) return super().forward(input_ids, **kwargs) fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device) @@ -3153,16 +3121,15 @@ def forward(self, input_ids, **kwargs): def test_default_max_length_warning(self): model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.generation_config.pad_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.eos_token_id text = "Hello world" tokenized_inputs = tokenizer([text], return_tensors="pt") input_ids = tokenized_inputs.input_ids.to(torch_device) # Default generation config value of 20 -> emits warning - # NOTE: in OH we do not have this warning - # with self.assertWarns(UserWarning): - # model.generate(input_ids) + with self.assertWarns(UserWarning): + model.generate(input_ids) # Explicitly setting max_length to 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: @@ -3171,805 +3138,7 @@ def test_default_max_length_warning(self): # Generation config max_length != 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: - # generation_config is modified -> legacy mode is disabled = generation_config takes precedence model.generation_config.max_length = 10 + model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence model.generate(input_ids) self.assertEqual(len(warning_list), 0) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_length_warning_assisted_generation(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.generation_config.pad_token_id = tokenizer.eos_token_id - assistant.generation_config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # This should not raise any warning that min length is not feasible in candidate generation - with warnings.catch_warnings(record=True) as warning_list: - model.generate( - input_ids, - assistant_model=assistant, - min_new_tokens=10, - max_length=20, - ) - self.assertEqual(len(warning_list), 0) - - def test_default_assisted_generation(self): - # Initialize the GenerationConfig object - config = GenerationConfig() - - # Check the default values - self.assertEqual(config.num_assistant_tokens, 20) - self.assertEqual(config.num_assistant_tokens_schedule, "constant") - self.assertEqual(config.assistant_confidence_threshold, 0.4) - self.assertEqual(config.is_assistant, False) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_generated_length_assisted_generation(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.generation_config.pad_token_id = tokenizer.eos_token_id - assistant.generation_config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - input_length = input_ids.shape[-1] - - out = model.generate( - input_ids, - assistant_model=assistant, - min_new_tokens=10, - max_new_tokens=20, - ) - self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length)) - - out = model.generate( - input_ids, - assistant_model=assistant, - min_new_tokens=10, - ) - self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_model_kwarg_assisted_decoding_decoder_only(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.generation_config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with token_type_ids - outputs_tti = model.generate( - input_ids, - token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), - ) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - assistant.config.pad_token_id = tokenizer.eos_token_id - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_model_kwarg_assisted_decoding_encoder_decoder(self): - """ - Tests that the following scenario is compatible with assisted generation: - 1. encoder-decoder main model - 2. encoder-decoder assistant model - 3. both have a custom input - (e.g. Whisper) - """ - - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg that distorts the output - class FakeBart(BartForConditionalGeneration): - def forward(self, input_ids, past_key_values, foo=False, **kwargs): - outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate(input_ids, foo=True) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = assistant.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - assistant_encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_assisted_decoding_encoder_decoder_shared_encoder(self): - """ - Tests that the following scenario is compatible with assisted generation: - 1. encoder-decoder main model - 2. decoder-only assistant model - 3. both have a custom input - (e.g. DistilWhisper) - """ - - # PT-only test: TF doesn't support assisted decoding yet. - # Bart subclass with a kwarg called foo that distorts the output - class FakeBartSeq2Seq(BartForConditionalGeneration): - def forward(self, input_ids, foo=False, **kwargs): - outs = super().forward(input_ids, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - class FakeBartCausalLM(BartForCausalLM): - def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs): - outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs) - if foo: - outs["logits"][:, :, :] = 0.0 - return outs - - def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): - kwargs["encoder_outputs"] = encoder_outputs - inputs = super().prepare_inputs_for_generation(*args, **kwargs) - inputs["foo"] = foo - return inputs - - model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( - torch_device - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_normal = model.generate(input_ids) - self.assertEqual(outputs_normal.shape, (1, 20)) - - # Should be different with foo - outputs_foo = model.generate(input_ids, foo=True) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) - - # Assistant model - assistant = FakeBartCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-BartForConditionalGeneration" - ).to(torch_device) - - # If assisted generation passes model_kwargs correctly, should be same as previous - outputs_assisted = model.generate( - input_ids, - foo=True, - assistant_model=assistant, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # Check that passing encoder_outputs directly also works as expected - encoder_outputs = model.get_encoder()(input_ids) - - outputs_assisted = model.generate( - foo=True, - assistant_model=assistant, - encoder_outputs=encoder_outputs, - ) - self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): - # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. - - prompt = "Alice and Bob" - checkpoint = "EleutherAI/pythia-160m-deduped" - tokenizer = AutoTokenizer.from_pretrained(checkpoint) - inputs = tokenizer(prompt, return_tensors="pt") - - model = AutoModelForCausalLM.from_pretrained(checkpoint) - - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 5 - assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic" - generation_kwargs = { - "eos_token_id": -1, - "max_new_tokens": 5, - "do_sample": False, - "assistant_model": assistant_model, - } - model.generate(**inputs, **generation_kwargs) - # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7 - self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7)) - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self): - # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly. - - prompt = "Alice and Bob" - checkpoint = "EleutherAI/pythia-160m-deduped" - tokenizer = AutoTokenizer.from_pretrained(checkpoint) - inputs = tokenizer(prompt, return_tensors="pt") - - model = AutoModelForCausalLM.from_pretrained(checkpoint) - - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 5 - assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient" - generation_kwargs = { - "eos_token_id": -1, - "max_new_tokens": 5, - "do_sample": False, - "assistant_model": assistant_model, - } - model.generate(**inputs, **generation_kwargs) - # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 - self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) - - # TODO [gustavo] Enable this test to Optimum-habana - @slow - @pytest.mark.xfail - def test_validate_assistant(self): - # Generate a random sample: - inputs = np.random.rand(160000) - - # Load a main encoder-decoder model: - model_id = "openai/whisper-large-v2" - processor = AutoProcessor.from_pretrained(model_id) - model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, - low_cpu_mem_usage=True, - use_safetensors=True, - ) - model.to(torch_device) - - # process the input: - features = processor(inputs, return_tensors="pt").to(torch_device) - - # Load an encoder-decoder assistant with same encoder as the main model: - assistant_distil_model_id = "distil-whisper/distil-large-v2" - assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( - assistant_distil_model_id, - use_safetensors=True, - ).to(torch_device) - self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) - - # Load its decoder only version: - assistant_causal_lm = AutoModelForCausalLM.from_pretrained( - assistant_distil_model_id, - low_cpu_mem_usage=True, - use_safetensors=True, - ).to(torch_device) - self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum()) - - # Load an encoder-decoder assistant with a different encoder than the main model: - assistant_distil_model_id = "openai/whisper-tiny" - assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( - assistant_distil_model_id, - use_safetensors=True, - ).to(torch_device) - self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum()) - - # Load its decoder only version: - assistant_causal_lm = AutoModelForCausalLM.from_pretrained( - assistant_distil_model_id, - low_cpu_mem_usage=True, - use_safetensors=True, - ).to(torch_device) - # It will raise an error as the encoder of the main and assistant model are not compatible: - with self.assertRaises(ValueError): - model.generate(**features, assistant_model=assistant_causal_lm) - - # Load an encoder-decoder model with a different tokenizer than the main model: - assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText" - assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained( - assistant_distil_model_id, - ).to(torch_device) - # This should raise an error as the main and assistant model don't use the same tokenizer: - with self.assertRaises(ValueError): - model.generate(**features, assistant_model=assistant_seq_to_seq) - - def test_compare_unprocessed_logit_scores(self): - # Get unprocessed logit scores back from model generate function. - # Assert that unprocessed logits from generate() are same as those from modal eval() - - # tell model to generate text and return unprocessed/unwarped logit scores - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - text = "generate yes or no: " - input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) - - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - - with torch.no_grad(): - # Get logits for the next token from fwd pass - logits_fwd = model(input_ids).logits[:, -1, :][0] - - # Get logits for the next token from generate function - outputs = model.generate( - input_ids=input_ids, - return_dict_in_generate=True, - output_logits=True, - max_new_tokens=1, - do_sample=True, - ) - logits_gen = outputs.logits[0][0] - - # assert that unprocessed logits from generate() are same as those from modal eval() - self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist()) - - def test_return_unprocessed_logit_scores(self): - # tell model to generate text and return unprocessed/unwarped logit scores - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - text = "generate yes or no: " - input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - - outputs = model.generate( - input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3 - ) - - # perform dummy check if unpreprocessed logits make sense. - # do preselection on high probabilities; find scores of y and n tokens - probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1) - indices = torch.argwhere(probs_all > 0.001) - indices = indices[:, -1] - tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True) - probs_max = probs_all[probs_all > 0.001] - - self.assertTrue(len(indices) >= 2) - next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)} - self.assertTrue("n" in next_token_dict) - self.assertTrue("y" in next_token_dict) - y_prob = next_token_dict["y"] - n_prob = next_token_dict["n"] - - self.assertTrue(y_prob > 0.001 and n_prob > 0.001) - self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) - - @slow - @require_torch_multi_gpu - def test_assisted_decoding_in_different_gpu(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( - "cuda:1" - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") - model.config.pad_token_id = tokenizer.eos_token_id - assistant.config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - input_length = input_ids.shape[-1] - - out = model.generate( - input_ids, - assistant_model=assistant, - max_new_tokens=20, - ) - self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) - - @slow - @require_torch_gpu - def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( - "cpu" - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") - model.config.pad_token_id = tokenizer.eos_token_id - assistant.config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - input_length = input_ids.shape[-1] - - out = model.generate( - input_ids, - assistant_model=assistant, - max_new_tokens=20, - ) - self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) - - def test_special_tokens_fall_back_to_model_default(self): - # PT-only test: TF doesn't support assisted decoding yet. - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( - torch_device - ) - test_bos_id = 50 - - # Sanity-check: the model has a BOS token set, and the first generated token is a BOS token - gen_output = model.generate() - self.assertTrue(model.generation_config.bos_token_id is not None) - self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) - - # If we pass a generation config **with** a BOS token, `generate` will use it - generation_config = GenerationConfig(bos_token_id=test_bos_id) - gen_output = model.generate(generation_config=generation_config) - self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0]) - self.assertTrue(generation_config.bos_token_id == gen_output[0, 0]) - self.assertTrue(test_bos_id == gen_output[0, 0]) - - # If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from - # `model.generation_config` - generation_config = GenerationConfig(bos_token_id=None) - gen_output = model.generate(generation_config=generation_config) - self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) - self.assertFalse(test_bos_id == gen_output[0, 0]) - self.assertTrue(generation_config.bos_token_id is None) - - # Changing `model.generation_config` will affect fallback behavior - model.generation_config.bos_token_id = test_bos_id - gen_output = model.generate(generation_config=generation_config) - self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) - self.assertTrue(test_bos_id == gen_output[0, 0]) - self.assertTrue(generation_config.bos_token_id is None) - - @pytest.mark.generate - @require_torch_multi_gpu - def test_generate_with_static_cache_multi_gpu(self): - """ - Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus. - """ - # need to split manually as auto doesn't work well with unbalanced model - device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} - model = AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - generation_kwargs = { - "max_new_tokens": 20, - "cache_implementation": "static", - "return_dict_in_generate": True, # Required to return `past_key_values` - } - - results = model.generate(input_ids, **generation_kwargs) - self.assertTrue(isinstance(results.past_key_values, StaticCache)) - - # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) - - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) - - @pytest.mark.generate - @require_torch_multi_gpu - def test_init_static_cache_multi_gpu(self): - """ - Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup. - """ - # need to split manually as auto doesn't work well with unbalanced model - device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} - model = AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") - - text = "Hello world" - tokenized_inputs = tokenizer([text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - generation_kwargs = { - "max_new_tokens": 20, - "return_dict_in_generate": True, # Required to return `past_key_values` - } - - # TODO: We need to raise a warning in case the cache is not set correctly - # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"): - # past_key_values = StaticCache( - # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype - # ) - # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) - - # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1 - layer_device_map = {0: 0, 1: 1} - past_key_values = StaticCache( - config=model.config, - batch_size=1, - max_cache_len=30, - device=torch_device, - dtype=model.dtype, - layer_device_map=layer_device_map, - ) - results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs) - - # check device of each layer - key_cache_0 = results.past_key_values.key_cache[0] - value_cache_0 = results.past_key_values.value_cache[0] - self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0)) - - key_cache_1 = results.past_key_values.key_cache[1] - value_cache_1 = results.past_key_values.value_cache[1] - self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) - - @slow - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_padding_input_contrastive_search_gpt2(self): - # Load the pre-trained GPT-2 model and tokenizer - model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") - model.to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True) - - # Set the tokenizer to left-pad the sequences - tokenizer.padding_side = "left" - - # Define the PAD token as the EOS token - tokenizer.pad_token = tokenizer.eos_token - model.generation_config.pad_token_id = model.generation_config.eos_token_id - - # Define the input prompt - prompt_text = "The whispered legends of the haunted mansion spoke" - - # Tokenize the input prompt - encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True) - input_ids = encoded_prompt.input_ids.to(torch_device) - attention_mask = encoded_prompt.attention_mask.to(torch_device) - - # Define the contrastive search params - penalty_alpha = 0.6 - top_k = 4 - - # Define the padding length to add to the input IDs and attention mask - padding_length = 10 - - # Generate text without padding - outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Pad the input IDs and attention mask on the left - padded_input_ids = F.pad( - input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id - ) - padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0) - - # Generate text with padded inputs - outputs_with_padding = model.generate( - input_ids=padded_input_ids, - attention_mask=padded_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) - - # Assert that the generated texts are identical for padded and non-padded inputs - self.assertEqual(generated_text_no_padding, generated_text_with_padding) - self.assertEqual( - generated_text_with_padding, - 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling ' - 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been ' - 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea', - ) - - @slow - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_padding_input_contrastive_search_t5(self): - # Load the pre-trained T5 model and tokenizer - model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small") - model.to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True) - - # Define the input prompt - prompt_text = "translate English to German: I need to finish this task before the end of the day." - - # Tokenize the input prompt - encoded_prompt = tokenizer(prompt_text, return_tensors="pt") - input_ids = encoded_prompt.input_ids.to(torch_device) - attention_mask = encoded_prompt.attention_mask.to(torch_device) - - # Define the decoder prompt - decoder_prompt_text = "Ich muss diese Aufgabe" - encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt") - decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device) - decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device) - - # Define the contrastive search params - penalty_alpha = 0.6 - top_k = 4 - - # Generate text without padding - outputs = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True) - - # Define the padding length to add to the input IDs and attention mask - padding_length = 10 - - # Pad the decoder input IDs and attention mask on the left - padded_decoder_input_ids = F.pad( - decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id - ) - padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0) - # Since the decoder_start_token_id is the same as the pad_token_id, - # the last padded token represents the decoder start token. - # Set the attention mask for the decoder_start_token_id to True (1). - padded_decoder_attention_mask[:, padding_length - 1] = 1 - # Generate text with padded inputs - outputs_with_padding = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=padded_decoder_input_ids, - decoder_attention_mask=padded_decoder_attention_mask, - do_sample=False, - penalty_alpha=penalty_alpha, - top_k=top_k, - max_new_tokens=64, - ) - generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True) - - # Assert that the generated texts are identical for padded and non-padded inputs - self.assertEqual(generated_text_no_padding, generated_text_with_padding) - self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") - - # TODO [gustavo] Enable this test to Optimum-habana - @pytest.mark.xfail - def test_generate_compile_fullgraph_tiny(self): - """ - Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) - NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the - non-slow tests to prevent regressions! - """ - model = AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto" - ) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") - - # compile generate - compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") - - # compiled generate does NOT accept parameterization except a) model inputs b) a generation config - generation_config = copy.deepcopy(model.generation_config) - generation_config.pad_token_id = model.config.eos_token_id - - model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt") - model_inputs = model_inputs.to(model.device) - gen_out = compiled_generate(**model_inputs, generation_config=generation_config) - self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated - - -@require_torch -class TokenHealingTestCase(unittest.TestCase): - @parameterized.expand( - [ - ( - "square_bracket", - 'An example ["like this"] and another example [', - 'An example ["like this"] and another example ["', - ), - ("url", 'The link is