diff --git a/conftest.py b/conftest.py
index 5775644c48..71cb6bb7ca 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,88 +1,3 @@
-# coding=utf-8
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# tests directory-specific settings - this file is run automatically
-# by pytest before any tests are run
-import doctest
-import sys
-import warnings
-from os.path import abspath, dirname, join
-
-import _pytest
-import pytest
-from transformers.testing_utils import HfDoctestModule, HfDocTestParser
-
-
-NOT_DEVICE_TESTS = {
- "test_tokenization",
- "test_processor",
- "test_processing",
- "test_beam_constraints",
- "test_configuration_utils",
- "test_data_collator",
- "test_trainer_callback",
- "test_trainer_utils",
- "test_feature_extraction",
- "test_image_processing",
- "test_image_processor",
- "test_image_transforms",
- "test_optimization",
- "test_retrieval",
- "test_config",
- "test_from_pretrained_no_checkpoint",
- "test_keep_in_fp32_modules",
- "test_gradient_checkpointing_backward_compatibility",
- "test_gradient_checkpointing_enable_disable",
- "test_save_load_fast_init_from_base",
- "test_fast_init_context_manager",
- "test_fast_init_tied_embeddings",
- "test_save_load_fast_init_to_base",
- "test_torch_save_load",
- "test_initialization",
- "test_forward_signature",
- "test_model_get_set_embeddings",
- "test_model_main_input_name",
- "test_correct_missing_keys",
- "test_tie_model_weights",
- "test_can_use_safetensors",
- "test_load_save_without_tied_weights",
- "test_tied_weights_keys",
- "test_model_weights_reload_no_missing_tied_weights",
- "test_pt_tf_model_equivalence",
- "test_mismatched_shapes_have_properly_initialized_weights",
- "test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist",
- "test_model_is_small",
- "test_tf_from_pt_safetensors",
- "test_flax_from_pt_safetensors",
- "ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device
- "ModelTester::test_pipeline_",
- "/repo_utils/",
- "/utils/",
- "/agents/",
-}
-
-# allow having multiple repository checkouts and not needing to remember to rerun
-# `pip install -e '.[dev]'` when switching between checkouts and running tests.
-git_repo_path = abspath(join(dirname(__file__), "src"))
-sys.path.insert(1, git_repo_path)
-
-# silence FutureWarning warnings in tests since often we can't act on them until
-# they become normal warnings - i.e. the tests still need to test the current functionality
-warnings.simplefilter(action="ignore", category=FutureWarning)
-
-
class Secret:
"""
Taken from: https://stackoverflow.com/a/67393351
@@ -98,47 +13,9 @@ def __str___(self):
return "*******"
-def pytest_configure(config):
- config.addinivalue_line(
- "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
- )
- config.addinivalue_line(
- "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
- )
- config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
- config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
- config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
- config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule")
- config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
-
-
-def pytest_collection_modifyitems(items):
- for item in items:
- if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS):
- item.add_marker(pytest.mark.not_device_test)
-
-
def pytest_addoption(parser):
parser.addoption("--token", action="store", default=None)
- from transformers.testing_utils import pytest_addoption_shared
-
- pytest_addoption_shared(parser)
-
-
-def pytest_terminal_summary(terminalreporter):
- from transformers.testing_utils import pytest_terminal_summary_main
-
- make_reports = terminalreporter.config.getoption("--make-reports")
- if make_reports:
- pytest_terminal_summary_main(terminalreporter, id=make_reports)
-
-
-def pytest_sessionfinish(session, exitstatus):
- # If no tests are collected, pytest exists with code 5, which makes the CI fail.
- if exitstatus == 5:
- session.exitstatus = 0
-
def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
@@ -146,21 +23,3 @@ def pytest_generate_tests(metafunc):
option_value = Secret(metafunc.config.option.token)
if "token" in metafunc.fixturenames:
metafunc.parametrize("token", [option_value])
-
-
-# Doctest custom flag to ignore output.
-IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT")
-
-OutputChecker = doctest.OutputChecker
-
-
-class CustomOutputChecker(OutputChecker):
- def check_output(self, want, got, optionflags):
- if IGNORE_RESULT & optionflags:
- return True
- return OutputChecker.check_output(self, want, got, optionflags)
-
-
-doctest.OutputChecker = CustomOutputChecker
-_pytest.doctest.DoctestModule = HfDoctestModule
-doctest.DocTestParser = HfDocTestParser
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index d81e0d179a..68b445c1b2 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -211,20 +211,19 @@ def _prepare_decoder_input_ids_for_generation(
# 2. `decoder_start_token_id` must have shape (batch_size, 1)
if device is None:
device = self.device
- if decoder_start_token_id.ndim == 1:
- if decoder_start_token_id.shape[0] != batch_size:
- raise ValueError(
- f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
+ if token_idx is None:
+ if decoder_start_token_id.ndim == 1:
+ if decoder_start_token_id.shape[0] != batch_size:
+ raise ValueError(
+ f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
+ )
+ decoder_start_token_id = decoder_start_token_id.view(-1, 1)
+ else:
+ decoder_start_token_id = (
+ torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
- decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
- decoder_start_token_id = (
- torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
- )
-
- if token_idx is not None:
- # creating padded decoder_input_ids to achieve static shapes.
- # Later new tokens once generated are copied in to decoder_input_ids based on token_idx
+ # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx
max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length
decoder_start_token_id = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
@@ -3040,8 +3039,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
if self.generation_config.early_stopping:
num_eos_tokens.add_(beam_tokens[0:num_beams].eq(self.config.eos_token_id).sum())
- if self.config.eos_token_id is not None:
- beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0))
+ beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0))
beam_scores = beam_scores.view(batch_size, -1).unsqueeze(0)
_, selected = torch.topk(beam_scores, k=num_beams, dim=-1, largest=True, sorted=True)
offset = torch.arange(0, torch.numel(beam_scores), beam_scores.shape[-1]).unsqueeze(-1)
@@ -3213,9 +3211,6 @@ def move(obj, device):
if not output_scores:
sequence_outputs["sequence_scores"] = None
- if self.generation_config.static_shapes:
- raise NotImplementedError("sequence_scores is not implemented for static_shapes")
-
if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py
index 08ea48e1a5..3e5f822cb1 100644
--- a/optimum/habana/transformers/models/bart/modeling_bart.py
+++ b/optimum/habana/transformers/models/bart/modeling_bart.py
@@ -458,9 +458,7 @@ def gaudi_BartDecoder_forward(
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
- tensor_past_key_values_length = (
- token_idx - 1 if (use_cache and token_idx is not None) else torch.tensor(past_key_values_length)
- )
+ tensor_past_key_values_length = token_idx - 1 if use_cache else torch.tensor(past_key_values_length)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
diff --git a/pyproject.toml b/pyproject.toml
index f53b25d1c0..b7896da5e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,12 +41,3 @@ skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
-
-[tool.pytest.ini_options]
-addopts = "--doctest-glob='**/*.md'"
-doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
-markers = [
- "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
- "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
- "generate: marks tests that use the GenerationTesterMixin"
-]
diff --git a/tests/transformers/tests/generation/test_framework_agnostic.py b/tests/transformers/tests/generation/test_framework_agnostic.py
index 906a90a95a..7fcc4de752 100644
--- a/tests/transformers/tests/generation/test_framework_agnostic.py
+++ b/tests/transformers/tests/generation/test_framework_agnostic.py
@@ -3,12 +3,8 @@
"""
import numpy as np
-import pytest
from transformers import AutoTokenizer
-from transformers.testing_utils import slow
-
-
-torch_device = "hpu"
+from transformers.testing_utils import slow, torch_device
class GenerationIntegrationTestsMixin:
@@ -50,8 +46,6 @@ def test_validate_generation_inputs(self):
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
model.generate(input_ids, **valid_model_kwargs)
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_custom_logits_processor(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
@@ -72,8 +66,6 @@ def test_custom_logits_processor(self):
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_max_new_tokens_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -230,8 +222,6 @@ def test_transition_scores_greedy_search_normalized(self):
)
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_transition_scores_beam_search_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -267,8 +257,6 @@ def test_transition_scores_beam_search_encoder_decoder(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -303,8 +291,6 @@ def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_transition_scores_beam_search_decoder_only(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -342,8 +328,6 @@ def test_transition_scores_beam_search_decoder_only(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_transition_scores_beam_sample_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -381,7 +365,6 @@ def test_transition_scores_beam_sample_encoder_decoder(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
@slow
- @pytest.mark.skip("Not Implemented: sequence_scores is not implemented for static_shapes")
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1`
@@ -417,8 +400,6 @@ def test_transition_scores_early_stopping(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_encoder_decoder_generate_attention_mask(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -520,8 +501,6 @@ def test_generate_too_many_encoder_kwargs(self):
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_generate_input_features_as_encoder_kwarg(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
@@ -563,8 +542,6 @@ def test_generate_pixel_values_as_encoder_kwarg(self):
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
self.assertEqual(output_sequences.shape, (2, 5))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_generate_encoder_outputs_attention_mask(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
@@ -599,6 +576,7 @@ def test_eos_token_id_int_and_list_greedy_search(self):
"do_sample": False,
"num_beams": 1,
}
+ expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
@@ -608,7 +586,6 @@ def test_eos_token_id_int_and_list_greedy_search(self):
model = model.to(torch_device)
tokens = tokens.to(torch_device)
- expectation = model.config.max_length # static shape should give max_length
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -628,6 +605,7 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
"penalty_alpha": 0.6,
"top_k": 4,
}
+ expectation = 17
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
@@ -637,7 +615,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
model = model.to(torch_device)
tokens = tokens.to(torch_device)
- expectation = model.config.max_length # static shape should give max_length
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -646,8 +623,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_eos_token_id_int_and_list_beam_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -673,10 +648,7 @@ def test_eos_token_id_int_and_list_beam_search(self):
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
)
- static_shape_condition = expectation < len(generated_tokens[0]) and all(
- token == eos_token_id for token in generated_tokens[0][expectation:]
- )
- self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition)
+ self.assertTrue(unpadded_correct_condition or padded_correct_condition)
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
@@ -684,13 +656,8 @@ def test_eos_token_id_int_and_list_beam_search(self):
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
)
- static_shape_condition = expectation < len(generated_tokens[0]) and all(
- token in eos_token_id for token in generated_tokens[0][expectation:]
- )
- self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition)
+ self.assertTrue(unpadded_correct_condition or padded_correct_condition)
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
def test_generate_vision2text_conditioning(self):
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py
index 954bcd14d5..512935e9dd 100644
--- a/tests/transformers/tests/generation/test_utils.py
+++ b/tests/transformers/tests/generation/test_utils.py
@@ -14,27 +14,14 @@
# limitations under the License.
-import copy
import inspect
-import tempfile
import unittest
import warnings
import numpy as np
import pytest
-from parameterized import parameterized
-from transformers import is_torch_available, pipeline, set_seed
-from transformers.testing_utils import (
- is_flaky,
- require_accelerate,
- require_auto_gptq,
- require_quanto,
- require_torch,
- require_torch_gpu,
- require_torch_multi_accelerator,
- require_torch_multi_gpu,
- slow,
-)
+from transformers import is_torch_available, pipeline
+from transformers.testing_utils import require_torch, slow
from optimum.habana.checkpoint_utils import model_is_optimized
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -45,50 +32,54 @@
if is_torch_available():
import torch
- import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
- AutoProcessor,
AutoTokenizer,
- BartForCausalLM,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
+ PreTrainedModel,
SpeechEncoderDecoderModel,
- T5ForConditionalGeneration,
)
- from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
+ BeamSearchScorer,
+ ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
+ ForcedBOSTokenLogitsProcessor,
+ ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
- GenerationConfig,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
+ HammingDiversityLogitsProcessor,
LogitsProcessorList,
MaxLengthCriteria,
MinLengthLogitsProcessor,
+ NoBadWordsLogitsProcessor,
+ NoRepeatNGramLogitsProcessor,
PhrasalConstraint,
- PromptLookupCandidateGenerator,
+ RepetitionPenaltyLogitsProcessor,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
- WatermarkDetector,
- WatermarkingConfig,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
)
- from transformers.generation.utils import _speculative_sampling
+ from transformers.generation.candidate_generator import AssistedCandidateGenerator, CandidateGenerator
+ from transformers.generation.streamers import BaseStreamer
torch_device = "hpu"
adapt_transformers_to_gaudi()
@@ -100,84 +91,116 @@ class GenerationTesterMixin:
input_name = "input_ids"
max_new_tokens = 3
+ def _update_default_model_kwargs(self, model_kwargs):
+ model_kwargs["limit_hpu_graphs"] = False
+ model_kwargs["reuse_cache"] = False
+ model_kwargs["bucket_size"] = -1
+
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- # TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids`
- input_ids = inputs_dict.pop(self.input_name)[:batch_size, :]
- inputs_dict.pop("attention_mask", None)
-
- # we don't want encoder-decoder models to start from filled decoder ids
- inputs_dict.pop("decoder_input_ids", None)
- inputs_dict.pop("decoder_attention_mask", None)
+ input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
- # we'll set cache use in each test differently
- inputs_dict.pop("use_cache", None)
-
- inputs_dict = {
- k: v[:batch_size, ...]
- for k, v in inputs_dict.items()
- if "head_mask" not in k and isinstance(v, torch.Tensor)
- }
+ # generate max 3 tokens
+ max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
-
- if self.has_attentions:
- attention_mask = torch.ones_like(input_ids, dtype=torch.long)
- else:
+ # TransfoXL has no attention mask
+ if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
-
- # It is important set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated
- config.eos_token_id = None
- config.forced_eos_token_id = None
-
- return config, input_ids, attention_mask, inputs_dict
-
- def _get_logits_processor_kwargs(self, do_sample=False, config=None):
- logits_processor_kwargs = {
+ else:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
+
+ return config, input_ids, attention_mask, max_length
+
+ @staticmethod
+ def _get_logits_processor_and_kwargs(
+ input_length,
+ eos_token_id,
+ forced_bos_token_id=None,
+ forced_eos_token_id=None,
+ max_length=None,
+ diversity_penalty=None,
+ ):
+ process_kwargs = {
+ "min_length": input_length + 1 if max_length is None else max_length - 1,
"bad_words_ids": [[1, 0]],
+ "no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
- "remove_invalid_values": True,
}
- if do_sample:
- logits_processor_kwargs.update(
- {
- "top_k": 10,
- "top_p": 0.7,
- "temperature": 0.7,
- }
+ logits_processor = LogitsProcessorList(
+ (
+ [
+ HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
+ ]
+ if diversity_penalty is not None
+ else []
)
- # TODO (joao, raushan): see this comment for a long-term fix
- # https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
- # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
- # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
- if config is not None:
- image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
- video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
- if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
- logits_processor_kwargs["bad_words_ids"].append([image_token_index])
- if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
- logits_processor_kwargs["bad_words_ids"].append([video_token_index])
-
- return logits_processor_kwargs
-
- def _get_beam_kwargs(self, num_return_sequences=1):
+ + (
+ [
+ MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
+ ]
+ if eos_token_id is not None
+ else []
+ )
+ + (
+ [
+ ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
+ ]
+ if forced_bos_token_id is not None
+ else []
+ )
+ + (
+ [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
+ if forced_eos_token_id is not None
+ else []
+ )
+ + [
+ NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
+ NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
+ RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
+ ]
+ )
+ return process_kwargs, logits_processor
+
+ @staticmethod
+ def _get_warper_and_kwargs(num_beams):
+ warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
+ logits_warper = LogitsProcessorList(
+ [
+ TemperatureLogitsWarper(warp_kwargs["temperature"]),
+ TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
+ TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
+ ]
+ )
+ return warp_kwargs, logits_warper
+
+ @staticmethod
+ def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
}
- return beam_kwargs
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=beam_kwargs["num_beams"],
+ device=torch_device,
+ length_penalty=beam_kwargs["length_penalty"],
+ do_early_stopping=beam_kwargs["early_stopping"],
+ num_beam_hyps_to_keep=num_return_sequences,
+ )
+ return beam_kwargs, beam_scorer
- def _get_diverse_beam_kwargs(self, num_return_sequences=1):
+ @staticmethod
+ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
@@ -186,46 +209,93 @@ def _get_diverse_beam_kwargs(self, num_return_sequences=1):
"num_beam_groups": 2, # one beam per group
"diversity_penalty": 2.0,
}
- return beam_kwargs
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=beam_kwargs["num_beams"],
+ device=torch_device,
+ length_penalty=beam_kwargs["length_penalty"],
+ do_early_stopping=beam_kwargs["early_stopping"],
+ num_beam_hyps_to_keep=num_return_sequences,
+ num_beam_groups=beam_kwargs["num_beam_groups"],
+ )
+ return beam_kwargs, beam_scorer
- def _get_constrained_beam_kwargs(self, num_return_sequences=1):
+ @staticmethod
+ def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
- return beam_kwargs
+ beam_scorer = ConstrainedBeamSearchScorer(
+ batch_size=batch_size,
+ constraints=constraints,
+ num_beams=beam_kwargs["num_beams"],
+ device=torch_device,
+ length_penalty=beam_kwargs["length_penalty"],
+ do_early_stopping=beam_kwargs["early_stopping"],
+ num_beam_hyps_to_keep=num_return_sequences,
+ )
+ return beam_kwargs, beam_scorer
+
+ @staticmethod
+ def _get_encoder_outputs(
+ model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
+ ):
+ encoder = model.get_encoder()
+ encoder_outputs = encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
+ num_interleave, dim=0
+ )
+ input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
+ attention_mask = None
+ return encoder_outputs, input_ids, attention_mask
+
+ @staticmethod
+ def _get_static_shapes():
+ return False
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ eos_token_id=model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
- output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
- **logits_processor_kwargs,
+ remove_invalid_values=True,
+ **logits_process_kwargs,
**model_kwargs,
- **inputs_dict,
)
return output_generate
@@ -235,33 +305,35 @@ def _sample_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
num_return_sequences,
+ logits_processor,
+ logits_warper,
+ logits_warper_kwargs,
+ process_kwargs,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
torch.manual_seed(0)
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
+ model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
- output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
- **logits_processor_kwargs,
+ remove_invalid_values=True,
+ **logits_warper_kwargs,
+ **process_kwargs,
**model_kwargs,
- **inputs_dict,
)
return output_generate
@@ -271,31 +343,31 @@ def _beam_search_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
+ beam_scorer,
beam_kwargs,
+ logits_processor,
+ logits_process_kwargs,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
+ model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_scores=output_scores,
- output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
+ remove_invalid_values=True,
**beam_kwargs,
- **logits_processor_kwargs,
+ **logits_process_kwargs,
**model_kwargs,
- **inputs_dict,
)
return output_generate
@@ -305,34 +377,32 @@ def _beam_sample_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
+ beam_scorer,
beam_kwargs,
+ logits_warper,
+ logits_warper_kwargs,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
torch.manual_seed(0)
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=True,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_scores=output_scores,
- output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
+ remove_invalid_values=True,
**beam_kwargs,
- **logits_processor_kwargs,
+ **logits_warper_kwargs,
**model_kwargs,
- **inputs_dict,
)
-
return output_generate
def _group_beam_search_generate(
@@ -340,31 +410,30 @@ def _group_beam_search_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
+ beam_scorer,
beam_kwargs,
+ logits_processor,
+ logits_process_kwargs,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_scores=output_scores,
- output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
+ remove_invalid_values=True,
**beam_kwargs,
- **logits_processor_kwargs,
+ **logits_process_kwargs,
**model_kwargs,
- **inputs_dict,
)
return output_generate
@@ -374,33 +443,33 @@ def _constrained_beam_search_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
+ constrained_beam_scorer,
constraints,
beam_kwargs,
+ logits_processor,
+ logits_process_kwargs,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
+ model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_scores=output_scores,
- output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
+ remove_invalid_values=True,
constraints=constraints,
- use_cache=use_cache,
**beam_kwargs,
- **logits_processor_kwargs,
+ **logits_process_kwargs,
**model_kwargs,
- **inputs_dict,
)
return output_generate
@@ -410,72 +479,76 @@ def _contrastive_generate(
model,
input_ids,
attention_mask,
- inputs_dict,
+ max_length,
output_scores=False,
- output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
- use_cache=True,
):
contrastive_search_kwargs = {
"penalty_alpha": 0.6,
"top_k": 5,
}
- logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ eos_token_id=model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ self._update_default_model_kwargs(model_kwargs)
+ model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
- output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
- use_cache=use_cache,
- **logits_processor_kwargs,
+ remove_invalid_values=True,
+ **logits_process_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
- **inputs_dict,
)
return output_generate
- @pytest.mark.generate
def test_greedy_generate(self):
+ # check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
-
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ # test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
- model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict
+ model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
)
-
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- @pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
-
+ # disable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -491,50 +564,58 @@ def test_greedy_generate_dict_outputs(self):
self._check_outputs(output_generate, input_ids, model.config)
- @pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ # enable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
- if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
- self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
+ # only relevant if model has "use_cache"
+ return
+ config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=True,
)
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
-
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
- @pytest.mark.generate
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
-
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
+
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
+
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
num_return_sequences=1,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
)
if model.config.is_encoder_decoder:
@@ -542,24 +623,38 @@ def test_sample_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- @pytest.mark.generate
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
-
+ # disable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
model = model_class(config).to(torch_device).eval()
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
num_return_sequences=2,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -575,20 +670,38 @@ def test_sample_generate_dict_output(self):
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
- @pytest.mark.generate
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ )
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
- beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_process_kwargs=logits_process_kwargs,
+ logits_processor=logits_processor,
)
if model.config.is_encoder_decoder:
@@ -596,26 +709,72 @@ def test_beam_search_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- @pytest.mark.generate
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+
+ output_generate = self._beam_search_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
+ beam_kwargs=beam_kwargs,
+ logits_process_kwargs=logits_process_kwargs,
+ logits_processor=logits_processor,
+ )
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # disable cache
+ config.use_cache = False
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
- beam_kwargs = self._get_beam_kwargs()
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ )
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_process_kwargs=logits_process_kwargs,
+ logits_processor=logits_processor,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
+ if model.config.is_encoder_decoder:
+ self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
+ else:
+ self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
+
+ self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
+ self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
+
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
@@ -631,139 +790,148 @@ def test_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
- @pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# enable cache
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
- if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
- self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
+ # only relevant if model has "use_cache"
+ return
model = model_class(config).to(torch_device).eval()
- beam_kwargs = self._get_beam_kwargs()
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ )
+
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_process_kwargs=logits_process_kwargs,
+ logits_processor=logits_processor,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
-
self._check_outputs(
- output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
+ output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
- @require_accelerate
- @require_torch_multi_accelerator
- @pytest.mark.generate
- def test_model_parallel_beam_search(self):
+ @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
+ def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
- if "xpu" in torch_device:
- return unittest.skip(reason="device_map='auto' does not work with XPU devices")
-
- if model_class._no_split_modules is None:
- continue
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
- model = model_class(config).eval()
- with tempfile.TemporaryDirectory() as tmp_dir:
- model.cpu().save_pretrained(tmp_dir)
- new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
- new_model.generate(
- input_ids,
- attention_mask=attention_mask,
- max_new_tokens=self.max_new_tokens,
- num_beams=2,
- **inputs_dict,
- )
+ model = model_class(config).to(torch_device).eval()
- @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
- @pytest.mark.generate
- def test_beam_sample_generate(self):
- for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ # check `generate()` and `beam_search()` are equal
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
- model = model_class(config).to(torch_device).eval()
- beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
-
- # for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
- # no quick fix available, since obtaining image embeddings step is very model-specific
- if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")):
- prepare_inputs_for_generation_args = set(
- inspect.signature(model.prepare_inputs_for_generation).parameters
+ if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
+ input_embeds = model.get_input_embeddings()(input_ids)
+ beam_kwargs.update({"inputs_embeds": input_embeds})
+ output_generate2 = self._beam_sample_generate(
+ model=model,
+ input_ids=None,
+ attention_mask=attention_mask,
+ beam_kwargs=beam_kwargs,
+ logits_warper_kwargs=logits_warper_kwargs,
)
- # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
- # code is up to date with our most recent standards
- if (
- "inputs_embeds" in prepare_inputs_for_generation_args
- and "cache_positions" in prepare_inputs_for_generation_args
- ):
- input_embeds = model.get_input_embeddings()(input_ids)
- beam_kwargs.update({"inputs_embeds": input_embeds})
- output_generate2 = self._beam_sample_generate(
- model=model,
- input_ids=None,
- attention_mask=attention_mask,
- inputs_dict={},
- beam_kwargs=beam_kwargs,
- )
- torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
+ torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
@pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
- @pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # disable cache
+ config.use_cache = False
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
- beam_kwargs = self._get_beam_kwargs()
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
+ self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
+ self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
+
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
@@ -779,131 +947,192 @@ def test_beam_sample_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
- @pytest.mark.generate
def test_generate_without_input_ids(self):
- config, _, _, _ = self._get_input_ids_and_config()
+ config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
- self.skipTest(reason="bos_token_id is None")
-
- # hack in case they are equal, otherwise the attn mask will be [0]
- if config.bos_token_id == config.pad_token_id:
- config.pad_token_id = None
+ return
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
- output_ids_generate = model.generate(
- do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
- )
+ output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
self.assertIsNotNone(output_ids_generate)
@pytest.mark.skip("Group beam search is not supported by optimum-habana")
- @pytest.mark.generate
def test_group_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ diversity_penalty=2.0,
+ )
+
# check `generate()` and `group_beam_search()` are equal
- beam_kwargs = self._get_diverse_beam_kwargs()
- output_generate = self._group_beam_search_generate(
+ beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
)
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
- # check `group_beam_search` for higher than 1 `num_return_sequences`
+ # check `generate()` and `group_beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
- beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
- output_generate = self._group_beam_search_generate(
+ if model.config.is_encoder_decoder:
+ max_length = 4
+ beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
+ input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
+ )
+ output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
)
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
@pytest.mark.skip("Group beam search is not supported by optimum-habana")
- @pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
- beam_kwargs = self._get_diverse_beam_kwargs()
- output_generate = self._group_beam_search_generate(
+ if model.config.is_encoder_decoder:
+ max_length = 4
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ diversity_penalty=2.0,
+ )
+
+ num_return_sequences = 1
+ beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
+ input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
+ )
+ output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
- self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
- # Retrocompatibility check
+ self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
- # Retrocompatibility check
+ self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
- self._check_outputs(
- output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
+ self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist())
+ self.assertTrue(
+ torch.allclose(
+ output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3
+ )
)
+ self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
+ self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
+
+ for output in (output_group_beam_search, output_generate):
+ self._check_outputs(
+ output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
+ )
- # TODO: @gante
- @is_flaky()
- @pytest.mark.generate
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
+ max_length = 20
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ )
+
+ # check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
- min_id = 3
- max_id = config.get_text_config(decoder=True).vocab_size
+ if not input_ids.dtype == torch.float32:
+ min_id = torch.min(input_ids) + 3
+ max_id = torch.max(input_ids)
+ else:
+ # otherwise this throws an error for Speech2TextModel since its inputs are floating points
+ min_id = 3
+ max_id = 100
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
- beam_kwargs = self._get_constrained_beam_kwargs()
+ beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
+ input_ids.shape[0], max_length, constraints, num_return_sequences=1
+ )
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
)
-
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ self.assertTrue(output_generate.shape[-1] == max_length)
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -915,63 +1144,86 @@ def test_constrained_beam_search_generate(self):
PhrasalConstraint(force_tokens),
]
- beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
+ num_return_sequences = 2
+ max_length = 20
+
+ beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
+ input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
+ )
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
)
-
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ self.assertTrue(output_generate.shape[-1] == max_length)
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
- @pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ # disable cache
+ config.use_cache = False
+
+ # It is important set set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated which could lead to flaky circle ci
+ # failures if the top `num_return_sequences` beams are all shorter than the longest beam
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
model = model_class(config).to(torch_device).eval()
+ if model.config.is_encoder_decoder:
+ max_length = 20
+
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ config.eos_token_id,
+ config.forced_bos_token_id,
+ config.forced_eos_token_id,
+ max_length,
+ )
# Sample constraints
min_id = 3
- max_id = model.config.get_text_config(decoder=True).vocab_size
+ max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
- beam_kwargs = self._get_constrained_beam_kwargs()
+ beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
+ input_ids.shape[0], max_length, constraints, num_return_sequences=1
+ )
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
+ constrained_beam_scorer=beam_scorer,
constraints=constraints,
beam_kwargs=beam_kwargs,
+ logits_processor=logits_processor,
+ logits_process_kwargs=logits_process_kwargs,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=False,
)
-
+ self.assertTrue(output_generate.sequences.shape[-1] == max_length)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -980,52 +1232,47 @@ def test_constrained_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
- @pytest.mark.generate
+ self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
+ self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
+
def test_contrastive_generate(self):
+ # check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support contrastive search generation")
-
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
+ return
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ return
+ config.use_cache = True
config.is_decoder = True
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_generate = self._contrastive_generate(
- model=model,
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_dict=inputs_dict,
- use_cache=True,
+ model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- @pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support contrastive search generation")
-
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
+ return
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ # enable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ return
+ config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@@ -1033,40 +1280,36 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- inputs_dict=inputs_dict,
+ max_length=max_length,
output_scores=True,
- output_logits=True,
output_hidden_states=True,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
- use_cache=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
-
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
- @pytest.mark.generate
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support contrastive search generation")
-
- if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
- if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
- self.skipTest(reason="TODO: fix me")
+ # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
+ if any(
+ model_name in model_class.__name__.lower()
+ for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
+ ):
+ return
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ return
+ config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
@@ -1077,10 +1320,8 @@ def test_contrastive_generate_low_memory(self):
top_k=4,
penalty_alpha=0.6,
low_memory=True,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
attention_mask=attention_mask,
- **inputs_dict,
- use_cache=True,
)
high_output = model.generate(
@@ -1088,10 +1329,8 @@ def test_contrastive_generate_low_memory(self):
top_k=4,
penalty_alpha=0.6,
low_memory=False,
- max_new_tokens=self.max_new_tokens,
+ max_length=max_length,
attention_mask=attention_mask,
- **inputs_dict,
- use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@@ -1138,75 +1377,89 @@ def test_contrastive_generate_dynamic_shapes(self):
)
self.assertListEqual(dynamic_output.tolist(), static_output.tolist())
- @pytest.mark.generate
- @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
- def test_beam_search_low_memory(self):
- # Check that choosing 'low_memory' does not change the model output
+ # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable
+ @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana")
+ @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
+ def test_assisted_decoding_matches_greedy_search(self):
+ # This test ensures that the assisted generation does not introduce output changes over greedy search.
+ # It breaks the pattern in the tests above, for multiple reasons:
+ # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
+ # prepare the assistant encoder outputs in the main generate body);
+ # - assisted_decoding does not support `use_cache = False`
+ # - assisted_decoding does not support `batch_size > 1`
+
for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="May fix in the future: need custom cache handling")
+ # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
+ return
+ # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if any(
model_name in model_class.__name__.lower()
- for model_name in [
- "ctrl",
- "gptbigcode",
- "transo_xl",
- "xlnet",
- "cpm",
- "jamba",
- ]
+ for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
- self.skipTest(reason="May fix in the future: need model-specific fixes")
- config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
- # batch_size=1 is ok, but batch_size>1 will cause non-identical output
+ return
- config.use_cache = True
- config.is_decoder = True
+ # This for loop is a naive and temporary effort to make the test less flaky.
+ failed = 0
+ for i in range(10):
+ # enable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
- # test output equality of low versus high memory
- model = model_class(config).to(torch_device).eval()
+ # NOTE: assisted generation only works with cache on at the moment.
+ if not hasattr(config, "use_cache"):
+ return
- low_output = model.generate(
- input_ids,
- max_new_tokens=8,
- num_beams=5,
- early_stopping=True,
- low_memory=True,
- use_cache=True,
- )
+ config.use_cache = True
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
+ output_greedy = model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ num_beams=1,
+ do_sample=False,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+ # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
+ # be correct
+ output_assisted = model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ num_beams=1,
+ do_sample=False,
+ assistant_model=model,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
- high_output = model.generate(
- input_ids,
- max_new_tokens=8,
- num_beams=5,
- early_stopping=True,
- low_memory=False,
- use_cache=True,
- )
- self.assertListEqual(low_output.tolist(), high_output.tolist())
+ try:
+ self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
- @pytest.mark.generate
- @parameterized.expand([("random",), ("same",)])
- @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
- def test_assisted_decoding_matches_greedy_search(self, assistant_type):
- # This test ensures that the assisted generation does not introduce output changes over greedy search.
- # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
- # shape differences -- and it may result in a different output. The input shape difference happens in the
- # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
- # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
- # NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
- # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
- # prepare the assistant encoder outputs in the main generate body);
- # - assisted_decoding does not support `use_cache = False`
- # - assisted_decoding does not support `batch_size > 1`
+ for output in (output_greedy, output_assisted):
+ self._check_outputs(output, input_ids, model.config, use_cache=True)
+ except AssertionError:
+ failed += 1
+ if failed > 1:
+ self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
+
+ for output in (output_greedy, output_assisted):
+ self._check_outputs(output, input_ids, model.config, use_cache=True)
+ # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable
+ @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana")
+ def test_assisted_decoding_sample(self):
+ # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
+ # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
+ # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
+ self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1220,15 +1473,16 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
"clvp",
]
):
- self.skipTest(reason="May fix in the future: need model-specific fixes")
+ self.skipTest("May fix in the future: need model-specific fixes")
# enable cache
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
+ config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ self.skipTest("This model doesn't support caching")
+ config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@@ -1237,253 +1491,503 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
- generation_kwargs = {
- "eos_token_id": -1, # see a)
- "max_new_tokens": 4, # see c)
- "num_beams": 1,
- "do_sample": False,
- "output_scores": True,
- "output_logits": True,
- "output_hidden_states": True,
- "output_attentions": self.has_attentions,
- "return_dict_in_generate": True,
- "use_cache": True,
- }
- output_greedy = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
-
- # test with the same assistant model or randomly init one
- # in the first case all candidate tokens are accepted, in the second none is accepted
- # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
- if assistant_type == "random":
- assistant_model = model_class(config).to(torch_device).eval()
- else:
- assistant_model = model
+ assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
- generation_kwargs.update({"assistant_model": assistant_model})
- output_assisted = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
-
- # The two outputs must match and their shape must be as expected
-
- self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
- for output in (output_greedy, output_assisted):
- self._check_outputs(output, input_ids, model.config, use_cache=True)
-
- @is_flaky()
- @pytest.mark.generate
- def test_prompt_lookup_decoding_matches_greedy_search(self):
- # This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
- # This test is mostly a copy of test_assisted_decoding_matches_greedy_search
-
- for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support assisted generation")
- if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
- if any(
- model_name in model_class.__name__.lower()
- for model_name in [
- "bigbirdpegasus",
- "led",
- "mega",
- "speech2text",
- "git",
- "prophetnet",
- "seamlessm4t",
- "clvp",
- ]
- ):
- self.skipTest(reason="May fix in the future: need model-specific fixes")
-
- # enable cache
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
-
- # NOTE: assisted generation only works with cache on at the moment.
- if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
-
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
- # Sets assisted generation arguments such that:
- # a) no EOS is generated, to ensure generation doesn't break early
- # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
- # prompt lookup is correct
- # c) there are at least two forward passes in the main model, to ensure the input preparation of
- # the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
- "do_sample": False,
+ "do_sample": True,
+ "assistant_model": assistant_model,
"output_scores": True,
- "output_logits": True,
"output_hidden_states": True,
- "output_attentions": self.has_attentions,
+ "output_attentions": True,
"return_dict_in_generate": True,
- "use_cache": True,
}
- output_greedy = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
-
- generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
- output_prompt_lookup = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
-
- # The two outputs must match and their shape must be as expected
+ #######################################################################
+ # Monkey patch assisted decoding function till SW issue is resolved
+ import copy
+ from types import MethodType
+ from typing import List, Optional, Union
+
+ from transformers.generation.utils import (
+ GenerateDecoderOnlyOutput,
+ _crop_past_key_values,
+ _prepare_attention_mask,
+ _prepare_token_type_ids,
+ _split_model_outputs,
+ )
+
+ def _speculative_sampling(
+ candidate_input_ids,
+ candidate_logits,
+ candidate_length,
+ new_logits,
+ last_assistant_token_is_eos,
+ max_matches,
+ ):
+ """
+ Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
+ the selected tokens, as well as the number of candidate matches.
+
+ NOTE: Unless otherwise stated, the variable names match those in the paper.
+ """
+ new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
+ # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
+ # selected by the assistant, respectively.
+ q = candidate_logits.softmax(dim=-1)
+ q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1)
+ p = new_logits.softmax(dim=-1)
+ p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1)
+ probability_ratio = p_i / q_i
+
+ # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
+ # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
+ # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
+ r_i = torch.rand_like(probability_ratio)
+ is_accepted = r_i <= probability_ratio
+ n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
+
+ # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
+ if last_assistant_token_is_eos and n_matches == candidate_length:
+ # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
+ # due to acceptance on EOS we fix `n_matches`
+ n_matches -= 1
+ valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
+ else:
+ n_matches = min(n_matches, max_matches)
+
+ # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
+ gamma = min(candidate_logits.shape[1], max_matches)
+ p_n_plus_1 = p[:, n_matches, :]
+ if n_matches < gamma:
+ q_n_plus_1 = q[:, n_matches, :]
+ p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
+ p_prime.div_(p_prime.sum())
+ else:
+ p_prime = p_n_plus_1
+ t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
+
+ # The selected tokens include the matches (if any) plus the next sampled tokens
+ if n_matches > 0:
+ valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
+ else:
+ valid_tokens = t
+
+ return valid_tokens, n_matches
+
+ def assisted_decoding(
+ self,
+ input_ids: torch.LongTensor,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ candidate_generator: Optional["CandidateGenerator"] = None,
+ do_sample: bool = False,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: bool = False,
+ streamer: Optional["BaseStreamer"] = None,
+ **model_kwargs,
+ ):
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
+ **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
+ candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
+ models.
+
+
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
+ generate() instead. For an overview of generation strategies and code examples, check the [following
+ guide](../generation_strategies).
+
+
+
+ Parameters:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ candidate_generator (`CandidateGenerator`, *optional*):
+ A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
+ more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
+ assistant_model (`PreTrainedModel`, *optional*):
+ An assistant model that can be used to accelerate generation. The assistant model must have the exact
+ same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
+ is much faster than running generation with the model you're calling generate from. As such, the
+ assistant model should be much smaller.
+ do_sample (`bool`, *optional*, defaults to `False`):
+ Whether or not to use sampling ; use greedy decoding otherwise.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ eos_token_id (`Union[int, List[int]]`, *optional*):
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more details.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more details.
+ output_scores (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ model_kwargs:
+ Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
+ If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import (
+ ... AutoTokenizer,
+ ... AutoModelForCausalLM,
+ ... LogitsProcessorList,
+ ... MinLengthLogitsProcessor,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
+ ... )
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+ >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
+ >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
+ >>> input_prompt = "It might be possible to"
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
+ >>> # instantiate logits processors
+ >>> logits_processor = LogitsProcessorList(
+ ... [
+ ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
+ ... ]
+ ... )
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+ >>> outputs = model.assisted_decoding(
+ ... input_ids,
+ ... assistant_model=assistant_model,
+ ... logits_processor=logits_processor,
+ ... stopping_criteria=stopping_criteria,
+ ... )
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
+ ```"""
+ # handling deprecated arguments
+ if (assistant_model is None) == (candidate_generator is None):
+ raise ValueError(
+ "One (and only one) of `assistant_model` and `candidate_generator` should be defined."
+ )
- self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
- for output in (output_greedy, output_prompt_lookup):
- self._check_outputs(output, input_ids, model.config, use_cache=True)
+ if assistant_model is not None:
+ candidate_generator = AssistedCandidateGenerator(
+ input_ids=input_ids,
+ assistant_model=assistant_model,
+ logits_processor=logits_processor,
+ model_kwargs=model_kwargs,
+ eos_token_id=eos_token_id,
+ )
+ warnings.warn(
+ "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
+ "Pass the `candidate_generator` argument instead.",
+ FutureWarning,
+ )
- @pytest.mark.generate
- def test_dola_decoding_sample(self):
- # TODO (joao): investigate skips, try to reduce incompatibilities
- for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support DoLa decoding")
+ # init values
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ if eos_token_id is not None and pad_token_id is None:
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id_tensor = (
+ torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
+ )
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
- if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
- self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
- self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = (
+ model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ )
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
- # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ # keep track of which sequences are already finished
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ # other auxiliary variables
+ max_len = stopping_criteria[0].max_length
+
+ this_peer_finished = False # used by synced_gpus only
+ while True:
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
+ # send 0.0 if we finished, 1.0 otherwise
+ torch.dist.all_reduce(this_peer_finished_flag, op=torch.dist.ReduceOp.SUM)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ break
+
+ cur_len = input_ids.shape[-1]
+
+ # 1. Fetch candidate sequences from a `CandidateGenerator`
+ candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
+ candidate_input_ids = candidate_input_ids.to(self.device)
+ if candidate_logits is not None:
+ candidate_logits = candidate_logits.to(self.device)
+
+ candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
+ last_assistant_token_is_eos = (
+ ~candidate_input_ids[:, -1]
+ .tile(eos_token_id_tensor.shape[0], 1)
+ .ne(eos_token_id_tensor.unsqueeze(1))
+ .prod(dim=0)
+ .bool()
+ )
- # Encoder-decoder models are not supported
- if config.is_encoder_decoder:
- self.skipTest("DoLa is not supported for encoder-decoder models")
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
+ # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
+ # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
+ # we use this forward pass to also pick the subsequent logits in the original model.
- if model.get_output_embeddings() is None:
- self.skipTest("DoLa is not supported for models that don't have output embeddings")
- # Sets dola generation arguments such that:
- # a) no EOS is generated, to ensure generation doesn't break early
- # b) there are at least two forward passes in the main model, to ensure the input preparation of
- # the main model is correct
- generation_kwargs = {
- "eos_token_id": -1, # see a)
- "max_new_tokens": 4, # see b)
- "num_beams": 1,
- "do_sample": True,
- "output_scores": True,
- "output_logits": True,
- "output_hidden_states": True,
- "output_attentions": self.has_attentions,
- "return_dict_in_generate": True,
- "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
- }
- generation_kwargs.update({"dola_layers": "low"})
- model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict)
- self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
+ # 2.1. Prepare the model inputs
+ candidate_kwargs = copy.copy(model_kwargs)
+ candidate_kwargs = _prepare_attention_mask(
+ candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
+ )
+ candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
- @pytest.mark.generate
- def test_assisted_decoding_sample(self):
- # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
- # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
- # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
- for model_class in self.all_generative_model_classes:
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support assisted generation")
- if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest(reason="Won't fix: old model with different cache format")
- if any(
- model_name in model_class.__name__.lower()
- for model_name in [
- "bigbirdpegasus",
- "led",
- "mega",
- "speech2text",
- "git",
- "prophetnet",
- "seamlessm4t",
- "clvp",
- ]
- ):
- self.skipTest(reason="May fix in the future: need model-specific fixes")
+ model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
- # enable cache
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
+ # 2.2. Run a forward pass on the candidate sequence
+ outputs = self(
+ **model_inputs,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
- # NOTE: assisted generation only works with cache on at the moment.
- if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ # 2.3. Process the new logits
+ new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
+ if len(logits_processor) > 0:
+ for i in range(candidate_length + 1):
+ new_logits[:, i, :] = logits_processor(
+ candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]
+ )
+ if len(logits_warper) > 0:
+ for i in range(candidate_length + 1):
+ new_logits[:, i, :] = logits_warper(
+ candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]
+ )
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
- # Sets assisted generation arguments such that:
- # a) no EOS is generated, to ensure generation doesn't break early
- # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
- # the assistant model is correct
- # c) there are at least two forward passes in the main model, to ensure the input preparation of
- # the main model is correct
- assistant_model = model
- assistant_model.generation_config.num_assistant_tokens = 2 # see b)
- assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
- generation_kwargs = {
- "eos_token_id": -1, # see a)
- "max_new_tokens": 4, # see c)
- "num_beams": 1,
- "do_sample": True,
- "assistant_model": assistant_model,
- "output_scores": True,
- "output_logits": True,
- "output_hidden_states": True,
- "output_attentions": self.has_attentions,
- "return_dict_in_generate": True,
- "use_cache": True,
- }
- output_assisted = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
+ # 3. Select the accepted tokens. There are two possible cases:
+ # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
+ # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
+ max_matches = max_len - cur_len - 1
+ if do_sample and candidate_logits is not None:
+ valid_tokens, n_matches = _speculative_sampling(
+ candidate_input_ids,
+ candidate_logits,
+ candidate_length,
+ new_logits,
+ last_assistant_token_is_eos,
+ max_matches,
+ )
- self._check_outputs(output_assisted, input_ids, config, use_cache=True)
+ # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
+ # original model logits with the candidate tokens. We can keep the candidate tokens until the first
+ # mismatch, or until the max length is reached.
+ else:
+ if do_sample:
+ probs = new_logits.softmax(dim=-1)
+ selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
+ else:
+ selected_tokens = new_logits.argmax(dim=-1)
+
+ candidate_new_tokens = candidate_input_ids[:, cur_len:]
+ n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
+
+ # Ensure we don't generate beyond max_len or an EOS token
+ if last_assistant_token_is_eos and n_matches == candidate_length:
+ n_matches -= 1
+ n_matches = min(n_matches, max_matches)
+ valid_tokens = selected_tokens[:, : n_matches + 1]
+
+ # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
+ # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
+ # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
+ # is no match.
+
+ # 4.1. Get the valid continuation, after the matching tokens
+ input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
+ if streamer is not None:
+ streamer.put(valid_tokens.cpu())
+ new_cur_len = input_ids.shape[-1]
+
+ # 4.2. Discard past key values relative to unused assistant tokens
+ new_cache_size = new_cur_len - 1
+ outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
+
+ # 5. Update the candidate generation strategy if needed
+ candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ # Store scores, attentions and hidden_states when required
+ # Assistant: modified to append one tuple element per token, as in the other generation methods.
+ if return_dict_in_generate:
+ if output_scores:
+ scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
+
+ if "past_key_values" not in model_kwargs:
+ added_len = new_cur_len
+ else:
+ added_len = n_matches + 1
+
+ if output_attentions:
+ if self.config.is_encoder_decoder:
+ cross_attentions = _split_model_outputs(
+ cross_attentions, outputs.cross_attentions, cur_len, added_len
+ )
+ decoder_attentions = _split_model_outputs(
+ decoder_attentions,
+ outputs.decoder_attentions,
+ cur_len,
+ added_len,
+ is_decoder_attention=True,
+ )
+ else:
+ decoder_attentions = _split_model_outputs(
+ decoder_attentions,
+ outputs.attentions,
+ cur_len,
+ added_len,
+ is_decoder_attention=True,
+ )
+ if output_hidden_states:
+ if self.config.is_encoder_decoder:
+ decoder_hidden_states = _split_model_outputs(
+ decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
+ )
+ else:
+ decoder_hidden_states = _split_model_outputs(
+ decoder_hidden_states, outputs.hidden_states, cur_len, added_len
+ )
+
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
- @pytest.mark.generate
- def test_prompt_lookup_decoding_stops_at_eos(self):
- # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens
- # (see https://github.com/huggingface/transformers/pull/31301)
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id_tensor is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ input_ids[:, -1]
+ .tile(eos_token_id_tensor.shape[0], 1)
+ .ne(eos_token_id_tensor.unsqueeze(1))
+ .prod(dim=0)
+ )
- # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids.
- # First time at the very end, so input ends with the unigrams, and second any arbitrary location.
- # Also, we need an EOS token which will be injected just after the arbitrary located ngram.
- # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams
- # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model.
- # That seems as if the model "generated" and EOS but didn't stop from user's perspective
+ # stop when each sentence is finished
+ if unfinished_sequences.max() == 0:
+ this_peer_finished = True
+
+ # stop if we exceed the maximum length
+ if stopping_criteria(input_ids, scores):
+ this_peer_finished = True
+
+ if this_peer_finished and not synced_gpus:
+ break
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
- input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50
- arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests
- input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs
- input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen
+ model.assisted_decoding = MethodType(assisted_decoding, model)
- eos_token_id = torch.tensor([0], device=torch_device)
- input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram
+ #######################################################################
- # init cand geenerator with max_matching_ngram_size=1 to match per-token
- candidate_generator = PromptLookupCandidateGenerator(
- eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1
- )
- output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0]
+ output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
- # PLD shouldn't propose any new tokens based on eos-match
- self.assertTrue(output_prompt_lookup.shape[-1] == 10)
+ self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
- @pytest.mark.generate
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# We want to test only encoder-decoder models
if not config.is_encoder_decoder:
continue
@@ -1509,93 +2013,60 @@ def test_generate_with_head_masking(self):
input_ids,
attention_mask=attention_mask,
num_beams=1,
- output_attentions=self.has_attentions,
+ output_attentions=True,
return_dict_in_generate=True,
remove_invalid_values=True,
**{name: mask},
- **inputs_dict,
)
# We check the state of decoder_attentions and cross_attentions just from the last step
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
- @pytest.mark.generate
def test_left_padding_compatibility(self):
- # NOTE: left-padding results in small numerical differences. This is expected.
- # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
-
- # First, filter out models that don't support left padding
- # - The model must have generative capabilities
- if len(self.all_generative_model_classes) == 0:
- self.skipTest(reason="No generative architecture available for this model.")
+ # The check done in this test is fairly difficult -- depending on the model architecture, passing the right
+ # position index for the position embeddings can still result in a different output, due to numerical masking.
+ # On the other hand, for some types of position embeddings, an incorrect position index can have a minimal
+ # impact on the output.
+ # There are two tricks employed to check whether left-padding compatibility is in place:
+ # 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the
+ # padding size to 1.
+ # 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check
+ # multiple times with random inputs, and it has to pass with all of them.
+ # NOTE: because of 2), there is some chance of false positives in this test.
- # - The model must support padding
- if not self.has_attentions:
- self.skipTest(reason="This model doesn't support padding.")
-
- # - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
- decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _, _, _ = self._get_input_ids_and_config()
if config.is_encoder_decoder:
- continue
- else:
- decoder_only_classes.append(model_class)
- if len(decoder_only_classes) == 0:
- self.skipTest(reason="No decoder-only architecture available for this model.")
-
- # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
- # added support for it yet. We skip these models for now.
- has_encoder_attributes = any(
- attr_name
- for attr_name in config.to_dict().keys()
- if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
- )
- if has_encoder_attributes:
- self.skipTest(
- reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
- )
-
- # Then, test left-padding
- def _prepare_model_kwargs(input_ids, attention_mask, signature):
- model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
- if "position_ids" in signature:
- position_ids = torch.cumsum(attention_mask, dim=-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- model_kwargs["position_ids"] = position_ids
- if "cache_position" in signature:
- cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
- model_kwargs["cache_position"] = cache_position
- return model_kwargs
-
- for model_class in decoder_only_classes:
- config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
+ continue # skip for encoder-decoder models -- they don't need left-padding compatibility
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
- # no cache as some models require special cache classes to be init outside forward
- model.generation_config.use_cache = False
-
- # Without padding
- model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
- next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
-
- # With left-padding (length 32)
- # can hardcode pad_token to be 0 as we'll do attn masking anyway
- pad_token_id = (
- config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
- )
- pad_size = (input_ids.shape[0], 32)
- padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
- padded_input_ids = torch.cat((padding, input_ids), dim=1)
- padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
- model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
- next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
-
- # They should result in very similar logits
- self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
+ no_failures = True
+ for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness
+ _, input_ids, attention_mask, _ = self._get_input_ids_and_config()
+ model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
+
+ pad_size = (input_ids.shape[0], 1)
+ padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
+ padded_input_ids = torch.cat((padding, input_ids), dim=1)
+ padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
+ model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(padded_attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
+ if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7):
+ no_failures = False
+ break
+
+ self.assertTrue(no_failures)
- @pytest.mark.generate
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).
@@ -1604,7 +2075,7 @@ def test_past_key_values_format(self):
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ return
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
@@ -1613,7 +2084,7 @@ def test_past_key_values_format(self):
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
- self.skipTest(reason="This model doesn't return `past_key_values`")
+ return
num_hidden_layers = (
getattr(config, "decoder_layers", None)
@@ -1667,7 +2138,6 @@ def test_past_key_values_format(self):
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
- @pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
@@ -1694,581 +2164,100 @@ def test_generate_from_inputs_embeds_decoder_only(self):
continue
# Traditional way of generating text
- outputs_from_ids = model.generate(
- input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
- )
- self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
+ outputs_from_ids = model.generate(input_ids)
+ self.assertEqual(outputs_from_ids.shape, (2, 20))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs_from_embeds = model.generate(
- input_ids,
- inputs_embeds=inputs_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
+ outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
+ self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
- # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
- # same, but the logits will almost surely be different)
+ # But if we pass different inputs_embeds, we should get different outputs
+ torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
- outputs_from_rand_embeds = model.generate(
- input_ids,
- inputs_embeds=random_embeds,
- max_new_tokens=5,
- return_dict_in_generate=True,
- output_scores=True,
- )
- for i in range(len(outputs_from_rand_embeds.scores)):
- self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
+ outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
+ inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
)
self.assertListEqual(
- outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
- outputs_from_embeds_wo_ids.sequences.tolist(),
+ outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
+ outputs_from_embeds_wo_ids.tolist(),
)
- @pytest.mark.generate
- def test_generate_from_inputs_embeds_with_static_cache(self):
- """
- Test that StaticCache can generate from inputs_embeds and calculates max_cache_length
- correctly in `generate()`. We force the model to not stop generation until max-length is reached
- to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
- """
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
- self.skipTest(reason="This model does not support the static cache format")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- if config.is_encoder_decoder:
- self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
-
- model = model_class(config).to(torch_device).eval()
- if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
- self.skipTest(reason="This model does not support `inputs_embeds` in generation")
-
- model.config.use_cache = True
- model.config.is_decoder = True
- batch_size, seq_length = input_ids.shape
- max_cache_len = 30
+ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
+ batch_size, seq_length = input_ids.shape
+ num_sequences_in_output = batch_size * num_return_sequences
+ gen_len = (
+ output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
+ )
- # here we force to not stop at eos and go until max-length
- model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
- generation_kwargs = {
- "max_length": max_cache_len,
- "cache_implementation": "static",
- "return_dict_in_generate": True, # Required to return `past_key_values`
- }
+ # scores
+ self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
- text_config = model.config.get_text_config()
- head_dim = (
- text_config.head_dim
- if hasattr(text_config, "head_dim")
- else text_config.hidden_size // text_config.num_attention_heads
+ # Attentions
+ if config.is_encoder_decoder:
+ # encoder
+ self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
+ # decoder
+ self._check_attentions_for_generate(
+ num_sequences_in_output,
+ output.decoder_attentions,
+ min_length=1,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
)
- num_key_value_heads = (
- text_config.num_attention_heads
- if getattr(text_config, "num_key_value_heads", None) is None
- else text_config.num_key_value_heads
+ else:
+ # if use_cache first input is equal to no use_cache, so skip here
+ attentions = output.attentions if not use_cache else output.attentions[1:]
+ min_length = seq_length if not use_cache else seq_length + 1
+ self._check_attentions_for_generate(
+ num_sequences_in_output,
+ attentions=attentions,
+ min_length=min_length,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
)
- num_hidden_layers = text_config.num_hidden_layers
- inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs = model.generate(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ # Hidden States
+ if config.is_encoder_decoder:
+ # encoder
+ self._check_encoder_hidden_states_for_generate(
+ output.encoder_hidden_states, batch_size, config, seq_length
)
- # we should get `max_length` in shape, not `max_length - embeds_length`
- cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
- self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
- self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
- self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
-
- @pytest.mark.generate
- def test_generate_continue_from_past_key_values(self):
- # Tests that we can continue generating from past key values, returned from a previous `generate` call
- for model_class in self.all_generative_model_classes:
- if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
- self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
- if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
- self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
-
- config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ # decoder
+ self._check_hidden_states_for_generate(
+ num_sequences_in_output,
+ output.decoder_hidden_states,
+ min_length=1,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
+ else:
+ # if use_cache first input is equal to no use_cache, so skip here
+ hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
+ min_length = seq_length if not use_cache else seq_length + 1
+ self._check_hidden_states_for_generate(
+ num_sequences_in_output,
+ hidden_states,
+ min_length=min_length,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
- if not hasattr(config, "use_cache"):
- self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
-
- # Let's make it always:
- # 1. use cache (for obvious reasons)
- # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
- # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
- # continuation would force it to generate beyond an EOS token)
- # 3. ignore `token_type_ids` for simplicity
- # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
- # active by default on some models
- # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
- # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
- # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
- # with cache, what is considered a prompt is different in the two cases.
-
- if "token_type_ids" in inputs:
- del inputs["token_type_ids"]
-
- model = model_class(config).to(torch_device)
- model.eval()
- model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
- model.generation_config.forced_eos_token_id = None
- model.generation_config.encoder_no_repeat_ngram_size = 0
- model.generation_config.use_cache = True
-
- # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
- outputs = model(**inputs)
- if "past_key_values" not in outputs:
- self.skipTest(reason="This model doesn't return `past_key_values`")
-
- # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
- outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
-
- # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
- # inputs may need to be tweaked across `generate` calls (like the attention mask).
- outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
-
- # Continue from the tokens generated above, preparing the inputs accordingly
- inputs["past_key_values"] = outputs_cached.past_key_values
- new_attention_len = outputs_cached.sequences.shape[-1]
- if config.is_encoder_decoder:
- inputs["decoder_input_ids"] = outputs_cached.sequences
- if "decoder_attention_mask" in inputs:
- inputs["decoder_attention_mask"] = torch.nn.functional.pad(
- inputs["decoder_attention_mask"],
- (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
- mode="constant",
- value=1,
- )
- else:
- inputs["input_ids"] = outputs_cached.sequences
- if "attention_mask" in inputs:
- inputs["attention_mask"] = torch.nn.functional.pad(
- inputs["attention_mask"],
- (0, new_attention_len - inputs["attention_mask"].shape[1]),
- mode="constant",
- value=1,
- )
- outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
-
- # The two sets of generated text and past kv should be equal to each other
- self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
- for layer_idx in range(len(outputs_cached.past_key_values)):
- for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
- self.assertTrue(
- torch.allclose(
- outputs.past_key_values[layer_idx][kv_idx],
- outputs_cached.past_key_values[layer_idx][kv_idx],
- )
- )
-
- @parameterized.expand([(1, False), (1, True), (4, False)])
- @pytest.mark.generate
- def test_new_cache_format(self, num_beams, do_sample):
- # Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
- # 👉 tests with and without beam search so that we can test with and without cache reordering.
- # 👉 tests with and without sampling so we can cover the most common use cases.
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_cache_class:
- self.skipTest(reason="This model does not support the new cache format")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
-
- model = model_class(config).to(torch_device).eval()
- generation_kwargs = {
- "max_new_tokens": 5,
- "do_sample": do_sample,
- "num_beams": num_beams,
- "num_return_sequences": num_beams,
- "return_dict_in_generate": True, # Required to return `past_key_values`
- "use_cache": True,
- }
-
- # Sets seed before calling `generate` for the case with do_sample=True
- seed = torch.randint(0, 1000000, (1,)).item()
- set_seed(seed)
- legacy_results = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
- )
- set_seed(seed)
- if config.is_encoder_decoder:
- cache_cls = EncoderDecoderCache
- past_key_values = cache_cls(DynamicCache(), DynamicCache())
- past_key_values = cache_cls(DynamicCache(), DynamicCache())
- else:
- cache_cls = DynamicCache
- past_key_values = cache_cls()
-
- new_results = model.generate(
- input_ids,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- **generation_kwargs,
- **inputs_dict,
- )
-
- # The two sets of generated sequences must match, despite the cache format between forward passes being
- # different
- self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
- self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
- self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
-
- # The contents of the two caches, when converted to the same format (in both directions!), must match
- legacy_cache = legacy_results.past_key_values
- new_cache_converted = new_results.past_key_values.to_legacy_cache()
- for layer_idx in range(len(legacy_cache)):
- for kv_idx in range(len(legacy_cache[layer_idx])):
- # TODO: @raushan, please look into this for new cache format
- if legacy_cache[layer_idx][kv_idx] != []:
- self.assertTrue(
- torch.allclose(
- legacy_cache[layer_idx][kv_idx],
- new_cache_converted[layer_idx][kv_idx],
- )
- )
-
- new_cache = new_results.past_key_values
- legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
- for layer_idx in range(len(new_cache)):
- for kv_idx in range(len(new_cache[layer_idx])):
- # TODO: @raushan, please look into this for new cache format
- if new_cache[layer_idx][kv_idx] != []:
- self.assertTrue(
- torch.allclose(
- new_cache[layer_idx][kv_idx],
- legacy_cache_converted[layer_idx][kv_idx],
- )
- )
-
- @pytest.mark.generate
- def test_generate_with_static_cache(self):
- """
- Tests if StaticCache works if we set attn_implementation=static when generation.
- This doesn't test if generation quality is good, but tests that models with
- self._supports_static_cache don't throw an error when generating and return
- a StaticCache object at the end.
- """
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
- self.skipTest(reason="This model does not support the static cache format")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- if config.is_encoder_decoder:
- self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
-
- config.is_decoder = True
- batch_size, seq_length = input_ids.shape
- max_new_tokens = 20
-
- model = model_class(config).to(torch_device).eval()
- generation_kwargs = {
- "max_length": None,
- "max_new_tokens": max_new_tokens,
- "cache_implementation": "static",
- "return_dict_in_generate": True, # Required to return `past_key_values`
- "use_cache": True,
- }
-
- max_cache_len = seq_length + max_new_tokens
- config = config.text_config if hasattr(config, "text_config") else config
- head_dim = (
- config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- )
- num_key_value_heads = (
- config.num_attention_heads
- if getattr(config, "num_key_value_heads", None) is None
- else config.num_key_value_heads
- )
- num_hidden_layers = config.num_hidden_layers
- results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
-
- cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
- self.assertTrue(isinstance(results.past_key_values, StaticCache))
- self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
- self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
-
- @require_quanto
- @pytest.mark.generate
- def test_generate_with_quant_cache(self):
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_quantized_cache:
- self.skipTest(reason="This model does not support the quantized cache format")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- config.is_decoder = True
-
- model = model_class(config).to(torch_device).eval()
- generation_kwargs = {
- "max_new_tokens": 5,
- "cache_implementation": "quantized",
- # careful with group size, should be divisor of model's hidden size
- "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
- "return_dict_in_generate": True, # Required to return `past_key_values`
- "use_cache": True,
- }
-
- results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
- self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
-
- # passing past key values of different type should raise Error
- with self.assertRaises(ValueError):
- num_hidden_layers = config.get_text_config().num_hidden_layers
- model.generate(
- input_ids,
- attention_mask=attention_mask,
- past_key_valyes=DynamicCache(num_hidden_layers),
- **generation_kwargs,
- )
-
- # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
- generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
- with self.assertRaises(ValueError):
- model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
-
- @pytest.mark.generate
- @require_torch_gpu
- @slow
- @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
- def test_generate_compile_fullgraph(self):
- """
- Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
- ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
- """
- for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
- self.skipTest("This model doesn't support static cache")
- # TODO (joao) -- fix and enable me :)
- if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
- self.skipTest("whisper model end-to-end generate compile not yet supported")
-
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- # TODO (joao) -- fix and enable me :)
- if config.is_encoder_decoder:
- self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
-
- model = model_class(config).to(torch_device)
- model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
-
- input_ids = inputs_dict["input_ids"].to(torch_device)
- # creates two sets of *different* inputs with the same shape
- half_batch_size = input_ids.shape[0] // 2
- input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
- self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
-
- generation_kwargs = {
- "do_sample": False,
- "max_new_tokens": 10,
- }
-
- max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
- config = config.get_text_config()
- past_key_values = StaticCache(
- config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
- )
-
- for model_inputs in input_ids_sets:
- # eager dynamic cache
- output_dynamic = model.generate(model_inputs, **generation_kwargs)
-
- # end-to-end compiled dynamic cache
- torch.compiler.reset()
- compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
- generation_config = copy.deepcopy(model.generation_config)
- generation_config.update(**generation_kwargs)
- output_compiled = compiled_generate(
- model_inputs, generation_config=generation_config, past_key_values=past_key_values
- )
- self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
-
- @pytest.mark.generate
- def test_generate_methods_with_num_logits_to_keep(self):
- for model_class in self.all_generative_model_classes:
- if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
- self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- config.use_cache = True
- config.is_decoder = True
-
- model = model_class(config).to(torch_device).eval()
- # All generation methods (except assisted decoding) rely on always extracting the last token logits of the
- # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
- # other methods will work as well)
- generation_kwargs = {
- "max_new_tokens": 10,
- "do_sample": False,
- }
-
- # Setting num_logits_to_keep at 0 keeps all logits (old behavior)
- with_all_logits = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
- )
- # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
- without_all_logits = model.generate(
- input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
- )
- self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
-
- @pytest.mark.generate
- @is_flaky() # assisted generation tests are flaky (minor fp ops differences)
- def test_assisted_decoding_with_num_logits_to_keep(self):
- for model_class in self.all_generative_model_classes:
- if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
- self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
- if model_class._is_stateful:
- self.skipTest(reason="Stateful models don't support assisted generation")
-
- config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
- config.use_cache = True
- config.is_decoder = True
-
- model = model_class(config).to(torch_device).eval()
- assistant_model = model
- # All generation methods (except assisted decoding) rely on always extracting the last token logits of the
- # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
- # other methods will work as well)
- generation_kwargs = {
- "max_new_tokens": 10,
- "do_sample": False,
- "assistant_model": assistant_model,
- }
-
- assistant_model.generation_config.assistant_confidence_threshold = None
- # Setting num_logits_to_keep at 0 keeps all logits (old behavior)
- with_all_logits = model.generate(
- input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
- )
- # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
- without_all_logits = model.generate(
- input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
- )
- self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
-
- @pytest.mark.generate
- def test_inherits_generation_mixin(self):
- """
- Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
- to inherit it.
- """
- for model_class in self.all_generative_model_classes:
- self.assertTrue("GenerationMixin" in str(model_class.__bases__))
-
- def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
- batch_size, seq_length = input_ids.shape
- config = config.text_config if hasattr(config, "text_config") else config
- num_sequences_in_output = batch_size * num_return_sequences
-
- gen_len = (
- output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
- )
-
- # scores
- self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
-
- # unprocessed logits
- self._check_logits(num_sequences_in_output, output.logits, config=config)
-
- # Attentions
- if self.has_attentions:
- if config.is_encoder_decoder:
- # encoder
- self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
- # decoder
- self._check_attentions_for_generate(
- num_sequences_in_output,
- output.decoder_attentions,
- min_length=1,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
- else:
- # if use_cache first input is equal to no use_cache, so skip here
- attentions = output.attentions if not use_cache else output.attentions[1:]
- min_length = seq_length if not use_cache else seq_length + 1
- self._check_attentions_for_generate(
- num_sequences_in_output,
- attentions=attentions,
- min_length=min_length,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
-
- # Hidden States
- if config.is_encoder_decoder:
- # encoder
- self._check_encoder_hidden_states_for_generate(
- output.encoder_hidden_states, batch_size, config, seq_length
- )
-
- # decoder
- self._check_hidden_states_for_generate(
- num_sequences_in_output,
- output.decoder_hidden_states,
- min_length=1,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
- else:
- # if use_cache first input is equal to no use_cache, so skip here
- hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
- min_length = seq_length if not use_cache else seq_length + 1
- self._check_hidden_states_for_generate(
- num_sequences_in_output,
- hidden_states,
- min_length=min_length,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
-
- # Past Key Value States -- a few notes here:
- # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
- # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
- # standard cache format (e.g.gptbigcode )
- models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
- has_standard_cache = not any(
- model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
- )
- if has_standard_cache:
- if use_cache:
- past_key_values = output.past_key_values
- past_sequence_length = output.sequences.shape[-1] - 1
- self._check_past_key_values_for_generate(
- num_sequences_in_output,
- past_key_values,
- seq_length=past_sequence_length,
- config=config,
- )
- elif use_cache is False:
- self.assertTrue(output.past_key_values is None)
-
- def _check_scores(self, batch_size, scores, length, config):
- vocab_size = config.get_text_config(decoder=True).vocab_size
- expected_shape = (batch_size, vocab_size)
- self.assertIsInstance(scores, tuple)
- self.assertEqual(len(scores), length)
- self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
-
- def _check_logits(self, batch_size, scores, config):
- vocab_size = config.get_text_config(decoder=True).vocab_size
- self.assertIsInstance(scores, tuple)
- self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
- # vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
- vocab_diff = vocab_size - scores[0].shape[-1]
- self.assertTrue(vocab_diff in [0, 1])
- self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
+ def _check_scores(self, batch_size, scores, length, config):
+ expected_shape = (batch_size, config.vocab_size)
+ self.assertIsInstance(scores, tuple)
+ self.assertEqual(len(scores), length)
+ self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
@@ -2329,30 +2318,6 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c
[encoder_expected_shape] * len(hidden_states),
)
- def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
- self.assertIsInstance(past_key_values, tuple)
- self.assertListEqual(
- [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
- [True] * len(past_key_values),
- )
-
- # (batch, head, seq_length, head_features)
- expected_shape = (
- batch_size * num_beam_groups,
- config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
- seq_length,
- config.hidden_size // config.num_attention_heads,
- )
- # check shape key, value
- self.assertListEqual(
- [layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
- [expected_shape] * len(past_key_values),
- )
- self.assertListEqual(
- [layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
- [expected_shape] * len(past_key_values),
- )
-
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
@@ -2377,45 +2342,6 @@ def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
self.assertTrue(flag)
-@require_torch
-class UtilsFunctionsTest(unittest.TestCase):
- def test_speculative_sampling(self):
- # assume vocab size 10, input length 5 + 3 generated candidates
- candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
- candidate_logits = torch.tensor(
- [
- [
- [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
- [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
- [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
- ]
- ]
- )
- candidate_length = 3
- inf = float("inf")
- new_logits = torch.tensor(
- [
- [
- [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
- [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
- [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
- [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
- ]
- ]
- )
- last_assistant_token_is_eos = False
- validated_tokens, n_matches = _speculative_sampling(
- candidate_input_ids,
- candidate_logits,
- candidate_length,
- new_logits,
- last_assistant_token_is_eos,
- )
- self.assertTrue(n_matches.item() == 2)
- self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
-
-
-@pytest.mark.generate
@require_torch
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
@@ -2433,7 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
}
@slow
- @pytest.mark.skip("Group beam search is not supported by optimum-habana")
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
@@ -2468,101 +2393,284 @@ def test_diverse_beam_search(self):
],
)
- def test_max_length_if_input_embeds(self):
+ def test_max_length_backward_compat_greedy(self):
# PT-only test: TF doesn't have StoppingCriteria
- article = "Today a dragon flew over Paris."
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- inputs_embeds = model.get_input_embeddings()(input_ids)
-
- # Controlling max_length via the configuration is deprecated in favor of max_new_tokens
- max_new_tokens = 20
- input_len = input_ids.shape[-1]
- out_gen = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
- out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens)
- self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
+ article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
+ bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- def test_min_length_if_input_embeds(self):
- # PT-only test: TF doesn't have StoppingCriteria
- article = "Today a dragon flew over Paris."
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- inputs_embeds = model.get_input_embeddings()(input_ids)
+ max_length = 20
+ input_ids = input_ids.expand(2, -1)
+ model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
+ input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
+ batch_size=input_ids.shape[0],
+ model_input_name=bart_model.main_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=bart_model.config.decoder_start_token_id,
+ bos_token_id=bart_model.config.bos_token_id,
+ )
- # Controlling max_length via the configuration is deprecated in favor of max_new_tokens
- min_length = 10
- input_len = input_ids.shape[-1]
- out_gen = model.generate(input_ids=input_ids, min_length=min_length, max_new_tokens=20)
- out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length, max_new_tokens=20)
- self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
+ with self.assertWarns(UserWarning):
+ bart_model.greedy_search(
+ input_ids,
+ max_length=max_length,
+ pad_token_id=bart_model.config.pad_token_id,
+ eos_token_id=bart_model.config.eos_token_id,
+ **model_kwargs,
+ )
- def test_custom_stopping_criteria_overload_error(self):
+ def test_max_length_backward_compat_sample(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
- bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
-
+ bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- stopping_criteria = StoppingCriteriaList()
- stopping_criteria.append(MaxLengthCriteria(max_length=42))
- with self.assertRaises(ValueError):
- bart_model.generate(input_ids, stopping_criteria=stopping_criteria)
- with self.assertRaises(ValueError):
- bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
- def test_custom_stopping_criteria(self):
+ max_length = 20
+ input_ids = input_ids.expand(2, -1)
+ model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
+ input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
+ batch_size=input_ids.shape[0],
+ model_input_name=bart_model.main_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=bart_model.config.decoder_start_token_id,
+ bos_token_id=bart_model.config.bos_token_id,
+ )
+ with torch.no_grad():
+ with self.assertWarns(UserWarning):
+ bart_model.sample(
+ input_ids,
+ max_length=max_length,
+ pad_token_id=bart_model.config.pad_token_id,
+ eos_token_id=bart_model.config.eos_token_id,
+ **model_kwargs,
+ )
+
+ def test_max_length_backward_compat_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
- bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
+ bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- class DummyCriteria(StoppingCriteria):
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
- return input_ids.shape[-1] >= 20
-
- stopping_criteria = StoppingCriteriaList()
- stopping_criteria.append(DummyCriteria())
+ batch_size = 1
+ max_length = 20
+ num_beams = 2
- output = bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22)
- self.assertEqual(
- list(output.shape),
- [1, 22], # still produces the max_length
+ input_ids = input_ids.expand(2, -1)
+ model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
+ input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
+ batch_size=input_ids.shape[0],
+ model_input_name=bart_model.main_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=bart_model.config.decoder_start_token_id,
+ bos_token_id=bart_model.config.bos_token_id,
)
- # make sure final tokens are padding
- self.assertEqual(output[:, 20:].tolist(), [[bart_model.config.pad_token_id, bart_model.config.pad_token_id]])
- self.assertEqual(
- list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape),
- [1, 18],
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=torch_device,
)
+ with self.assertWarns(UserWarning):
+ _ = bart_model.beam_search(
+ input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
+ )
- # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_stop_sequence_stopping_criteria(self):
- # PT-only test: TF doesn't have StoppingCriteria
- prompt = """Hello I believe in"""
- generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
- output = generator(prompt)
- self.assertEqual(
- output,
- [{"generated_text": ("Hello I believe in we we we we we we we we we")}],
+ def test_max_length_backward_compat_group_beam_search(self):
+ # PT-only test: TF doesn't have StoppingCriteria & group beam search
+ article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
+ bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
)
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- output = generator(prompt, stop_sequence=" we")
- self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
+ batch_size = 1
+ max_length = 20
+ num_beams = 6
+ num_beam_groups = 3
+ num_return_sequences = num_beams * batch_size
- def test_generate_non_nlp_input_ids_as_kwarg(self):
- # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
- model = ImageGPTForCausalImageModeling.from_pretrained(
- "hf-internal-testing/tiny-random-imagegpt", max_length=10
- ).to(torch_device)
- input_ids = ids_tensor((3, 5), vocab_size=10)
+ input_ids = input_ids.expand(6, -1)
+ model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
+ input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
+ batch_size=input_ids.shape[0],
+ model_input_name=bart_model.main_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=bart_model.config.decoder_start_token_id,
+ bos_token_id=bart_model.config.bos_token_id,
+ )
- output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
+ diverse_beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=torch_device,
+ num_beam_hyps_to_keep=num_return_sequences,
+ num_beam_groups=num_beam_groups,
+ )
+ with self.assertWarns(UserWarning):
+ bart_model.group_beam_search(
+ input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
+ )
+
+ def test_max_length_warning_if_different(self):
+ # PT-only test: TF doesn't have StoppingCriteria
+ article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
+ bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+
+ batch_size = 1
+
+ max_length = 20
+ num_beams = 6
+ num_beam_groups = 3
+ num_return_sequences = num_beams * batch_size
+ stopping_criteria_max_length = 18
+ stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
+
+ # Greedy
+ input_ids = input_ids.expand(6, -1)
+ model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
+ input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
+ batch_size=input_ids.shape[0],
+ model_input_name=bart_model.main_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=bart_model.config.decoder_start_token_id,
+ bos_token_id=bart_model.config.bos_token_id,
+ )
+
+ with self.assertWarns(UserWarning):
+ bart_model.greedy_search(
+ input_ids,
+ max_length=max_length,
+ pad_token_id=bart_model.config.pad_token_id,
+ stopping_criteria=stopping_criteria,
+ eos_token_id=bart_model.config.eos_token_id,
+ **model_kwargs,
+ )
+
+ # Sample
+ with self.assertWarns(UserWarning):
+ with torch.no_grad():
+ bart_model.sample(
+ input_ids,
+ max_length=max_length,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=bart_model.config.pad_token_id,
+ eos_token_id=bart_model.config.eos_token_id,
+ **model_kwargs,
+ )
+
+ # Beam
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=torch_device,
+ )
+ with self.assertWarns(UserWarning):
+ with torch.no_grad():
+ bart_model.beam_search(
+ input_ids,
+ num_beams=num_beams,
+ stopping_criteria=stopping_criteria,
+ max_length=max_length,
+ beam_scorer=beam_scorer,
+ **model_kwargs,
+ )
+
+ # Grouped beam search
+ diverse_beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=torch_device,
+ num_beam_hyps_to_keep=num_return_sequences,
+ num_beam_groups=num_beam_groups,
+ )
+ with self.assertWarns(UserWarning):
+ bart_model.group_beam_search(
+ input_ids,
+ diverse_beam_scorer,
+ stopping_criteria=stopping_criteria,
+ num_beams=num_beams,
+ max_length=max_length,
+ **model_kwargs,
+ )
+
+ def test_custom_stopping_criteria_overload_error(self):
+ # PT-only test: TF doesn't have StoppingCriteria
+ article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
+ bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
+ bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
+
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ stopping_criteria = StoppingCriteriaList()
+ stopping_criteria.append(MaxLengthCriteria(max_length=42))
+ with self.assertRaises(ValueError):
+ bart_model.generate(input_ids, stopping_criteria=stopping_criteria)
+ with self.assertRaises(ValueError):
+ bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
+
+ def test_custom_stopping_criteria(self):
+ # PT-only test: TF doesn't have StoppingCriteria
+ article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
+ bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
+ bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+
+ class DummyCriteria(StoppingCriteria):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ return input_ids.shape[-1] >= 20
+
+ stopping_criteria = StoppingCriteriaList()
+ stopping_criteria.append(DummyCriteria())
+
+ self.assertEqual(
+ list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22).shape),
+ [1, 20],
+ )
+ self.assertEqual(
+ list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape),
+ [1, 18],
+ )
+
+ def test_stop_sequence_stopping_criteria(self):
+ # PT-only test: TF doesn't have StoppingCriteria
+ prompt = """Hello I believe in"""
+ generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
+ output = generator(prompt)
+ self.assertEqual(
+ output,
+ [
+ {
+ "generated_text": (
+ "Hello I believe in in in number number number number number number number number number"
+ )
+ }
+ ],
+ )
+
+ output = generator(prompt, stop_sequence=" number")
+ self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
+
+ def test_generate_non_nlp_input_ids_as_kwarg(self):
+ # PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
+ model = ImageGPTForCausalImageModeling.from_pretrained(
+ "hf-internal-testing/tiny-random-imagegpt", max_length=10
+ ).to(torch_device)
+ input_ids = ids_tensor((3, 5), vocab_size=10)
+
+ output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
@@ -2579,7 +2687,6 @@ def test_generate_input_values_as_encoder_kwarg(self):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
- @pytest.mark.skip("Group beam search is not supported by optimum-habana")
def test_transition_scores_group_beam_search_encoder_decoder(self):
# PT-only test: TF doesn't have group beam search
articles = [
@@ -2609,61 +2716,13 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_beam_search_low_memory(self):
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
- tokenizer.pad_token_id = tokenizer.eos_token_id
- model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
-
- low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
-
- high_output = model.generate(
- model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
- )
- self.assertListEqual(low_output.tolist(), high_output.tolist())
-
- @slow
- @pytest.mark.skip("Watermarking is not supported by optimum-habana yet")
- def test_watermark_generation(self):
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer.pad_token_id = tokenizer.eos_token_id
- model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
- input_len = model_inputs["input_ids"].shape[-1]
-
- # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
- watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
- _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
-
- # We will not check watermarked text, since we check it in `logits_processors` tests
- # Checking if generated ids are as expected fails on different hardware
- args = {
- "bias": 2.0,
- "context_width": 1,
- "seeding_scheme": "selfhash",
- "greenlist_ratio": 0.25,
- "hashing_key": 15485863,
- }
- output = model.generate(**model_inputs, do_sample=False, max_length=15)
- output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
-
- # Check that the detector is detecting watermarked text
- detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
- detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
- detection_out = detector(output[:, input_len:], return_dict=True)
-
- self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
- self.assertListEqual(detection_out.prediction.tolist(), [False])
-
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
# exactly the example provided in the docstrings of beam search, which previously
# failed after directly copying from it. Refer to PR #15555
- tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@@ -2671,15 +2730,31 @@ def test_beam_search_example_integration(self):
# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
- input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
+ input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
- model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
+ model_kwargs = {
+ "encoder_outputs": model.get_encoder()(
+ encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
+ )
+ }
- outputs = model.generate(
- input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs
+ # instantiate beam scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=1,
+ num_beams=num_beams,
+ device=model.device,
+ )
+
+ # instantiate logits processors
+ logits_processor = LogitsProcessorList(
+ [
+ MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
+ ]
)
+
+ outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alt bist du?"])
@@ -2687,8 +2762,8 @@ def test_beam_search_example_integration(self):
@slow
def test_constrained_beam_search(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
@@ -2725,8 +2800,8 @@ def test_constrained_beam_search(self):
@slow
def test_constrained_beam_search_mixed(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
@@ -2766,8 +2841,8 @@ def test_constrained_beam_search_mixed(self):
@slow
def test_constrained_beam_search_mixed_mixin(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
@@ -2802,15 +2877,9 @@ def test_constrained_beam_search_mixed_mixin(self):
)
@slow
- @pytest.mark.xfail
def test_cfg_mixin(self):
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
-
- # add pad_token_id for static shape
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- model.generation_config.pad_token_id = model.generation_config.eos_token_id
+ model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
input["input_ids"] = input["input_ids"].to(torch_device)
@@ -2850,8 +2919,8 @@ def test_cfg_mixin(self):
@slow
def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search
- tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["sind"]
@@ -2875,8 +2944,8 @@ def test_constrained_beam_search_example_translation_mixin(self):
@slow
def test_constrained_beam_search_example_integration(self):
# PT-only test: TF doesn't have constrained beam search
- tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@@ -2884,65 +2953,38 @@ def test_constrained_beam_search_example_integration(self):
# lets run beam search using 5 beams
num_beams = 5
# define decoder start token ids
- input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
+ input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
- model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
+ model_kwargs = {
+ "encoder_outputs": model.get_encoder()(
+ encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
+ )
+ }
constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
+ constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
- outputs = model.generate(
- input_ids,
- num_beams=num_beams,
- force_words_ids=[constraint_token_ids],
- min_length=5,
- eos_token_id=model.config.eos_token_id,
- **model_kwargs,
+ # instantiate beam scorer
+ beam_scorer = ConstrainedBeamSearchScorer(
+ batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
)
- outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- self.assertListEqual(outputs, ["Wie alt sind Sie?"])
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- @slow
- def test_per_row_stopping_criteria(self):
- text = [
- "They completed the challenging puzzle, revealing the hidden",
- "Today a dragon flew over France",
- "The aroma of freshly baked pizza filled the kitchen",
- ]
- stop_strings = ["secrets"]
+ # instantiate logits processors
+ logits_processor = LogitsProcessorList(
+ [
+ MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
+ ]
+ )
- model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
- tokenizer.padding_side = "left"
- tokenizer.pad_token_id = tokenizer.eos_token_id
- input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to(
- torch_device
+ outputs = model.constrained_beam_search(
+ input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- # normal generation with one stopping criteria
- out = model.generate(input_ids, max_length=15)
- out_text = tokenizer.batch_decode(out)
- expected_out = [
- "They completed the challenging puzzle, revealing the hidden secrets of the world.\n",
- "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
- "The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
- ]
- self.assertListEqual(out_text, expected_out)
-
- # generation should stop at "secrets" for first batch only, filling the rest with eos tokens
- out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer)
- out_text = tokenizer.batch_decode(out)
- expected_out = [
- "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>",
- "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
- "The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
- ]
- self.assertListEqual(out_text, expected_out)
+ self.assertListEqual(outputs, ["Wie alt sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self):
# PT-only test: TF doesn't have constrained beam search
@@ -2985,55 +3027,6 @@ def test_constrained_beam_search_mixin_type_checks(self):
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
- def test_batched_decoder_start_id(self):
- # PT-only test: TF doesn't support batched_decoder_start_id
- articles = [
- "Justin Timberlake and Jessica Biel, welcome to parenthood.",
- "Michael Phelps is arguably the most decorated Olympian of all time.",
- ]
- bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
- input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
- decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
- decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0]
-
- outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id)
-
- outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch)
-
- self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
-
- def test_decoder_start_id_from_config(self):
- # Refer to: (#30899)
- articles = [
- "Justin Timberlake and Jessica Biel, welcome to parenthood.",
- "Michael Phelps is arguably the most decorated Olympian of all time.",
- ]
- bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
- input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
- decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
-
- # we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
- outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
-
- # If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
- bart_model.generation_config.decoder_start_token_id = None
- bart_model.generation_config.bos_token_id = None
- outputs_with_user_id = bart_model.generate(
- input_ids,
- generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
- )
-
- self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())
-
- with self.assertRaises(ValueError):
- outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
-
def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
@@ -3060,27 +3053,6 @@ def test_contrastive_search_batched(self):
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
- def test_logits_processor_not_inplace(self):
- # PT-only test: TF fixes were not made
- article = "Today a dragon flew over Paris."
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
-
- out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True)
- out_with_temp = model.generate(
- input_ids,
- temperature=0.5,
- do_sample=True,
- output_logits=True,
- output_scores=True,
- return_dict_in_generate=True,
- )
-
- # if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores
- self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
- self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())
-
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {
@@ -3135,10 +3107,6 @@ def forward(self, input_ids, foo=None, **kwargs):
# because it doesn't do signature filtering.
class FakeEncoder(bart_model.model.encoder.__class__):
def forward(self, input_ids, **kwargs):
- # We remove these to pass gaudi_BartEncoder_forward TypeError
- kwargs.pop("bucket_size", None)
- kwargs.pop("bucket_internal", None)
- kwargs.pop("reduce_recompile", None)
return super().forward(input_ids, **kwargs)
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device)
@@ -3153,16 +3121,15 @@ def forward(self, input_ids, **kwargs):
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model.generation_config.pad_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Default generation config value of 20 -> emits warning
- # NOTE: in OH we do not have this warning
- # with self.assertWarns(UserWarning):
- # model.generate(input_ids)
+ with self.assertWarns(UserWarning):
+ model.generate(input_ids)
# Explicitly setting max_length to 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
@@ -3171,805 +3138,7 @@ def test_default_max_length_warning(self):
# Generation config max_length != 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
- # generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model.generation_config.max_length = 10
+ model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_length_warning_assisted_generation(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model.generation_config.pad_token_id = tokenizer.eos_token_id
- assistant.generation_config.pad_token_id = tokenizer.eos_token_id
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- # This should not raise any warning that min length is not feasible in candidate generation
- with warnings.catch_warnings(record=True) as warning_list:
- model.generate(
- input_ids,
- assistant_model=assistant,
- min_new_tokens=10,
- max_length=20,
- )
- self.assertEqual(len(warning_list), 0)
-
- def test_default_assisted_generation(self):
- # Initialize the GenerationConfig object
- config = GenerationConfig()
-
- # Check the default values
- self.assertEqual(config.num_assistant_tokens, 20)
- self.assertEqual(config.num_assistant_tokens_schedule, "constant")
- self.assertEqual(config.assistant_confidence_threshold, 0.4)
- self.assertEqual(config.is_assistant, False)
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_generated_length_assisted_generation(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model.generation_config.pad_token_id = tokenizer.eos_token_id
- assistant.generation_config.pad_token_id = tokenizer.eos_token_id
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
- input_length = input_ids.shape[-1]
-
- out = model.generate(
- input_ids,
- assistant_model=assistant,
- min_new_tokens=10,
- max_new_tokens=20,
- )
- self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))
-
- out = model.generate(
- input_ids,
- assistant_model=assistant,
- min_new_tokens=10,
- )
- self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_model_kwarg_assisted_decoding_decoder_only(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model.generation_config.pad_token_id = tokenizer.eos_token_id
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- # Traditional way of generating text
- outputs_normal = model.generate(input_ids)
- self.assertEqual(outputs_normal.shape, (1, 20))
-
- # Should be different with token_type_ids
- outputs_tti = model.generate(
- input_ids,
- token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
- )
- with self.assertRaises(AssertionError):
- self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
-
- # Assistant model
- assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- assistant.config.pad_token_id = tokenizer.eos_token_id
-
- # If assisted generation passes model_kwargs correctly, should be same as previous
- outputs_assisted = model.generate(
- input_ids,
- token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
- assistant_model=assistant,
- )
- self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_model_kwarg_assisted_decoding_encoder_decoder(self):
- """
- Tests that the following scenario is compatible with assisted generation:
- 1. encoder-decoder main model
- 2. encoder-decoder assistant model
- 3. both have a custom input
- (e.g. Whisper)
- """
-
- # PT-only test: TF doesn't support assisted decoding yet.
- # Bart subclass with a kwarg that distorts the output
- class FakeBart(BartForConditionalGeneration):
- def forward(self, input_ids, past_key_values, foo=False, **kwargs):
- outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
- if foo:
- outs["logits"][:, :, :] = 0.0
- return outs
-
- def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
- kwargs["encoder_outputs"] = encoder_outputs
- inputs = super().prepare_inputs_for_generation(*args, **kwargs)
- inputs["foo"] = foo
- return inputs
-
- model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
- torch_device
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- # Traditional way of generating text
- outputs_normal = model.generate(input_ids)
- self.assertEqual(outputs_normal.shape, (1, 20))
-
- # Should be different with foo
- outputs_foo = model.generate(input_ids, foo=True)
- with self.assertRaises(AssertionError):
- self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
-
- # Assistant model
- assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
- torch_device
- )
-
- # If assisted generation passes model_kwargs correctly, should be same as previous
- outputs_assisted = model.generate(
- input_ids,
- foo=True,
- assistant_model=assistant,
- )
- self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
-
- # Check that passing encoder_outputs directly also works as expected
- encoder_outputs = assistant.get_encoder()(input_ids)
-
- outputs_assisted = model.generate(
- foo=True,
- assistant_model=assistant,
- encoder_outputs=encoder_outputs,
- assistant_encoder_outputs=encoder_outputs,
- )
- self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_assisted_decoding_encoder_decoder_shared_encoder(self):
- """
- Tests that the following scenario is compatible with assisted generation:
- 1. encoder-decoder main model
- 2. decoder-only assistant model
- 3. both have a custom input
- (e.g. DistilWhisper)
- """
-
- # PT-only test: TF doesn't support assisted decoding yet.
- # Bart subclass with a kwarg called foo that distorts the output
- class FakeBartSeq2Seq(BartForConditionalGeneration):
- def forward(self, input_ids, foo=False, **kwargs):
- outs = super().forward(input_ids, **kwargs)
- if foo:
- outs["logits"][:, :, :] = 0.0
- return outs
-
- def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
- kwargs["encoder_outputs"] = encoder_outputs
- inputs = super().prepare_inputs_for_generation(*args, **kwargs)
- inputs["foo"] = foo
- return inputs
-
- class FakeBartCausalLM(BartForCausalLM):
- def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
- outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
- if foo:
- outs["logits"][:, :, :] = 0.0
- return outs
-
- def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
- kwargs["encoder_outputs"] = encoder_outputs
- inputs = super().prepare_inputs_for_generation(*args, **kwargs)
- inputs["foo"] = foo
- return inputs
-
- model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
- torch_device
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- # Traditional way of generating text
- outputs_normal = model.generate(input_ids)
- self.assertEqual(outputs_normal.shape, (1, 20))
-
- # Should be different with foo
- outputs_foo = model.generate(input_ids, foo=True)
- with self.assertRaises(AssertionError):
- self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
-
- # Assistant model
- assistant = FakeBartCausalLM.from_pretrained(
- "hf-internal-testing/tiny-random-BartForConditionalGeneration"
- ).to(torch_device)
-
- # If assisted generation passes model_kwargs correctly, should be same as previous
- outputs_assisted = model.generate(
- input_ids,
- foo=True,
- assistant_model=assistant,
- )
- self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
-
- # Check that passing encoder_outputs directly also works as expected
- encoder_outputs = model.get_encoder()(input_ids)
-
- outputs_assisted = model.generate(
- foo=True,
- assistant_model=assistant,
- encoder_outputs=encoder_outputs,
- )
- self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
- # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
-
- prompt = "Alice and Bob"
- checkpoint = "EleutherAI/pythia-160m-deduped"
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
- inputs = tokenizer(prompt, return_tensors="pt")
-
- model = AutoModelForCausalLM.from_pretrained(checkpoint)
-
- assistant_model = model
- assistant_model.generation_config.num_assistant_tokens = 5
- assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
- generation_kwargs = {
- "eos_token_id": -1,
- "max_new_tokens": 5,
- "do_sample": False,
- "assistant_model": assistant_model,
- }
- model.generate(**inputs, **generation_kwargs)
- # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
- self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
- # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
-
- prompt = "Alice and Bob"
- checkpoint = "EleutherAI/pythia-160m-deduped"
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
- inputs = tokenizer(prompt, return_tensors="pt")
-
- model = AutoModelForCausalLM.from_pretrained(checkpoint)
-
- assistant_model = model
- assistant_model.generation_config.num_assistant_tokens = 5
- assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient"
- generation_kwargs = {
- "eos_token_id": -1,
- "max_new_tokens": 5,
- "do_sample": False,
- "assistant_model": assistant_model,
- }
- model.generate(**inputs, **generation_kwargs)
- # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
- self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @slow
- @pytest.mark.xfail
- def test_validate_assistant(self):
- # Generate a random sample:
- inputs = np.random.rand(160000)
-
- # Load a main encoder-decoder model:
- model_id = "openai/whisper-large-v2"
- processor = AutoProcessor.from_pretrained(model_id)
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
- model_id,
- low_cpu_mem_usage=True,
- use_safetensors=True,
- )
- model.to(torch_device)
-
- # process the input:
- features = processor(inputs, return_tensors="pt").to(torch_device)
-
- # Load an encoder-decoder assistant with same encoder as the main model:
- assistant_distil_model_id = "distil-whisper/distil-large-v2"
- assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
- assistant_distil_model_id,
- use_safetensors=True,
- ).to(torch_device)
- self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
-
- # Load its decoder only version:
- assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
- assistant_distil_model_id,
- low_cpu_mem_usage=True,
- use_safetensors=True,
- ).to(torch_device)
- self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum())
-
- # Load an encoder-decoder assistant with a different encoder than the main model:
- assistant_distil_model_id = "openai/whisper-tiny"
- assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
- assistant_distil_model_id,
- use_safetensors=True,
- ).to(torch_device)
- self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
-
- # Load its decoder only version:
- assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
- assistant_distil_model_id,
- low_cpu_mem_usage=True,
- use_safetensors=True,
- ).to(torch_device)
- # It will raise an error as the encoder of the main and assistant model are not compatible:
- with self.assertRaises(ValueError):
- model.generate(**features, assistant_model=assistant_causal_lm)
-
- # Load an encoder-decoder model with a different tokenizer than the main model:
- assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText"
- assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
- assistant_distil_model_id,
- ).to(torch_device)
- # This should raise an error as the main and assistant model don't use the same tokenizer:
- with self.assertRaises(ValueError):
- model.generate(**features, assistant_model=assistant_seq_to_seq)
-
- def test_compare_unprocessed_logit_scores(self):
- # Get unprocessed logit scores back from model generate function.
- # Assert that unprocessed logits from generate() are same as those from modal eval()
-
- # tell model to generate text and return unprocessed/unwarped logit scores
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- text = "generate yes or no: "
- input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
-
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
-
- with torch.no_grad():
- # Get logits for the next token from fwd pass
- logits_fwd = model(input_ids).logits[:, -1, :][0]
-
- # Get logits for the next token from generate function
- outputs = model.generate(
- input_ids=input_ids,
- return_dict_in_generate=True,
- output_logits=True,
- max_new_tokens=1,
- do_sample=True,
- )
- logits_gen = outputs.logits[0][0]
-
- # assert that unprocessed logits from generate() are same as those from modal eval()
- self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
-
- def test_return_unprocessed_logit_scores(self):
- # tell model to generate text and return unprocessed/unwarped logit scores
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- text = "generate yes or no: "
- input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
-
- outputs = model.generate(
- input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3
- )
-
- # perform dummy check if unpreprocessed logits make sense.
- # do preselection on high probabilities; find scores of y and n tokens
- probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1)
- indices = torch.argwhere(probs_all > 0.001)
- indices = indices[:, -1]
- tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True)
- probs_max = probs_all[probs_all > 0.001]
-
- self.assertTrue(len(indices) >= 2)
- next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)}
- self.assertTrue("n" in next_token_dict)
- self.assertTrue("y" in next_token_dict)
- y_prob = next_token_dict["y"]
- n_prob = next_token_dict["n"]
-
- self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
- self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
-
- @slow
- @require_torch_multi_gpu
- def test_assisted_decoding_in_different_gpu(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
- assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
- "cuda:1"
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
- model.config.pad_token_id = tokenizer.eos_token_id
- assistant.config.pad_token_id = tokenizer.eos_token_id
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
- input_length = input_ids.shape[-1]
-
- out = model.generate(
- input_ids,
- assistant_model=assistant,
- max_new_tokens=20,
- )
- self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
-
- @slow
- @require_torch_gpu
- def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
- assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
- "cpu"
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
- model.config.pad_token_id = tokenizer.eos_token_id
- assistant.config.pad_token_id = tokenizer.eos_token_id
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
- input_length = input_ids.shape[-1]
-
- out = model.generate(
- input_ids,
- assistant_model=assistant,
- max_new_tokens=20,
- )
- self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
-
- def test_special_tokens_fall_back_to_model_default(self):
- # PT-only test: TF doesn't support assisted decoding yet.
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
- torch_device
- )
- test_bos_id = 50
-
- # Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
- gen_output = model.generate()
- self.assertTrue(model.generation_config.bos_token_id is not None)
- self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
-
- # If we pass a generation config **with** a BOS token, `generate` will use it
- generation_config = GenerationConfig(bos_token_id=test_bos_id)
- gen_output = model.generate(generation_config=generation_config)
- self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
- self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
- self.assertTrue(test_bos_id == gen_output[0, 0])
-
- # If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
- # `model.generation_config`
- generation_config = GenerationConfig(bos_token_id=None)
- gen_output = model.generate(generation_config=generation_config)
- self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
- self.assertFalse(test_bos_id == gen_output[0, 0])
- self.assertTrue(generation_config.bos_token_id is None)
-
- # Changing `model.generation_config` will affect fallback behavior
- model.generation_config.bos_token_id = test_bos_id
- gen_output = model.generate(generation_config=generation_config)
- self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
- self.assertTrue(test_bos_id == gen_output[0, 0])
- self.assertTrue(generation_config.bos_token_id is None)
-
- @pytest.mark.generate
- @require_torch_multi_gpu
- def test_generate_with_static_cache_multi_gpu(self):
- """
- Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
- """
- # need to split manually as auto doesn't work well with unbalanced model
- device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
- model = AutoModelForCausalLM.from_pretrained(
- "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- generation_kwargs = {
- "max_new_tokens": 20,
- "cache_implementation": "static",
- "return_dict_in_generate": True, # Required to return `past_key_values`
- }
-
- results = model.generate(input_ids, **generation_kwargs)
- self.assertTrue(isinstance(results.past_key_values, StaticCache))
-
- # check device of each layer
- key_cache_0 = results.past_key_values.key_cache[0]
- value_cache_0 = results.past_key_values.value_cache[0]
- self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
-
- key_cache_1 = results.past_key_values.key_cache[1]
- value_cache_1 = results.past_key_values.value_cache[1]
- self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
-
- @pytest.mark.generate
- @require_torch_multi_gpu
- def test_init_static_cache_multi_gpu(self):
- """
- Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
- """
- # need to split manually as auto doesn't work well with unbalanced model
- device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
- model = AutoModelForCausalLM.from_pretrained(
- "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
-
- text = "Hello world"
- tokenized_inputs = tokenizer([text], return_tensors="pt")
- input_ids = tokenized_inputs.input_ids.to(torch_device)
-
- generation_kwargs = {
- "max_new_tokens": 20,
- "return_dict_in_generate": True, # Required to return `past_key_values`
- }
-
- # TODO: We need to raise a warning in case the cache is not set correctly
- # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
- # past_key_values = StaticCache(
- # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
- # )
- # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
-
- # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
- layer_device_map = {0: 0, 1: 1}
- past_key_values = StaticCache(
- config=model.config,
- batch_size=1,
- max_cache_len=30,
- device=torch_device,
- dtype=model.dtype,
- layer_device_map=layer_device_map,
- )
- results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
-
- # check device of each layer
- key_cache_0 = results.past_key_values.key_cache[0]
- value_cache_0 = results.past_key_values.value_cache[0]
- self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
-
- key_cache_1 = results.past_key_values.key_cache[1]
- value_cache_1 = results.past_key_values.value_cache[1]
- self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
-
- @slow
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_padding_input_contrastive_search_gpt2(self):
- # Load the pre-trained GPT-2 model and tokenizer
- model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
- model.to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
-
- # Set the tokenizer to left-pad the sequences
- tokenizer.padding_side = "left"
-
- # Define the PAD token as the EOS token
- tokenizer.pad_token = tokenizer.eos_token
- model.generation_config.pad_token_id = model.generation_config.eos_token_id
-
- # Define the input prompt
- prompt_text = "The whispered legends of the haunted mansion spoke"
-
- # Tokenize the input prompt
- encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
- input_ids = encoded_prompt.input_ids.to(torch_device)
- attention_mask = encoded_prompt.attention_mask.to(torch_device)
-
- # Define the contrastive search params
- penalty_alpha = 0.6
- top_k = 4
-
- # Define the padding length to add to the input IDs and attention mask
- padding_length = 10
-
- # Generate text without padding
- outputs = model.generate(
- input_ids=input_ids,
- attention_mask=attention_mask,
- do_sample=False,
- penalty_alpha=penalty_alpha,
- top_k=top_k,
- max_new_tokens=64,
- )
- generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
-
- # Pad the input IDs and attention mask on the left
- padded_input_ids = F.pad(
- input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
- )
- padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)
-
- # Generate text with padded inputs
- outputs_with_padding = model.generate(
- input_ids=padded_input_ids,
- attention_mask=padded_attention_mask,
- do_sample=False,
- penalty_alpha=penalty_alpha,
- top_k=top_k,
- max_new_tokens=64,
- )
- generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
-
- # Assert that the generated texts are identical for padded and non-padded inputs
- self.assertEqual(generated_text_no_padding, generated_text_with_padding)
- self.assertEqual(
- generated_text_with_padding,
- 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
- 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
- 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
- )
-
- @slow
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_padding_input_contrastive_search_t5(self):
- # Load the pre-trained T5 model and tokenizer
- model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
- model.to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)
-
- # Define the input prompt
- prompt_text = "translate English to German: I need to finish this task before the end of the day."
-
- # Tokenize the input prompt
- encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
- input_ids = encoded_prompt.input_ids.to(torch_device)
- attention_mask = encoded_prompt.attention_mask.to(torch_device)
-
- # Define the decoder prompt
- decoder_prompt_text = "Ich muss diese Aufgabe"
- encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
- decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
- decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)
-
- # Define the contrastive search params
- penalty_alpha = 0.6
- top_k = 4
-
- # Generate text without padding
- outputs = model.generate(
- input_ids=input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- do_sample=False,
- penalty_alpha=penalty_alpha,
- top_k=top_k,
- max_new_tokens=64,
- )
- generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
-
- # Define the padding length to add to the input IDs and attention mask
- padding_length = 10
-
- # Pad the decoder input IDs and attention mask on the left
- padded_decoder_input_ids = F.pad(
- decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
- )
- padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
- # Since the decoder_start_token_id is the same as the pad_token_id,
- # the last padded token represents the decoder start token.
- # Set the attention mask for the decoder_start_token_id to True (1).
- padded_decoder_attention_mask[:, padding_length - 1] = 1
- # Generate text with padded inputs
- outputs_with_padding = model.generate(
- input_ids=input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=padded_decoder_input_ids,
- decoder_attention_mask=padded_decoder_attention_mask,
- do_sample=False,
- penalty_alpha=penalty_alpha,
- top_k=top_k,
- max_new_tokens=64,
- )
- generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
-
- # Assert that the generated texts are identical for padded and non-padded inputs
- self.assertEqual(generated_text_no_padding, generated_text_with_padding)
- self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
-
- # TODO [gustavo] Enable this test to Optimum-habana
- @pytest.mark.xfail
- def test_generate_compile_fullgraph_tiny(self):
- """
- Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
- NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
- non-slow tests to prevent regressions!
- """
- model = AutoModelForCausalLM.from_pretrained(
- "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
- )
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
-
- # compile generate
- compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
-
- # compiled generate does NOT accept parameterization except a) model inputs b) a generation config
- generation_config = copy.deepcopy(model.generation_config)
- generation_config.pad_token_id = model.config.eos_token_id
-
- model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
- model_inputs = model_inputs.to(model.device)
- gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
- self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
-
-
-@require_torch
-class TokenHealingTestCase(unittest.TestCase):
- @parameterized.expand(
- [
- (
- "square_bracket",
- 'An example ["like this"] and another example [',
- 'An example ["like this"] and another example ["',
- ),
- ("url", 'The link is