diff --git a/conftest.py b/conftest.py
index 71cb6bb7ca..5775644c48 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,3 +1,88 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tests directory-specific settings - this file is run automatically
+# by pytest before any tests are run
+import doctest
+import sys
+import warnings
+from os.path import abspath, dirname, join
+
+import _pytest
+import pytest
+from transformers.testing_utils import HfDoctestModule, HfDocTestParser
+
+
+NOT_DEVICE_TESTS = {
+ "test_tokenization",
+ "test_processor",
+ "test_processing",
+ "test_beam_constraints",
+ "test_configuration_utils",
+ "test_data_collator",
+ "test_trainer_callback",
+ "test_trainer_utils",
+ "test_feature_extraction",
+ "test_image_processing",
+ "test_image_processor",
+ "test_image_transforms",
+ "test_optimization",
+ "test_retrieval",
+ "test_config",
+ "test_from_pretrained_no_checkpoint",
+ "test_keep_in_fp32_modules",
+ "test_gradient_checkpointing_backward_compatibility",
+ "test_gradient_checkpointing_enable_disable",
+ "test_save_load_fast_init_from_base",
+ "test_fast_init_context_manager",
+ "test_fast_init_tied_embeddings",
+ "test_save_load_fast_init_to_base",
+ "test_torch_save_load",
+ "test_initialization",
+ "test_forward_signature",
+ "test_model_get_set_embeddings",
+ "test_model_main_input_name",
+ "test_correct_missing_keys",
+ "test_tie_model_weights",
+ "test_can_use_safetensors",
+ "test_load_save_without_tied_weights",
+ "test_tied_weights_keys",
+ "test_model_weights_reload_no_missing_tied_weights",
+ "test_pt_tf_model_equivalence",
+ "test_mismatched_shapes_have_properly_initialized_weights",
+ "test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist",
+ "test_model_is_small",
+ "test_tf_from_pt_safetensors",
+ "test_flax_from_pt_safetensors",
+ "ModelTest::test_pipeline_", # None of the pipeline tests from PipelineTesterMixin (of which XxxModelTest inherits from) are running on device
+ "ModelTester::test_pipeline_",
+ "/repo_utils/",
+ "/utils/",
+ "/agents/",
+}
+
+# allow having multiple repository checkouts and not needing to remember to rerun
+# `pip install -e '.[dev]'` when switching between checkouts and running tests.
+git_repo_path = abspath(join(dirname(__file__), "src"))
+sys.path.insert(1, git_repo_path)
+
+# silence FutureWarning warnings in tests since often we can't act on them until
+# they become normal warnings - i.e. the tests still need to test the current functionality
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
class Secret:
"""
Taken from: https://stackoverflow.com/a/67393351
@@ -13,9 +98,47 @@ def __str___(self):
return "*******"
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
+ )
+ config.addinivalue_line(
+ "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
+ )
+ config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
+ config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
+ config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
+ config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule")
+ config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
+
+
+def pytest_collection_modifyitems(items):
+ for item in items:
+ if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS):
+ item.add_marker(pytest.mark.not_device_test)
+
+
def pytest_addoption(parser):
parser.addoption("--token", action="store", default=None)
+ from transformers.testing_utils import pytest_addoption_shared
+
+ pytest_addoption_shared(parser)
+
+
+def pytest_terminal_summary(terminalreporter):
+ from transformers.testing_utils import pytest_terminal_summary_main
+
+ make_reports = terminalreporter.config.getoption("--make-reports")
+ if make_reports:
+ pytest_terminal_summary_main(terminalreporter, id=make_reports)
+
+
+def pytest_sessionfinish(session, exitstatus):
+ # If no tests are collected, pytest exists with code 5, which makes the CI fail.
+ if exitstatus == 5:
+ session.exitstatus = 0
+
def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
@@ -23,3 +146,21 @@ def pytest_generate_tests(metafunc):
option_value = Secret(metafunc.config.option.token)
if "token" in metafunc.fixturenames:
metafunc.parametrize("token", [option_value])
+
+
+# Doctest custom flag to ignore output.
+IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT")
+
+OutputChecker = doctest.OutputChecker
+
+
+class CustomOutputChecker(OutputChecker):
+ def check_output(self, want, got, optionflags):
+ if IGNORE_RESULT & optionflags:
+ return True
+ return OutputChecker.check_output(self, want, got, optionflags)
+
+
+doctest.OutputChecker = CustomOutputChecker
+_pytest.doctest.DoctestModule = HfDoctestModule
+doctest.DocTestParser = HfDocTestParser
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index 68b445c1b2..d81e0d179a 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -211,19 +211,20 @@ def _prepare_decoder_input_ids_for_generation(
# 2. `decoder_start_token_id` must have shape (batch_size, 1)
if device is None:
device = self.device
- if token_idx is None:
- if decoder_start_token_id.ndim == 1:
- if decoder_start_token_id.shape[0] != batch_size:
- raise ValueError(
- f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
- )
- decoder_start_token_id = decoder_start_token_id.view(-1, 1)
- else:
- decoder_start_token_id = (
- torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
+ if decoder_start_token_id.ndim == 1:
+ if decoder_start_token_id.shape[0] != batch_size:
+ raise ValueError(
+ f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
)
+ decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
- # creating padded decoder_input_ids to achieve static shapes. Later new tokens once generated are copied in to decoder_input_ids based on token_idx
+ decoder_start_token_id = (
+ torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
+ )
+
+ if token_idx is not None:
+ # creating padded decoder_input_ids to achieve static shapes.
+ # Later new tokens once generated are copied in to decoder_input_ids based on token_idx
max_length = max_new_tokens + 1 if max_new_tokens is not None else self.generation_config.max_length
decoder_start_token_id = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
@@ -3039,7 +3040,8 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
if self.generation_config.early_stopping:
num_eos_tokens.add_(beam_tokens[0:num_beams].eq(self.config.eos_token_id).sum())
- beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0))
+ if self.config.eos_token_id is not None:
+ beam_scores.add_(torch.where(beam_tokens.eq(self.config.eos_token_id), float("-inf"), 0.0))
beam_scores = beam_scores.view(batch_size, -1).unsqueeze(0)
_, selected = torch.topk(beam_scores, k=num_beams, dim=-1, largest=True, sorted=True)
offset = torch.arange(0, torch.numel(beam_scores), beam_scores.shape[-1]).unsqueeze(-1)
@@ -3211,6 +3213,9 @@ def move(obj, device):
if not output_scores:
sequence_outputs["sequence_scores"] = None
+ if self.generation_config.static_shapes:
+ raise NotImplementedError("sequence_scores is not implemented for static_shapes")
+
if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
diff --git a/optimum/habana/transformers/models/bart/modeling_bart.py b/optimum/habana/transformers/models/bart/modeling_bart.py
index 3e5f822cb1..08ea48e1a5 100644
--- a/optimum/habana/transformers/models/bart/modeling_bart.py
+++ b/optimum/habana/transformers/models/bart/modeling_bart.py
@@ -458,7 +458,9 @@ def gaudi_BartDecoder_forward(
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
- tensor_past_key_values_length = token_idx - 1 if use_cache else torch.tensor(past_key_values_length)
+ tensor_past_key_values_length = (
+ token_idx - 1 if (use_cache and token_idx is not None) else torch.tensor(past_key_values_length)
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
diff --git a/pyproject.toml b/pyproject.toml
index b7896da5e8..f53b25d1c0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,3 +41,12 @@ skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
+
+[tool.pytest.ini_options]
+addopts = "--doctest-glob='**/*.md'"
+doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
+markers = [
+ "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
+ "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
+ "generate: marks tests that use the GenerationTesterMixin"
+]
diff --git a/tests/transformers/tests/generation/test_framework_agnostic.py b/tests/transformers/tests/generation/test_framework_agnostic.py
index 7fcc4de752..906a90a95a 100644
--- a/tests/transformers/tests/generation/test_framework_agnostic.py
+++ b/tests/transformers/tests/generation/test_framework_agnostic.py
@@ -3,8 +3,12 @@
"""
import numpy as np
+import pytest
from transformers import AutoTokenizer
-from transformers.testing_utils import slow, torch_device
+from transformers.testing_utils import slow
+
+
+torch_device = "hpu"
class GenerationIntegrationTestsMixin:
@@ -46,6 +50,8 @@ def test_validate_generation_inputs(self):
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
model.generate(input_ids, **valid_model_kwargs)
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_custom_logits_processor(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
@@ -66,6 +72,8 @@ def test_custom_logits_processor(self):
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_max_new_tokens_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -222,6 +230,8 @@ def test_transition_scores_greedy_search_normalized(self):
)
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_transition_scores_beam_search_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -257,6 +267,8 @@ def test_transition_scores_beam_search_encoder_decoder(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -291,6 +303,8 @@ def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_transition_scores_beam_search_decoder_only(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -328,6 +342,8 @@ def test_transition_scores_beam_search_decoder_only(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_transition_scores_beam_sample_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -365,6 +381,7 @@ def test_transition_scores_beam_sample_encoder_decoder(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
@slow
+ @pytest.mark.skip("Not Implemented: sequence_scores is not implemented for static_shapes")
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1`
@@ -400,6 +417,8 @@ def test_transition_scores_early_stopping(self):
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_encoder_decoder_generate_attention_mask(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -501,6 +520,8 @@ def test_generate_too_many_encoder_kwargs(self):
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_generate_input_features_as_encoder_kwarg(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
@@ -542,6 +563,8 @@ def test_generate_pixel_values_as_encoder_kwarg(self):
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
self.assertEqual(output_sequences.shape, (2, 5))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_generate_encoder_outputs_attention_mask(self):
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
@@ -576,7 +599,6 @@ def test_eos_token_id_int_and_list_greedy_search(self):
"do_sample": False,
"num_beams": 1,
}
- expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
@@ -586,6 +608,7 @@ def test_eos_token_id_int_and_list_greedy_search(self):
model = model.to(torch_device)
tokens = tokens.to(torch_device)
+ expectation = model.config.max_length # static shape should give max_length
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -605,7 +628,6 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
"penalty_alpha": 0.6,
"top_k": 4,
}
- expectation = 17
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
@@ -615,6 +637,7 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
model = model.to(torch_device)
tokens = tokens.to(torch_device)
+ expectation = model.config.max_length # static shape should give max_length
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
@@ -623,6 +646,8 @@ def test_eos_token_id_int_and_list_contrastive_search(self):
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_eos_token_id_int_and_list_beam_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
@@ -648,7 +673,10 @@ def test_eos_token_id_int_and_list_beam_search(self):
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
)
- self.assertTrue(unpadded_correct_condition or padded_correct_condition)
+ static_shape_condition = expectation < len(generated_tokens[0]) and all(
+ token == eos_token_id for token in generated_tokens[0][expectation:]
+ )
+ self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition)
eos_token_id = [873, 198]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
@@ -656,8 +684,13 @@ def test_eos_token_id_int_and_list_beam_search(self):
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
)
- self.assertTrue(unpadded_correct_condition or padded_correct_condition)
+ static_shape_condition = expectation < len(generated_tokens[0]) and all(
+ token in eos_token_id for token in generated_tokens[0][expectation:]
+ )
+ self.assertTrue(unpadded_correct_condition or padded_correct_condition or static_shape_condition)
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_generate_vision2text_conditioning(self):
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py
index 512935e9dd..954bcd14d5 100644
--- a/tests/transformers/tests/generation/test_utils.py
+++ b/tests/transformers/tests/generation/test_utils.py
@@ -14,14 +14,27 @@
# limitations under the License.
+import copy
import inspect
+import tempfile
import unittest
import warnings
import numpy as np
import pytest
-from transformers import is_torch_available, pipeline
-from transformers.testing_utils import require_torch, slow
+from parameterized import parameterized
+from transformers import is_torch_available, pipeline, set_seed
+from transformers.testing_utils import (
+ is_flaky,
+ require_accelerate,
+ require_auto_gptq,
+ require_quanto,
+ require_torch,
+ require_torch_gpu,
+ require_torch_multi_accelerator,
+ require_torch_multi_gpu,
+ slow,
+)
from optimum.habana.checkpoint_utils import model_is_optimized
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -32,54 +45,50 @@
if is_torch_available():
import torch
+ import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
+ AutoProcessor,
AutoTokenizer,
+ BartForCausalLM,
BartForConditionalGeneration,
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
- PreTrainedModel,
SpeechEncoderDecoderModel,
+ T5ForConditionalGeneration,
)
+ from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
- BeamSearchScorer,
- ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
+ GenerationConfig,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
- HammingDiversityLogitsProcessor,
LogitsProcessorList,
MaxLengthCriteria,
MinLengthLogitsProcessor,
- NoBadWordsLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
PhrasalConstraint,
- RepetitionPenaltyLogitsProcessor,
+ PromptLookupCandidateGenerator,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
+ WatermarkDetector,
+ WatermarkingConfig,
)
- from transformers.generation.candidate_generator import AssistedCandidateGenerator, CandidateGenerator
- from transformers.generation.streamers import BaseStreamer
+ from transformers.generation.utils import _speculative_sampling
torch_device = "hpu"
adapt_transformers_to_gaudi()
@@ -91,116 +100,84 @@ class GenerationTesterMixin:
input_name = "input_ids"
max_new_tokens = 3
- def _update_default_model_kwargs(self, model_kwargs):
- model_kwargs["limit_hpu_graphs"] = False
- model_kwargs["reuse_cache"] = False
- model_kwargs["bucket_size"] = -1
-
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- input_ids = inputs_dict[self.input_name]
+ # TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids`
+ input_ids = inputs_dict.pop(self.input_name)[:batch_size, :]
+ inputs_dict.pop("attention_mask", None)
+
+ # we don't want encoder-decoder models to start from filled decoder ids
+ inputs_dict.pop("decoder_input_ids", None)
+ inputs_dict.pop("decoder_attention_mask", None)
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
- # generate max 3 tokens
- max_length = input_ids.shape[-1] + 3
+ # we'll set cache use in each test differently
+ inputs_dict.pop("use_cache", None)
+
+ inputs_dict = {
+ k: v[:batch_size, ...]
+ for k, v in inputs_dict.items()
+ if "head_mask" not in k and isinstance(v, torch.Tensor)
+ }
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
- # TransfoXL has no attention mask
- if "transfoxl" in config.__class__.__name__.lower():
- attention_mask = None
+
+ if self.has_attentions:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
else:
- attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
-
- return config, input_ids, attention_mask, max_length
-
- @staticmethod
- def _get_logits_processor_and_kwargs(
- input_length,
- eos_token_id,
- forced_bos_token_id=None,
- forced_eos_token_id=None,
- max_length=None,
- diversity_penalty=None,
- ):
- process_kwargs = {
- "min_length": input_length + 1 if max_length is None else max_length - 1,
+ attention_mask = None
+
+ # It is important set the eos_token_id to None to ensure that no sequences
+ # shorter than `max_length` can be generated
+ config.eos_token_id = None
+ config.forced_eos_token_id = None
+
+ return config, input_ids, attention_mask, inputs_dict
+
+ def _get_logits_processor_kwargs(self, do_sample=False, config=None):
+ logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
- "no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
+ "remove_invalid_values": True,
}
- logits_processor = LogitsProcessorList(
- (
- [
- HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
- ]
- if diversity_penalty is not None
- else []
- )
- + (
- [
- MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
- ]
- if eos_token_id is not None
- else []
+ if do_sample:
+ logits_processor_kwargs.update(
+ {
+ "top_k": 10,
+ "top_p": 0.7,
+ "temperature": 0.7,
+ }
)
- + (
- [
- ForcedBOSTokenLogitsProcessor(forced_bos_token_id),
- ]
- if forced_bos_token_id is not None
- else []
- )
- + (
- [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)]
- if forced_eos_token_id is not None
- else []
- )
- + [
- NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
- NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
- RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
- ]
- )
- return process_kwargs, logits_processor
-
- @staticmethod
- def _get_warper_and_kwargs(num_beams):
- warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
- logits_warper = LogitsProcessorList(
- [
- TemperatureLogitsWarper(warp_kwargs["temperature"]),
- TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
- TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
- ]
- )
- return warp_kwargs, logits_warper
-
- @staticmethod
- def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
+ # TODO (joao, raushan): see this comment for a long-term fix
+ # https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
+ # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
+ # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
+ if config is not None:
+ image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
+ video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
+ if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
+ logits_processor_kwargs["bad_words_ids"].append([image_token_index])
+ if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
+ logits_processor_kwargs["bad_words_ids"].append([video_token_index])
+
+ return logits_processor_kwargs
+
+ def _get_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
}
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=beam_kwargs["num_beams"],
- device=torch_device,
- length_penalty=beam_kwargs["length_penalty"],
- do_early_stopping=beam_kwargs["early_stopping"],
- num_beam_hyps_to_keep=num_return_sequences,
- )
- return beam_kwargs, beam_scorer
+ return beam_kwargs
- @staticmethod
- def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
+ def _get_diverse_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
@@ -209,93 +186,46 @@ def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_seque
"num_beam_groups": 2, # one beam per group
"diversity_penalty": 2.0,
}
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=beam_kwargs["num_beams"],
- device=torch_device,
- length_penalty=beam_kwargs["length_penalty"],
- do_early_stopping=beam_kwargs["early_stopping"],
- num_beam_hyps_to_keep=num_return_sequences,
- num_beam_groups=beam_kwargs["num_beam_groups"],
- )
- return beam_kwargs, beam_scorer
+ return beam_kwargs
- @staticmethod
- def _get_constrained_beam_scorer_and_kwargs(batch_size, max_length, constraints, num_return_sequences=1):
+ def _get_constrained_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
- beam_scorer = ConstrainedBeamSearchScorer(
- batch_size=batch_size,
- constraints=constraints,
- num_beams=beam_kwargs["num_beams"],
- device=torch_device,
- length_penalty=beam_kwargs["length_penalty"],
- do_early_stopping=beam_kwargs["early_stopping"],
- num_beam_hyps_to_keep=num_return_sequences,
- )
- return beam_kwargs, beam_scorer
-
- @staticmethod
- def _get_encoder_outputs(
- model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
- ):
- encoder = model.get_encoder()
- encoder_outputs = encoder(
- input_ids,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
- num_interleave, dim=0
- )
- input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
- attention_mask = None
- return encoder_outputs, input_ids, attention_mask
-
- @staticmethod
- def _get_static_shapes():
- return False
+ return beam_kwargs
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
- max_length,
+ inputs_dict,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
- if model.config.is_encoder_decoder:
- max_length = 4
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- eos_token_id=model.config.eos_token_id,
- forced_bos_token_id=model.config.forced_bos_token_id,
- forced_eos_token_id=model.config.forced_eos_token_id,
- max_length=max_length,
- )
-
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
+ output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
- **logits_process_kwargs,
+ use_cache=use_cache,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
return output_generate
@@ -305,35 +235,33 @@ def _sample_generate(
model,
input_ids,
attention_mask,
- max_length,
+ inputs_dict,
num_return_sequences,
- logits_processor,
- logits_warper,
- logits_warper_kwargs,
- process_kwargs,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
torch.manual_seed(0)
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
- model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
+ output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
- **logits_warper_kwargs,
- **process_kwargs,
+ use_cache=use_cache,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
return output_generate
@@ -343,31 +271,31 @@ def _beam_search_generate(
model,
input_ids,
attention_mask,
- max_length,
- beam_scorer,
+ inputs_dict,
beam_kwargs,
- logits_processor,
- logits_process_kwargs,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
- model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_scores=output_scores,
+ output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
+ use_cache=use_cache,
**beam_kwargs,
- **logits_process_kwargs,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
return output_generate
@@ -377,32 +305,34 @@ def _beam_sample_generate(
model,
input_ids,
attention_mask,
- max_length,
- beam_scorer,
+ inputs_dict,
beam_kwargs,
- logits_warper,
- logits_warper_kwargs,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
torch.manual_seed(0)
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=True,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_scores=output_scores,
+ output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
+ use_cache=use_cache,
**beam_kwargs,
- **logits_warper_kwargs,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
+
return output_generate
def _group_beam_search_generate(
@@ -410,30 +340,31 @@ def _group_beam_search_generate(
model,
input_ids,
attention_mask,
- max_length,
- beam_scorer,
+ inputs_dict,
beam_kwargs,
- logits_processor,
- logits_process_kwargs,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
output_generate = model.generate(
input_ids,
do_sample=False,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_scores=output_scores,
+ output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
+ use_cache=use_cache,
**beam_kwargs,
- **logits_process_kwargs,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
return output_generate
@@ -443,33 +374,33 @@ def _constrained_beam_search_generate(
model,
input_ids,
attention_mask,
- max_length,
- constrained_beam_scorer,
+ inputs_dict,
constraints,
beam_kwargs,
- logits_processor,
- logits_process_kwargs,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
- model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_scores=output_scores,
+ output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
constraints=constraints,
+ use_cache=use_cache,
**beam_kwargs,
- **logits_process_kwargs,
+ **logits_processor_kwargs,
**model_kwargs,
+ **inputs_dict,
)
return output_generate
@@ -479,76 +410,72 @@ def _contrastive_generate(
model,
input_ids,
attention_mask,
- max_length,
+ inputs_dict,
output_scores=False,
+ output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
+ use_cache=True,
):
contrastive_search_kwargs = {
"penalty_alpha": 0.6,
"top_k": 5,
}
- if model.config.is_encoder_decoder:
- max_length = 4
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- eos_token_id=model.config.eos_token_id,
- forced_bos_token_id=model.config.forced_bos_token_id,
- forced_eos_token_id=model.config.forced_eos_token_id,
- max_length=max_length,
- )
-
+ logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
- self._update_default_model_kwargs(model_kwargs)
- model.generation_config.static_shapes = self._get_static_shapes()
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
+ output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
- remove_invalid_values=True,
- **logits_process_kwargs,
+ use_cache=use_cache,
+ **logits_processor_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
+ **inputs_dict,
)
return output_generate
+ @pytest.mark.generate
def test_greedy_generate(self):
- # check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- # test old generation output for backwards compatibility
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
- model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
+ model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict
)
+
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ @pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
- # disable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- config.use_cache = False
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
+ inputs_dict=inputs_dict,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -564,58 +491,50 @@ def test_greedy_generate_dict_outputs(self):
self._check_outputs(output_generate, input_ids, model.config)
+ @pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
- # enable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
- # only relevant if model has "use_cache"
- return
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
+ self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
- config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
+ inputs_dict=inputs_dict,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=True,
)
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
+ @pytest.mark.generate
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- model = model_class(config).to(torch_device).eval()
-
- if model.config.is_encoder_decoder:
- max_length = 4
-
- process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- model.config.eos_token_id,
- forced_bos_token_id=model.config.forced_bos_token_id,
- forced_eos_token_id=model.config.forced_eos_token_id,
- max_length=max_length,
- )
- logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
+ inputs_dict=inputs_dict,
num_return_sequences=1,
- logits_processor=logits_processor,
- logits_warper=logits_warper,
- logits_warper_kwargs=logits_warper_kwargs,
- process_kwargs=process_kwargs,
)
if model.config.is_encoder_decoder:
@@ -623,38 +542,24 @@ def test_sample_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ @pytest.mark.generate
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- # disable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- config.use_cache = False
- model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- model.config.eos_token_id,
- forced_bos_token_id=model.config.forced_bos_token_id,
- forced_eos_token_id=model.config.forced_eos_token_id,
- max_length=max_length,
- )
- logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
+ inputs_dict=inputs_dict,
num_return_sequences=2,
- logits_processor=logits_processor,
- logits_warper=logits_warper,
- logits_warper_kwargs=logits_warper_kwargs,
- process_kwargs=process_kwargs,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
if model.config.is_encoder_decoder:
@@ -670,38 +575,20 @@ def test_sample_generate_dict_output(self):
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
+ @pytest.mark.generate
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- )
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_process_kwargs=logits_process_kwargs,
- logits_processor=logits_processor,
)
if model.config.is_encoder_decoder:
@@ -709,72 +596,26 @@ def test_beam_search_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- if model.config.is_encoder_decoder:
- max_length = 4
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
-
- output_generate = self._beam_search_generate(
- model=model,
- input_ids=input_ids,
- attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
- beam_kwargs=beam_kwargs,
- logits_process_kwargs=logits_process_kwargs,
- logits_processor=logits_processor,
- )
- if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
- else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
-
+ @pytest.mark.generate
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # disable cache
- config.use_cache = False
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- )
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_process_kwargs=logits_process_kwargs,
- logits_processor=logits_processor,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
- if model.config.is_encoder_decoder:
- self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
- else:
- self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
-
- self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
- self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
-
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
@@ -790,148 +631,139 @@ def test_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
+ @pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# enable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
- # only relevant if model has "use_cache"
- return
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+ if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
+ self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- )
-
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ beam_kwargs = self._get_beam_kwargs()
- config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_process_kwargs=logits_process_kwargs,
- logits_processor=logits_processor,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+
self._check_outputs(
- output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
+ output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
)
- @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
- def test_beam_sample_generate(self):
+ @require_accelerate
+ @require_torch_multi_accelerator
+ @pytest.mark.generate
+ def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ if "xpu" in torch_device:
+ return unittest.skip(reason="device_map='auto' does not work with XPU devices")
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ if model_class._no_split_modules is None:
+ continue
- logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- model = model_class(config).to(torch_device).eval()
+ model = model_class(config).eval()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir)
+ new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
- # check `generate()` and `beam_search()` are equal
- if model.config.is_encoder_decoder:
- max_length = 4
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ new_model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=self.max_new_tokens,
+ num_beams=2,
+ **inputs_dict,
+ )
+
+ @pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
+ @pytest.mark.generate
+ def test_beam_sample_generate(self):
+ for model_class in self.all_generative_model_classes:
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
+ beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_warper=logits_warper,
- logits_warper_kwargs=logits_warper_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
- input_embeds = model.get_input_embeddings()(input_ids)
- beam_kwargs.update({"inputs_embeds": input_embeds})
- output_generate2 = self._beam_sample_generate(
- model=model,
- input_ids=None,
- attention_mask=attention_mask,
- beam_kwargs=beam_kwargs,
- logits_warper_kwargs=logits_warper_kwargs,
+
+ # for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
+ # no quick fix available, since obtaining image embeddings step is very model-specific
+ if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")):
+ prepare_inputs_for_generation_args = set(
+ inspect.signature(model.prepare_inputs_for_generation).parameters
)
+ # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
+ # code is up to date with our most recent standards
+ if (
+ "inputs_embeds" in prepare_inputs_for_generation_args
+ and "cache_positions" in prepare_inputs_for_generation_args
+ ):
+ input_embeds = model.get_input_embeddings()(input_ids)
+ beam_kwargs.update({"inputs_embeds": input_embeds})
+ output_generate2 = self._beam_sample_generate(
+ model=model,
+ input_ids=None,
+ attention_mask=attention_mask,
+ inputs_dict={},
+ beam_kwargs=beam_kwargs,
+ )
- torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
+ torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
@pytest.mark.skip("Beam search sampling is not supported by optimum-habana yet")
+ @pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # disable cache
- config.use_cache = False
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
-
- if model.config.is_encoder_decoder:
- max_length = 4
- beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
+ beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_warper=logits_warper,
- logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
- self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
- self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
-
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
@@ -947,192 +779,131 @@ def test_beam_sample_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
+ @pytest.mark.generate
def test_generate_without_input_ids(self):
- config, _, _, max_length = self._get_input_ids_and_config()
+ config, _, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
- return
+ self.skipTest(reason="bos_token_id is None")
+
+ # hack in case they are equal, otherwise the attn mask will be [0]
+ if config.bos_token_id == config.pad_token_id:
+ config.pad_token_id = None
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
- output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
+ output_ids_generate = model.generate(
+ do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
+ )
self.assertIsNotNone(output_ids_generate)
@pytest.mark.skip("Group beam search is not supported by optimum-habana")
+ @pytest.mark.generate
def test_group_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- diversity_penalty=2.0,
- )
-
# check `generate()` and `group_beam_search()` are equal
- beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
- output_generate, output_group_beam_search = self._group_beam_search_generate(
+ beam_kwargs = self._get_diverse_beam_kwargs()
+ output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
)
- self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
- # check `generate()` and `group_beam_search()` are equal for `num_return_sequences`
+ # check `group_beam_search` for higher than 1 `num_return_sequences`
num_return_sequences = 2
- if model.config.is_encoder_decoder:
- max_length = 4
- beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
- input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
- )
- output_generate, output_group_beam_search = self._group_beam_search_generate(
+ beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
+ output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
)
- self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
@pytest.mark.skip("Group beam search is not supported by optimum-habana")
+ @pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
- config.use_cache = False
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 4
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- diversity_penalty=2.0,
- )
-
- num_return_sequences = 1
- beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
- input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
- )
- output_generate, output_group_beam_search = self._group_beam_search_generate(
+ beam_kwargs = self._get_diverse_beam_kwargs()
+ output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
if model.config.is_encoder_decoder:
- self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
+ # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
- self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
+ # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
- self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist())
- self.assertTrue(
- torch.allclose(
- output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3
- )
+ self._check_outputs(
+ output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
- self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
- self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
-
- for output in (output_group_beam_search, output_generate):
- self._check_outputs(
- output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
- )
+ # TODO: @gante
+ @is_flaky()
+ @pytest.mark.generate
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- max_length = 20
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- )
-
- # check `generate()` and `constrained_beam_search()` are equal
# Sample constraints
- if not input_ids.dtype == torch.float32:
- min_id = torch.min(input_ids) + 3
- max_id = torch.max(input_ids)
- else:
- # otherwise this throws an error for Speech2TextModel since its inputs are floating points
- min_id = 3
- max_id = 100
+ min_id = 3
+ max_id = config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
- beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
- input_ids.shape[0], max_length, constraints, num_return_sequences=1
- )
+ beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- constrained_beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
)
- self.assertTrue(output_generate.shape[-1] == max_length)
+
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -1144,86 +915,63 @@ def test_constrained_beam_search_generate(self):
PhrasalConstraint(force_tokens),
]
- num_return_sequences = 2
- max_length = 20
-
- beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
- input_ids.shape[0], max_length, constraints, num_return_sequences=num_return_sequences
- )
+ beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- constrained_beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
)
- self.assertTrue(output_generate.shape[-1] == max_length)
+
+ if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ else:
+ self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
+ @pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
-
- # disable cache
- config.use_cache = False
-
- # It is important set set the eos_token_id to None to ensure that no sequences
- # shorter than `max_length` can be generated which could lead to flaky circle ci
- # failures if the top `num_return_sequences` beams are all shorter than the longest beam
- config.eos_token_id = None
- config.forced_eos_token_id = None
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
- if model.config.is_encoder_decoder:
- max_length = 20
-
- logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
- input_ids.shape[-1],
- config.eos_token_id,
- config.forced_bos_token_id,
- config.forced_eos_token_id,
- max_length,
- )
# Sample constraints
min_id = 3
- max_id = model.config.vocab_size
+ max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
- beam_kwargs, beam_scorer = self._get_constrained_beam_scorer_and_kwargs(
- input_ids.shape[0], max_length, constraints, num_return_sequences=1
- )
+ beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
- constrained_beam_scorer=beam_scorer,
+ inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
- logits_processor=logits_processor,
- logits_process_kwargs=logits_process_kwargs,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=False,
)
- self.assertTrue(output_generate.sequences.shape[-1] == max_length)
+
if model.config.is_encoder_decoder:
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
+ self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -1232,47 +980,52 @@ def test_constrained_beam_search_generate_dict_output(self):
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
- self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
- self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
-
+ @pytest.mark.generate
def test_contrastive_generate(self):
- # check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support contrastive search generation")
+
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- return
+ self.skipTest(reason="Won't fix: old model with different cache format")
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- return
- config.use_cache = True
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_generate = self._contrastive_generate(
- model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_dict=inputs_dict,
+ use_cache=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+ @pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support contrastive search generation")
+
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- return
+ self.skipTest(reason="Won't fix: old model with different cache format")
- # enable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- return
- config.use_cache = True
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@@ -1280,36 +1033,40 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
- max_length=max_length,
+ inputs_dict=inputs_dict,
output_scores=True,
+ output_logits=True,
output_hidden_states=True,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
+ use_cache=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
+
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
+ @pytest.mark.generate
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
- # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
- if any(
- model_name in model_class.__name__.lower()
- for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
- ):
- return
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support contrastive search generation")
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
+ if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
+ self.skipTest(reason="Won't fix: old model with different cache format")
+ if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
+ self.skipTest(reason="TODO: fix me")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- return
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
- config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
@@ -1320,8 +1077,10 @@ def test_contrastive_generate_low_memory(self):
top_k=4,
penalty_alpha=0.6,
low_memory=True,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
+ **inputs_dict,
+ use_cache=True,
)
high_output = model.generate(
@@ -1329,8 +1088,10 @@ def test_contrastive_generate_low_memory(self):
top_k=4,
penalty_alpha=0.6,
low_memory=False,
- max_length=max_length,
+ max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
+ **inputs_dict,
+ use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@@ -1377,89 +1138,75 @@ def test_contrastive_generate_dynamic_shapes(self):
)
self.assertListEqual(dynamic_output.tolist(), static_output.tolist())
- # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable
- @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana")
- @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
- def test_assisted_decoding_matches_greedy_search(self):
- # This test ensures that the assisted generation does not introduce output changes over greedy search.
- # It breaks the pattern in the tests above, for multiple reasons:
- # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
- # prepare the assistant encoder outputs in the main generate body);
- # - assisted_decoding does not support `use_cache = False`
- # - assisted_decoding does not support `batch_size > 1`
-
+ @pytest.mark.generate
+ @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
+ def test_beam_search_low_memory(self):
+ # Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
- # won't fix: FSMT and Reformer have a different cache variable type (and format).
+ if model_class._is_stateful:
+ self.skipTest(reason="May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- return
- # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
+ self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
- for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
+ for model_name in [
+ "ctrl",
+ "gptbigcode",
+ "transo_xl",
+ "xlnet",
+ "cpm",
+ "jamba",
+ ]
):
- return
-
- # This for loop is a naive and temporary effort to make the test less flaky.
- failed = 0
- for i in range(10):
- # enable cache
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
+ self.skipTest(reason="May fix in the future: need model-specific fixes")
+ config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
+ # batch_size=1 is ok, but batch_size>1 will cause non-identical output
- # NOTE: assisted generation only works with cache on at the moment.
- if not hasattr(config, "use_cache"):
- return
+ config.use_cache = True
+ config.is_decoder = True
- config.use_cache = True
- config.is_decoder = True
- model = model_class(config).to(torch_device).eval()
- output_greedy = model.generate(
- input_ids,
- attention_mask=attention_mask,
- max_length=max_length,
- num_beams=1,
- do_sample=False,
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
- # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
- # be correct
- output_assisted = model.generate(
- input_ids,
- attention_mask=attention_mask,
- max_length=max_length,
- num_beams=1,
- do_sample=False,
- assistant_model=model,
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
+ # test output equality of low versus high memory
+ model = model_class(config).to(torch_device).eval()
- try:
- self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
+ low_output = model.generate(
+ input_ids,
+ max_new_tokens=8,
+ num_beams=5,
+ early_stopping=True,
+ low_memory=True,
+ use_cache=True,
+ )
- for output in (output_greedy, output_assisted):
- self._check_outputs(output, input_ids, model.config, use_cache=True)
- except AssertionError:
- failed += 1
- if failed > 1:
- self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
+ high_output = model.generate(
+ input_ids,
+ max_new_tokens=8,
+ num_beams=5,
+ early_stopping=True,
+ low_memory=False,
+ use_cache=True,
+ )
+ self.assertListEqual(low_output.tolist(), high_output.tolist())
- for output in (output_greedy, output_assisted):
- self._check_outputs(output, input_ids, model.config, use_cache=True)
+ @pytest.mark.generate
+ @parameterized.expand([("random",), ("same",)])
+ @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ # This test ensures that the assisted generation does not introduce output changes over greedy search.
+ # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
+ # shape differences -- and it may result in a different output. The input shape difference happens in the
+ # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
+ # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
+ # NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
+ # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
+ # prepare the assistant encoder outputs in the main generate body);
+ # - assisted_decoding does not support `use_cache = False`
+ # - assisted_decoding does not support `batch_size > 1`
- # TODO [sasarkar] it is supported now. Enable this test, or delete it if its not applicable
- @pytest.mark.skip(reason="Assisted decoding not yet supported by optimum-habana")
- def test_assisted_decoding_sample(self):
- # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
- # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
- # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
- self.skipTest("Won't fix: old model with different cache format")
+ self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1473,16 +1220,15 @@ def test_assisted_decoding_sample(self):
"clvp",
]
):
- self.skipTest("May fix in the future: need model-specific fixes")
+ self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache
- config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
- self.skipTest("This model doesn't support caching")
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
- config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@@ -1491,503 +1237,253 @@ def test_assisted_decoding_sample(self):
# the assistant model is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
- assistant_model = model
+ generation_kwargs = {
+ "eos_token_id": -1, # see a)
+ "max_new_tokens": 4, # see c)
+ "num_beams": 1,
+ "do_sample": False,
+ "output_scores": True,
+ "output_logits": True,
+ "output_hidden_states": True,
+ "output_attentions": self.has_attentions,
+ "return_dict_in_generate": True,
+ "use_cache": True,
+ }
+ output_greedy = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
+
+ # test with the same assistant model or randomly init one
+ # in the first case all candidate tokens are accepted, in the second none is accepted
+ # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
+ if assistant_type == "random":
+ assistant_model = model_class(config).to(torch_device).eval()
+ else:
+ assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
+ generation_kwargs.update({"assistant_model": assistant_model})
+ output_assisted = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
+
+ # The two outputs must match and their shape must be as expected
+
+ self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
+ for output in (output_greedy, output_assisted):
+ self._check_outputs(output, input_ids, model.config, use_cache=True)
+
+ @is_flaky()
+ @pytest.mark.generate
+ def test_prompt_lookup_decoding_matches_greedy_search(self):
+ # This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
+ # This test is mostly a copy of test_assisted_decoding_matches_greedy_search
+
+ for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support assisted generation")
+ if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
+ self.skipTest(reason="Won't fix: old model with different cache format")
+ if any(
+ model_name in model_class.__name__.lower()
+ for model_name in [
+ "bigbirdpegasus",
+ "led",
+ "mega",
+ "speech2text",
+ "git",
+ "prophetnet",
+ "seamlessm4t",
+ "clvp",
+ ]
+ ):
+ self.skipTest(reason="May fix in the future: need model-specific fixes")
+
+ # enable cache
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
+
+ # NOTE: assisted generation only works with cache on at the moment.
+ if not hasattr(config, "use_cache"):
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
+ # Sets assisted generation arguments such that:
+ # a) no EOS is generated, to ensure generation doesn't break early
+ # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
+ # prompt lookup is correct
+ # c) there are at least two forward passes in the main model, to ensure the input preparation of
+ # the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
- "do_sample": True,
- "assistant_model": assistant_model,
+ "do_sample": False,
"output_scores": True,
+ "output_logits": True,
"output_hidden_states": True,
- "output_attentions": True,
+ "output_attentions": self.has_attentions,
"return_dict_in_generate": True,
+ "use_cache": True,
}
- #######################################################################
- # Monkey patch assisted decoding function till SW issue is resolved
- import copy
- from types import MethodType
- from typing import List, Optional, Union
-
- from transformers.generation.utils import (
- GenerateDecoderOnlyOutput,
- _crop_past_key_values,
- _prepare_attention_mask,
- _prepare_token_type_ids,
- _split_model_outputs,
- )
-
- def _speculative_sampling(
- candidate_input_ids,
- candidate_logits,
- candidate_length,
- new_logits,
- last_assistant_token_is_eos,
- max_matches,
- ):
- """
- Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
- the selected tokens, as well as the number of candidate matches.
-
- NOTE: Unless otherwise stated, the variable names match those in the paper.
- """
- new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
- # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
- # selected by the assistant, respectively.
- q = candidate_logits.softmax(dim=-1)
- q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1)
- p = new_logits.softmax(dim=-1)
- p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids.squeeze()].squeeze(0, 1)
- probability_ratio = p_i / q_i
-
- # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
- # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
- # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
- r_i = torch.rand_like(probability_ratio)
- is_accepted = r_i <= probability_ratio
- n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
-
- # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
- if last_assistant_token_is_eos and n_matches == candidate_length:
- # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
- # due to acceptance on EOS we fix `n_matches`
- n_matches -= 1
- valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
- else:
- n_matches = min(n_matches, max_matches)
-
- # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
- gamma = min(candidate_logits.shape[1], max_matches)
- p_n_plus_1 = p[:, n_matches, :]
- if n_matches < gamma:
- q_n_plus_1 = q[:, n_matches, :]
- p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
- p_prime.div_(p_prime.sum())
- else:
- p_prime = p_n_plus_1
- t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
-
- # The selected tokens include the matches (if any) plus the next sampled tokens
- if n_matches > 0:
- valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
- else:
- valid_tokens = t
-
- return valid_tokens, n_matches
-
- def assisted_decoding(
- self,
- input_ids: torch.LongTensor,
- assistant_model: Optional["PreTrainedModel"] = None,
- candidate_generator: Optional["CandidateGenerator"] = None,
- do_sample: bool = False,
- logits_processor: Optional[LogitsProcessorList] = None,
- logits_warper: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[Union[int, List[int]]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: bool = False,
- streamer: Optional["BaseStreamer"] = None,
- **model_kwargs,
- ):
- r"""
- Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
- **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
- candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
- models.
-
-
-
- In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
- generate() instead. For an overview of generation strategies and code examples, check the [following
- guide](../generation_strategies).
-
-
-
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- candidate_generator (`CandidateGenerator`, *optional*):
- A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
- more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function.
- assistant_model (`PreTrainedModel`, *optional*):
- An assistant model that can be used to accelerate generation. The assistant model must have the exact
- same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
- is much faster than running generation with the model you're calling generate from. As such, the
- assistant model should be much smaller.
- do_sample (`bool`, *optional*, defaults to `False`):
- Whether or not to use sampling ; use greedy decoding otherwise.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- logits_warper (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
- to warp the prediction score distribution of the language modeling head applied before multinomial
- sampling at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`Union[int, List[int]]`, *optional*):
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- model_kwargs:
- Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
- If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForCausalLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... StoppingCriteriaList,
- ... MaxLengthCriteria,
- ... )
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
- >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
- >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
- >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
- >>> input_prompt = "It might be possible to"
- >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
- ... ]
- ... )
- >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
- >>> outputs = model.assisted_decoding(
- ... input_ids,
- ... assistant_model=assistant_model,
- ... logits_processor=logits_processor,
- ... stopping_criteria=stopping_criteria,
- ... )
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
- ```"""
- # handling deprecated arguments
- if (assistant_model is None) == (candidate_generator is None):
- raise ValueError(
- "One (and only one) of `assistant_model` and `candidate_generator` should be defined."
- )
+ output_greedy = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
- if assistant_model is not None:
- candidate_generator = AssistedCandidateGenerator(
- input_ids=input_ids,
- assistant_model=assistant_model,
- logits_processor=logits_processor,
- model_kwargs=model_kwargs,
- eos_token_id=eos_token_id,
- )
- warnings.warn(
- "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. "
- "Pass the `candidate_generator` argument instead.",
- FutureWarning,
- )
+ generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
+ output_prompt_lookup = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
- # init values
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
- logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
- if eos_token_id is not None and pad_token_id is None:
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id_tensor = (
- torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
- )
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
- output_attentions = (
- output_attentions if output_attentions is not None else self.generation_config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.generation_config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.generation_config.return_dict_in_generate
- )
+ # The two outputs must match and their shape must be as expected
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+ self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
+ for output in (output_greedy, output_prompt_lookup):
+ self._check_outputs(output, input_ids, model.config, use_cache=True)
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
+ @pytest.mark.generate
+ def test_dola_decoding_sample(self):
+ # TODO (joao): investigate skips, try to reduce incompatibilities
+ for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support DoLa decoding")
- # keep track of which sequences are already finished
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
-
- # other auxiliary variables
- max_len = stopping_criteria[0].max_length
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- torch.dist.all_reduce(this_peer_finished_flag, op=torch.dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- cur_len = input_ids.shape[-1]
-
- # 1. Fetch candidate sequences from a `CandidateGenerator`
- candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
- candidate_input_ids = candidate_input_ids.to(self.device)
- if candidate_logits is not None:
- candidate_logits = candidate_logits.to(self.device)
-
- candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
- last_assistant_token_is_eos = (
- ~candidate_input_ids[:, -1]
- .tile(eos_token_id_tensor.shape[0], 1)
- .ne(eos_token_id_tensor.unsqueeze(1))
- .prod(dim=0)
- .bool()
- )
+ if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
+ self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
- # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
- # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
- # we use this forward pass to also pick the subsequent logits in the original model.
+ if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
+ self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
- # 2.1. Prepare the model inputs
- candidate_kwargs = copy.copy(model_kwargs)
- candidate_kwargs = _prepare_attention_mask(
- candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
- )
- candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
+ # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
- model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
+ # Encoder-decoder models are not supported
+ if config.is_encoder_decoder:
+ self.skipTest("DoLa is not supported for encoder-decoder models")
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
- # 2.2. Run a forward pass on the candidate sequence
- outputs = self(
- **model_inputs,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
+ if model.get_output_embeddings() is None:
+ self.skipTest("DoLa is not supported for models that don't have output embeddings")
+ # Sets dola generation arguments such that:
+ # a) no EOS is generated, to ensure generation doesn't break early
+ # b) there are at least two forward passes in the main model, to ensure the input preparation of
+ # the main model is correct
+ generation_kwargs = {
+ "eos_token_id": -1, # see a)
+ "max_new_tokens": 4, # see b)
+ "num_beams": 1,
+ "do_sample": True,
+ "output_scores": True,
+ "output_logits": True,
+ "output_hidden_states": True,
+ "output_attentions": self.has_attentions,
+ "return_dict_in_generate": True,
+ "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
+ }
+ generation_kwargs.update({"dola_layers": "low"})
+ model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict)
+ self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
- # 2.3. Process the new logits
- new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
- if len(logits_processor) > 0:
- for i in range(candidate_length + 1):
- new_logits[:, i, :] = logits_processor(
- candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]
- )
- if len(logits_warper) > 0:
- for i in range(candidate_length + 1):
- new_logits[:, i, :] = logits_warper(
- candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]
- )
+ @pytest.mark.generate
+ def test_assisted_decoding_sample(self):
+ # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
+ # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
+ # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
+ for model_class in self.all_generative_model_classes:
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support assisted generation")
+ if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
+ self.skipTest(reason="Won't fix: old model with different cache format")
+ if any(
+ model_name in model_class.__name__.lower()
+ for model_name in [
+ "bigbirdpegasus",
+ "led",
+ "mega",
+ "speech2text",
+ "git",
+ "prophetnet",
+ "seamlessm4t",
+ "clvp",
+ ]
+ ):
+ self.skipTest(reason="May fix in the future: need model-specific fixes")
- # 3. Select the accepted tokens. There are two possible cases:
- # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
- # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
- max_matches = max_len - cur_len - 1
- if do_sample and candidate_logits is not None:
- valid_tokens, n_matches = _speculative_sampling(
- candidate_input_ids,
- candidate_logits,
- candidate_length,
- new_logits,
- last_assistant_token_is_eos,
- max_matches,
- )
+ # enable cache
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
- # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
- # original model logits with the candidate tokens. We can keep the candidate tokens until the first
- # mismatch, or until the max length is reached.
- else:
- if do_sample:
- probs = new_logits.softmax(dim=-1)
- selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
- else:
- selected_tokens = new_logits.argmax(dim=-1)
-
- candidate_new_tokens = candidate_input_ids[:, cur_len:]
- n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
-
- # Ensure we don't generate beyond max_len or an EOS token
- if last_assistant_token_is_eos and n_matches == candidate_length:
- n_matches -= 1
- n_matches = min(n_matches, max_matches)
- valid_tokens = selected_tokens[:, : n_matches + 1]
-
- # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
- # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
- # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
- # is no match.
-
- # 4.1. Get the valid continuation, after the matching tokens
- input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
- if streamer is not None:
- streamer.put(valid_tokens.cpu())
- new_cur_len = input_ids.shape[-1]
-
- # 4.2. Discard past key values relative to unused assistant tokens
- new_cache_size = new_cur_len - 1
- outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
-
- # 5. Update the candidate generation strategy if needed
- candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
-
- if synced_gpus and this_peer_finished:
- continue # don't waste resources running the code we don't need
-
- # Store scores, attentions and hidden_states when required
- # Assistant: modified to append one tuple element per token, as in the other generation methods.
- if return_dict_in_generate:
- if output_scores:
- scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
-
- if "past_key_values" not in model_kwargs:
- added_len = new_cur_len
- else:
- added_len = n_matches + 1
-
- if output_attentions:
- if self.config.is_encoder_decoder:
- cross_attentions = _split_model_outputs(
- cross_attentions, outputs.cross_attentions, cur_len, added_len
- )
- decoder_attentions = _split_model_outputs(
- decoder_attentions,
- outputs.decoder_attentions,
- cur_len,
- added_len,
- is_decoder_attention=True,
- )
- else:
- decoder_attentions = _split_model_outputs(
- decoder_attentions,
- outputs.attentions,
- cur_len,
- added_len,
- is_decoder_attention=True,
- )
- if output_hidden_states:
- if self.config.is_encoder_decoder:
- decoder_hidden_states = _split_model_outputs(
- decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
- )
- else:
- decoder_hidden_states = _split_model_outputs(
- decoder_hidden_states, outputs.hidden_states, cur_len, added_len
- )
-
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
+ # NOTE: assisted generation only works with cache on at the moment.
+ if not hasattr(config, "use_cache"):
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id_tensor is not None:
- unfinished_sequences = unfinished_sequences.mul(
- input_ids[:, -1]
- .tile(eos_token_id_tensor.shape[0], 1)
- .ne(eos_token_id_tensor.unsqueeze(1))
- .prod(dim=0)
- )
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
+ # Sets assisted generation arguments such that:
+ # a) no EOS is generated, to ensure generation doesn't break early
+ # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
+ # the assistant model is correct
+ # c) there are at least two forward passes in the main model, to ensure the input preparation of
+ # the main model is correct
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2 # see b)
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
+ generation_kwargs = {
+ "eos_token_id": -1, # see a)
+ "max_new_tokens": 4, # see c)
+ "num_beams": 1,
+ "do_sample": True,
+ "assistant_model": assistant_model,
+ "output_scores": True,
+ "output_logits": True,
+ "output_hidden_states": True,
+ "output_attentions": self.has_attentions,
+ "return_dict_in_generate": True,
+ "use_cache": True,
+ }
+ output_assisted = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
- # stop when each sentence is finished
- if unfinished_sequences.max() == 0:
- this_peer_finished = True
-
- # stop if we exceed the maximum length
- if stopping_criteria(input_ids, scores):
- this_peer_finished = True
-
- if this_peer_finished and not synced_gpus:
- break
-
- if streamer is not None:
- streamer.end()
-
- if return_dict_in_generate:
- if self.config.is_encoder_decoder:
- return GenerateEncoderDecoderOutput(
- sequences=input_ids,
- scores=scores,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return GenerateDecoderOnlyOutput(
- sequences=input_ids,
- scores=scores,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return input_ids
+ self._check_outputs(output_assisted, input_ids, config, use_cache=True)
+
+ @pytest.mark.generate
+ def test_prompt_lookup_decoding_stops_at_eos(self):
+ # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens
+ # (see https://github.com/huggingface/transformers/pull/31301)
+
+ # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids.
+ # First time at the very end, so input ends with the unigrams, and second any arbitrary location.
+ # Also, we need an EOS token which will be injected just after the arbitrary located ngram.
+ # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams
+ # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model.
+ # That seems as if the model "generated" and EOS but didn't stop from user's perspective
- model.assisted_decoding = MethodType(assisted_decoding, model)
+ input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50
+ arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests
+ input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs
+ input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen
- #######################################################################
+ eos_token_id = torch.tensor([0], device=torch_device)
+ input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram
- output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
+ # init cand geenerator with max_matching_ngram_size=1 to match per-token
+ candidate_generator = PromptLookupCandidateGenerator(
+ eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1
+ )
+ output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0]
- self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
+ # PLD shouldn't propose any new tokens based on eos-match
+ self.assertTrue(output_prompt_lookup.shape[-1] == 10)
+ @pytest.mark.generate
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes:
- config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
# We want to test only encoder-decoder models
if not config.is_encoder_decoder:
continue
@@ -2013,60 +1509,93 @@ def test_generate_with_head_masking(self):
input_ids,
attention_mask=attention_mask,
num_beams=1,
- output_attentions=True,
+ output_attentions=self.has_attentions,
return_dict_in_generate=True,
remove_invalid_values=True,
**{name: mask},
+ **inputs_dict,
)
# We check the state of decoder_attentions and cross_attentions just from the last step
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
+ @pytest.mark.generate
def test_left_padding_compatibility(self):
- # The check done in this test is fairly difficult -- depending on the model architecture, passing the right
- # position index for the position embeddings can still result in a different output, due to numerical masking.
- # On the other hand, for some types of position embeddings, an incorrect position index can have a minimal
- # impact on the output.
- # There are two tricks employed to check whether left-padding compatibility is in place:
- # 1 - To reduce the negative impact of the numerical attention mask on a correct position index, we set the
- # padding size to 1.
- # 2 - To reduce the chance of false positives (i.e. passing when it should be failing), we run the check
- # multiple times with random inputs, and it has to pass with all of them.
- # NOTE: because of 2), there is some chance of false positives in this test.
+ # NOTE: left-padding results in small numerical differences. This is expected.
+ # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
+
+ # First, filter out models that don't support left padding
+ # - The model must have generative capabilities
+ if len(self.all_generative_model_classes) == 0:
+ self.skipTest(reason="No generative architecture available for this model.")
+ # - The model must support padding
+ if not self.has_attentions:
+ self.skipTest(reason="This model doesn't support padding.")
+
+ # - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
+ decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _, _, _ = self._get_input_ids_and_config()
if config.is_encoder_decoder:
- continue # skip for encoder-decoder models -- they don't need left-padding compatibility
+ continue
+ else:
+ decoder_only_classes.append(model_class)
+ if len(decoder_only_classes) == 0:
+ self.skipTest(reason="No decoder-only architecture available for this model.")
+
+ # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
+ # added support for it yet. We skip these models for now.
+ has_encoder_attributes = any(
+ attr_name
+ for attr_name in config.to_dict().keys()
+ if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
+ )
+ if has_encoder_attributes:
+ self.skipTest(
+ reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
+ )
+
+ # Then, test left-padding
+ def _prepare_model_kwargs(input_ids, attention_mask, signature):
+ model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ if "cache_position" in signature:
+ cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
+ model_kwargs["cache_position"] = cache_position
+ return model_kwargs
+
+ for model_class in decoder_only_classes:
+ config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
- no_failures = True
- for _ in range(10): # there may be false positives with 10 runs, we rely on the CI to catch the flakiness
- _, input_ids, attention_mask, _ = self._get_input_ids_and_config()
- model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
- if "position_ids" in signature:
- position_ids = torch.cumsum(attention_mask, dim=-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- model_kwargs["position_ids"] = position_ids
- next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
-
- pad_size = (input_ids.shape[0], 1)
- padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
- padded_input_ids = torch.cat((padding, input_ids), dim=1)
- padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
- model_kwargs = {"input_ids": padded_input_ids, "attention_mask": padded_attention_mask}
- if "position_ids" in signature:
- position_ids = torch.cumsum(padded_attention_mask, dim=-1) - 1
- position_ids.masked_fill_(padded_attention_mask == 0, 1)
- model_kwargs["position_ids"] = position_ids
- next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
- if not torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-7):
- no_failures = False
- break
-
- self.assertTrue(no_failures)
+ # no cache as some models require special cache classes to be init outside forward
+ model.generation_config.use_cache = False
+
+ # Without padding
+ model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
+ next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # With left-padding (length 32)
+ # can hardcode pad_token to be 0 as we'll do attn masking anyway
+ pad_token_id = (
+ config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
+ )
+ pad_size = (input_ids.shape[0], 32)
+ padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
+ padded_input_ids = torch.cat((padding, input_ids), dim=1)
+ padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
+ model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
+ next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # They should result in very similar logits
+ self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
+ @pytest.mark.generate
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).
@@ -2075,7 +1604,7 @@ def test_past_key_values_format(self):
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
- return
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
@@ -2084,7 +1613,7 @@ def test_past_key_values_format(self):
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
- return
+ self.skipTest(reason="This model doesn't return `past_key_values`")
num_hidden_layers = (
getattr(config, "decoder_layers", None)
@@ -2138,6 +1667,7 @@ def test_past_key_values_format(self):
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
+ @pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
@@ -2164,106 +1694,587 @@ def test_generate_from_inputs_embeds_decoder_only(self):
continue
# Traditional way of generating text
- outputs_from_ids = model.generate(input_ids)
- self.assertEqual(outputs_from_ids.shape, (2, 20))
+ outputs_from_ids = model.generate(
+ input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
+ )
+ self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
- outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
- self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
+ outputs_from_embeds = model.generate(
+ input_ids,
+ inputs_embeds=inputs_embeds,
+ max_new_tokens=5,
+ return_dict_in_generate=True,
+ output_scores=True,
+ )
+ self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
- # But if we pass different inputs_embeds, we should get different outputs
- torch.manual_seed(0)
+ # But if we pass different inputs_embeds, we should get different outputs (the output text may be the
+ # same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
- outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
- with self.assertRaises(AssertionError):
- self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
+ outputs_from_rand_embeds = model.generate(
+ input_ids,
+ inputs_embeds=random_embeds,
+ max_new_tokens=5,
+ return_dict_in_generate=True,
+ output_scores=True,
+ )
+ for i in range(len(outputs_from_rand_embeds.scores)):
+ self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
+ inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertListEqual(
- outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
- outputs_from_embeds_wo_ids.tolist(),
+ outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
+ outputs_from_embeds_wo_ids.sequences.tolist(),
)
- def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
- batch_size, seq_length = input_ids.shape
- num_sequences_in_output = batch_size * num_return_sequences
- gen_len = (
- output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
- )
+ @pytest.mark.generate
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ """
+ Test that StaticCache can generate from inputs_embeds and calculates max_cache_length
+ correctly in `generate()`. We force the model to not stop generation until max-length is reached
+ to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
+ """
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_static_cache:
+ self.skipTest(reason="This model does not support the static cache format")
- # scores
- self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ if config.is_encoder_decoder:
+ self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
- # Attentions
- if config.is_encoder_decoder:
- # encoder
- self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
- # decoder
- self._check_attentions_for_generate(
- num_sequences_in_output,
- output.decoder_attentions,
- min_length=1,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
+ model = model_class(config).to(torch_device).eval()
+ if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
+ self.skipTest(reason="This model does not support `inputs_embeds` in generation")
+
+ model.config.use_cache = True
+ model.config.is_decoder = True
+ batch_size, seq_length = input_ids.shape
+ max_cache_len = 30
+
+ # here we force to not stop at eos and go until max-length
+ model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
+ generation_kwargs = {
+ "max_length": max_cache_len,
+ "cache_implementation": "static",
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ }
+
+ text_config = model.config.get_text_config()
+ head_dim = (
+ text_config.head_dim
+ if hasattr(text_config, "head_dim")
+ else text_config.hidden_size // text_config.num_attention_heads
)
- else:
- # if use_cache first input is equal to no use_cache, so skip here
- attentions = output.attentions if not use_cache else output.attentions[1:]
- min_length = seq_length if not use_cache else seq_length + 1
- self._check_attentions_for_generate(
- num_sequences_in_output,
- attentions=attentions,
- min_length=min_length,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
+ num_key_value_heads = (
+ text_config.num_attention_heads
+ if getattr(text_config, "num_key_value_heads", None) is None
+ else text_config.num_key_value_heads
)
+ num_hidden_layers = text_config.num_hidden_layers
- # Hidden States
- if config.is_encoder_decoder:
- # encoder
- self._check_encoder_hidden_states_for_generate(
- output.encoder_hidden_states, batch_size, config, seq_length
+ inputs_embeds = model.get_input_embeddings()(input_ids)
+ outputs = model.generate(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
)
- # decoder
- self._check_hidden_states_for_generate(
- num_sequences_in_output,
- output.decoder_hidden_states,
- min_length=1,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
- else:
- # if use_cache first input is equal to no use_cache, so skip here
- hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
- min_length = seq_length if not use_cache else seq_length + 1
- self._check_hidden_states_for_generate(
- num_sequences_in_output,
- hidden_states,
- min_length=min_length,
- max_length=output.sequences.shape[-1],
- config=config,
- use_cache=use_cache,
- )
+ # we should get `max_length` in shape, not `max_length - embeds_length`
+ cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
+ self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
+ self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
+ self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
- def _check_scores(self, batch_size, scores, length, config):
- expected_shape = (batch_size, config.vocab_size)
- self.assertIsInstance(scores, tuple)
- self.assertEqual(len(scores), length)
- self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
+ @pytest.mark.generate
+ def test_generate_continue_from_past_key_values(self):
+ # Tests that we can continue generating from past key values, returned from a previous `generate` call
+ for model_class in self.all_generative_model_classes:
+ if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
+ self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
+ if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
+ self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
- def _check_attentions_for_generate(
- self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
- ):
- self.assertIsInstance(attentions, tuple)
- self.assertListEqual(
+ config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
+
+ if not hasattr(config, "use_cache"):
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+
+ # Let's make it always:
+ # 1. use cache (for obvious reasons)
+ # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
+ # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
+ # continuation would force it to generate beyond an EOS token)
+ # 3. ignore `token_type_ids` for simplicity
+ # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
+ # active by default on some models
+ # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
+ # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
+ # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
+ # with cache, what is considered a prompt is different in the two cases.
+
+ if "token_type_ids" in inputs:
+ del inputs["token_type_ids"]
+
+ model = model_class(config).to(torch_device)
+ model.eval()
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
+ model.generation_config.forced_eos_token_id = None
+ model.generation_config.encoder_no_repeat_ngram_size = 0
+ model.generation_config.use_cache = True
+
+ # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
+ outputs = model(**inputs)
+ if "past_key_values" not in outputs:
+ self.skipTest(reason="This model doesn't return `past_key_values`")
+
+ # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
+ outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
+
+ # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
+ # inputs may need to be tweaked across `generate` calls (like the attention mask).
+ outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
+
+ # Continue from the tokens generated above, preparing the inputs accordingly
+ inputs["past_key_values"] = outputs_cached.past_key_values
+ new_attention_len = outputs_cached.sequences.shape[-1]
+ if config.is_encoder_decoder:
+ inputs["decoder_input_ids"] = outputs_cached.sequences
+ if "decoder_attention_mask" in inputs:
+ inputs["decoder_attention_mask"] = torch.nn.functional.pad(
+ inputs["decoder_attention_mask"],
+ (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
+ mode="constant",
+ value=1,
+ )
+ else:
+ inputs["input_ids"] = outputs_cached.sequences
+ if "attention_mask" in inputs:
+ inputs["attention_mask"] = torch.nn.functional.pad(
+ inputs["attention_mask"],
+ (0, new_attention_len - inputs["attention_mask"].shape[1]),
+ mode="constant",
+ value=1,
+ )
+ outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
+
+ # The two sets of generated text and past kv should be equal to each other
+ self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
+ for layer_idx in range(len(outputs_cached.past_key_values)):
+ for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
+ self.assertTrue(
+ torch.allclose(
+ outputs.past_key_values[layer_idx][kv_idx],
+ outputs_cached.past_key_values[layer_idx][kv_idx],
+ )
+ )
+
+ @parameterized.expand([(1, False), (1, True), (4, False)])
+ @pytest.mark.generate
+ def test_new_cache_format(self, num_beams, do_sample):
+ # Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
+ # 👉 tests with and without beam search so that we can test with and without cache reordering.
+ # 👉 tests with and without sampling so we can cover the most common use cases.
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_cache_class:
+ self.skipTest(reason="This model does not support the new cache format")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+
+ model = model_class(config).to(torch_device).eval()
+ generation_kwargs = {
+ "max_new_tokens": 5,
+ "do_sample": do_sample,
+ "num_beams": num_beams,
+ "num_return_sequences": num_beams,
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ "use_cache": True,
+ }
+
+ # Sets seed before calling `generate` for the case with do_sample=True
+ seed = torch.randint(0, 1000000, (1,)).item()
+ set_seed(seed)
+ legacy_results = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
+ )
+ set_seed(seed)
+ if config.is_encoder_decoder:
+ cache_cls = EncoderDecoderCache
+ past_key_values = cache_cls(DynamicCache(), DynamicCache())
+ past_key_values = cache_cls(DynamicCache(), DynamicCache())
+ else:
+ cache_cls = DynamicCache
+ past_key_values = cache_cls()
+
+ new_results = model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ **generation_kwargs,
+ **inputs_dict,
+ )
+
+ # The two sets of generated sequences must match, despite the cache format between forward passes being
+ # different
+ self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
+ self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
+ self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
+
+ # The contents of the two caches, when converted to the same format (in both directions!), must match
+ legacy_cache = legacy_results.past_key_values
+ new_cache_converted = new_results.past_key_values.to_legacy_cache()
+ for layer_idx in range(len(legacy_cache)):
+ for kv_idx in range(len(legacy_cache[layer_idx])):
+ # TODO: @raushan, please look into this for new cache format
+ if legacy_cache[layer_idx][kv_idx] != []:
+ self.assertTrue(
+ torch.allclose(
+ legacy_cache[layer_idx][kv_idx],
+ new_cache_converted[layer_idx][kv_idx],
+ )
+ )
+
+ new_cache = new_results.past_key_values
+ legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
+ for layer_idx in range(len(new_cache)):
+ for kv_idx in range(len(new_cache[layer_idx])):
+ # TODO: @raushan, please look into this for new cache format
+ if new_cache[layer_idx][kv_idx] != []:
+ self.assertTrue(
+ torch.allclose(
+ new_cache[layer_idx][kv_idx],
+ legacy_cache_converted[layer_idx][kv_idx],
+ )
+ )
+
+ @pytest.mark.generate
+ def test_generate_with_static_cache(self):
+ """
+ Tests if StaticCache works if we set attn_implementation=static when generation.
+ This doesn't test if generation quality is good, but tests that models with
+ self._supports_static_cache don't throw an error when generating and return
+ a StaticCache object at the end.
+ """
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_static_cache:
+ self.skipTest(reason="This model does not support the static cache format")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ if config.is_encoder_decoder:
+ self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
+
+ config.is_decoder = True
+ batch_size, seq_length = input_ids.shape
+ max_new_tokens = 20
+
+ model = model_class(config).to(torch_device).eval()
+ generation_kwargs = {
+ "max_length": None,
+ "max_new_tokens": max_new_tokens,
+ "cache_implementation": "static",
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ "use_cache": True,
+ }
+
+ max_cache_len = seq_length + max_new_tokens
+ config = config.text_config if hasattr(config, "text_config") else config
+ head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+ num_key_value_heads = (
+ config.num_attention_heads
+ if getattr(config, "num_key_value_heads", None) is None
+ else config.num_key_value_heads
+ )
+ num_hidden_layers = config.num_hidden_layers
+ results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
+
+ cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
+ self.assertTrue(isinstance(results.past_key_values, StaticCache))
+ self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
+ self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
+
+ @require_quanto
+ @pytest.mark.generate
+ def test_generate_with_quant_cache(self):
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_quantized_cache:
+ self.skipTest(reason="This model does not support the quantized cache format")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config.is_decoder = True
+
+ model = model_class(config).to(torch_device).eval()
+ generation_kwargs = {
+ "max_new_tokens": 5,
+ "cache_implementation": "quantized",
+ # careful with group size, should be divisor of model's hidden size
+ "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ "use_cache": True,
+ }
+
+ results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
+ self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
+
+ # passing past key values of different type should raise Error
+ with self.assertRaises(ValueError):
+ num_hidden_layers = config.get_text_config().num_hidden_layers
+ model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ past_key_valyes=DynamicCache(num_hidden_layers),
+ **generation_kwargs,
+ )
+
+ # setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
+ generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
+ with self.assertRaises(ValueError):
+ model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
+
+ @pytest.mark.generate
+ @require_torch_gpu
+ @slow
+ @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
+ def test_generate_compile_fullgraph(self):
+ """
+ Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
+ ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
+ """
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_static_cache:
+ self.skipTest("This model doesn't support static cache")
+ # TODO (joao) -- fix and enable me :)
+ if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
+ self.skipTest("whisper model end-to-end generate compile not yet supported")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ # TODO (joao) -- fix and enable me :)
+ if config.is_encoder_decoder:
+ self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
+
+ model = model_class(config).to(torch_device)
+ model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
+
+ input_ids = inputs_dict["input_ids"].to(torch_device)
+ # creates two sets of *different* inputs with the same shape
+ half_batch_size = input_ids.shape[0] // 2
+ input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
+ self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
+
+ generation_kwargs = {
+ "do_sample": False,
+ "max_new_tokens": 10,
+ }
+
+ max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
+ config = config.get_text_config()
+ past_key_values = StaticCache(
+ config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
+ )
+
+ for model_inputs in input_ids_sets:
+ # eager dynamic cache
+ output_dynamic = model.generate(model_inputs, **generation_kwargs)
+
+ # end-to-end compiled dynamic cache
+ torch.compiler.reset()
+ compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
+ generation_config = copy.deepcopy(model.generation_config)
+ generation_config.update(**generation_kwargs)
+ output_compiled = compiled_generate(
+ model_inputs, generation_config=generation_config, past_key_values=past_key_values
+ )
+ self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
+
+ @pytest.mark.generate
+ def test_generate_methods_with_num_logits_to_keep(self):
+ for model_class in self.all_generative_model_classes:
+ if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
+ self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
+ config.use_cache = True
+ config.is_decoder = True
+
+ model = model_class(config).to(torch_device).eval()
+ # All generation methods (except assisted decoding) rely on always extracting the last token logits of the
+ # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
+ # other methods will work as well)
+ generation_kwargs = {
+ "max_new_tokens": 10,
+ "do_sample": False,
+ }
+
+ # Setting num_logits_to_keep at 0 keeps all logits (old behavior)
+ with_all_logits = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
+ )
+ # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
+ without_all_logits = model.generate(
+ input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
+ )
+ self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
+
+ @pytest.mark.generate
+ @is_flaky() # assisted generation tests are flaky (minor fp ops differences)
+ def test_assisted_decoding_with_num_logits_to_keep(self):
+ for model_class in self.all_generative_model_classes:
+ if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
+ self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
+ if model_class._is_stateful:
+ self.skipTest(reason="Stateful models don't support assisted generation")
+
+ config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
+ config.use_cache = True
+ config.is_decoder = True
+
+ model = model_class(config).to(torch_device).eval()
+ assistant_model = model
+ # All generation methods (except assisted decoding) rely on always extracting the last token logits of the
+ # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
+ # other methods will work as well)
+ generation_kwargs = {
+ "max_new_tokens": 10,
+ "do_sample": False,
+ "assistant_model": assistant_model,
+ }
+
+ assistant_model.generation_config.assistant_confidence_threshold = None
+ # Setting num_logits_to_keep at 0 keeps all logits (old behavior)
+ with_all_logits = model.generate(
+ input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
+ )
+ # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
+ without_all_logits = model.generate(
+ input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
+ )
+ self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
+
+ @pytest.mark.generate
+ def test_inherits_generation_mixin(self):
+ """
+ Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel`
+ to inherit it.
+ """
+ for model_class in self.all_generative_model_classes:
+ self.assertTrue("GenerationMixin" in str(model_class.__bases__))
+
+ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
+ batch_size, seq_length = input_ids.shape
+ config = config.text_config if hasattr(config, "text_config") else config
+ num_sequences_in_output = batch_size * num_return_sequences
+
+ gen_len = (
+ output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
+ )
+
+ # scores
+ self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
+
+ # unprocessed logits
+ self._check_logits(num_sequences_in_output, output.logits, config=config)
+
+ # Attentions
+ if self.has_attentions:
+ if config.is_encoder_decoder:
+ # encoder
+ self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
+ # decoder
+ self._check_attentions_for_generate(
+ num_sequences_in_output,
+ output.decoder_attentions,
+ min_length=1,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
+ else:
+ # if use_cache first input is equal to no use_cache, so skip here
+ attentions = output.attentions if not use_cache else output.attentions[1:]
+ min_length = seq_length if not use_cache else seq_length + 1
+ self._check_attentions_for_generate(
+ num_sequences_in_output,
+ attentions=attentions,
+ min_length=min_length,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
+
+ # Hidden States
+ if config.is_encoder_decoder:
+ # encoder
+ self._check_encoder_hidden_states_for_generate(
+ output.encoder_hidden_states, batch_size, config, seq_length
+ )
+
+ # decoder
+ self._check_hidden_states_for_generate(
+ num_sequences_in_output,
+ output.decoder_hidden_states,
+ min_length=1,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
+ else:
+ # if use_cache first input is equal to no use_cache, so skip here
+ hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
+ min_length = seq_length if not use_cache else seq_length + 1
+ self._check_hidden_states_for_generate(
+ num_sequences_in_output,
+ hidden_states,
+ min_length=min_length,
+ max_length=output.sequences.shape[-1],
+ config=config,
+ use_cache=use_cache,
+ )
+
+ # Past Key Value States -- a few notes here:
+ # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
+ # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
+ # standard cache format (e.g.gptbigcode )
+ models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
+ has_standard_cache = not any(
+ model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
+ )
+ if has_standard_cache:
+ if use_cache:
+ past_key_values = output.past_key_values
+ past_sequence_length = output.sequences.shape[-1] - 1
+ self._check_past_key_values_for_generate(
+ num_sequences_in_output,
+ past_key_values,
+ seq_length=past_sequence_length,
+ config=config,
+ )
+ elif use_cache is False:
+ self.assertTrue(output.past_key_values is None)
+
+ def _check_scores(self, batch_size, scores, length, config):
+ vocab_size = config.get_text_config(decoder=True).vocab_size
+ expected_shape = (batch_size, vocab_size)
+ self.assertIsInstance(scores, tuple)
+ self.assertEqual(len(scores), length)
+ self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
+
+ def _check_logits(self, batch_size, scores, config):
+ vocab_size = config.get_text_config(decoder=True).vocab_size
+ self.assertIsInstance(scores, tuple)
+ self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
+ # vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
+ vocab_diff = vocab_size - scores[0].shape[-1]
+ self.assertTrue(vocab_diff in [0, 1])
+ self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
+
+ def _check_attentions_for_generate(
+ self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
+ ):
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
@@ -2318,6 +2329,30 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c
[encoder_expected_shape] * len(hidden_states),
)
+ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
+ self.assertIsInstance(past_key_values, tuple)
+ self.assertListEqual(
+ [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
+ [True] * len(past_key_values),
+ )
+
+ # (batch, head, seq_length, head_features)
+ expected_shape = (
+ batch_size * num_beam_groups,
+ config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
+ seq_length,
+ config.hidden_size // config.num_attention_heads,
+ )
+ # check shape key, value
+ self.assertListEqual(
+ [layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
+ [expected_shape] * len(past_key_values),
+ )
+ self.assertListEqual(
+ [layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
+ [expected_shape] * len(past_key_values),
+ )
+
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
@@ -2342,6 +2377,45 @@ def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
self.assertTrue(flag)
+@require_torch
+class UtilsFunctionsTest(unittest.TestCase):
+ def test_speculative_sampling(self):
+ # assume vocab size 10, input length 5 + 3 generated candidates
+ candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
+ candidate_logits = torch.tensor(
+ [
+ [
+ [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
+ [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
+ [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
+ ]
+ ]
+ )
+ candidate_length = 3
+ inf = float("inf")
+ new_logits = torch.tensor(
+ [
+ [
+ [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 1
+ [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # accepts 4
+ [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 10.0, -inf], # rejects 5, accepts 8
+ [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # N/A
+ ]
+ ]
+ )
+ last_assistant_token_is_eos = False
+ validated_tokens, n_matches = _speculative_sampling(
+ candidate_input_ids,
+ candidate_logits,
+ candidate_length,
+ new_logits,
+ last_assistant_token_is_eos,
+ )
+ self.assertTrue(n_matches.item() == 2)
+ self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])
+
+
+@pytest.mark.generate
@require_torch
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
@@ -2359,6 +2433,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
}
@slow
+ @pytest.mark.skip("Group beam search is not supported by optimum-habana")
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
@@ -2393,257 +2468,80 @@ def test_diverse_beam_search(self):
],
)
- def test_max_length_backward_compat_greedy(self):
+ def test_max_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
- article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ article = "Today a dragon flew over Paris."
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ inputs_embeds = model.get_input_embeddings()(input_ids)
- max_length = 20
- input_ids = input_ids.expand(2, -1)
- model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
- input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
- batch_size=input_ids.shape[0],
- model_input_name=bart_model.main_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=bart_model.config.decoder_start_token_id,
- bos_token_id=bart_model.config.bos_token_id,
- )
+ # Controlling max_length via the configuration is deprecated in favor of max_new_tokens
+ max_new_tokens = 20
+ input_len = input_ids.shape[-1]
+ out_gen = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
+ out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens)
+ self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
- with self.assertWarns(UserWarning):
- bart_model.greedy_search(
- input_ids,
- max_length=max_length,
- pad_token_id=bart_model.config.pad_token_id,
- eos_token_id=bart_model.config.eos_token_id,
- **model_kwargs,
- )
+ def test_min_length_if_input_embeds(self):
+ # PT-only test: TF doesn't have StoppingCriteria
+ article = "Today a dragon flew over Paris."
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ inputs_embeds = model.get_input_embeddings()(input_ids)
+
+ # Controlling max_length via the configuration is deprecated in favor of max_new_tokens
+ min_length = 10
+ input_len = input_ids.shape[-1]
+ out_gen = model.generate(input_ids=input_ids, min_length=min_length, max_new_tokens=20)
+ out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length, max_new_tokens=20)
+ self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
- def test_max_length_backward_compat_sample(self):
+ def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
+ bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
- max_length = 20
- input_ids = input_ids.expand(2, -1)
- model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
- input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
- batch_size=input_ids.shape[0],
- model_input_name=bart_model.main_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=bart_model.config.decoder_start_token_id,
- bos_token_id=bart_model.config.bos_token_id,
- )
- with torch.no_grad():
- with self.assertWarns(UserWarning):
- bart_model.sample(
- input_ids,
- max_length=max_length,
- pad_token_id=bart_model.config.pad_token_id,
- eos_token_id=bart_model.config.eos_token_id,
- **model_kwargs,
- )
+ input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ stopping_criteria = StoppingCriteriaList()
+ stopping_criteria.append(MaxLengthCriteria(max_length=42))
+ with self.assertRaises(ValueError):
+ bart_model.generate(input_ids, stopping_criteria=stopping_criteria)
+ with self.assertRaises(ValueError):
+ bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
- def test_max_length_backward_compat_beam_search(self):
+ def test_custom_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
+ bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
+ bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- batch_size = 1
- max_length = 20
- num_beams = 2
-
- input_ids = input_ids.expand(2, -1)
- model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
- input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
- batch_size=input_ids.shape[0],
- model_input_name=bart_model.main_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=bart_model.config.decoder_start_token_id,
- bos_token_id=bart_model.config.bos_token_id,
- )
+ class DummyCriteria(StoppingCriteria):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ return input_ids.shape[-1] >= 20
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- device=torch_device,
- )
- with self.assertWarns(UserWarning):
- _ = bart_model.beam_search(
- input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
- )
+ stopping_criteria = StoppingCriteriaList()
+ stopping_criteria.append(DummyCriteria())
- def test_max_length_backward_compat_group_beam_search(self):
- # PT-only test: TF doesn't have StoppingCriteria & group beam search
- article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
+ output = bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22)
+ self.assertEqual(
+ list(output.shape),
+ [1, 22], # still produces the max_length
)
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+ # make sure final tokens are padding
+ self.assertEqual(output[:, 20:].tolist(), [[bart_model.config.pad_token_id, bart_model.config.pad_token_id]])
- batch_size = 1
- max_length = 20
- num_beams = 6
- num_beam_groups = 3
- num_return_sequences = num_beams * batch_size
-
- input_ids = input_ids.expand(6, -1)
- model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
- input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
- batch_size=input_ids.shape[0],
- model_input_name=bart_model.main_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=bart_model.config.decoder_start_token_id,
- bos_token_id=bart_model.config.bos_token_id,
- )
-
- diverse_beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- device=torch_device,
- num_beam_hyps_to_keep=num_return_sequences,
- num_beam_groups=num_beam_groups,
- )
- with self.assertWarns(UserWarning):
- bart_model.group_beam_search(
- input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
- )
-
- def test_max_length_warning_if_different(self):
- # PT-only test: TF doesn't have StoppingCriteria
- article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
- bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
- torch_device
- )
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
-
- batch_size = 1
-
- max_length = 20
- num_beams = 6
- num_beam_groups = 3
- num_return_sequences = num_beams * batch_size
- stopping_criteria_max_length = 18
- stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
-
- # Greedy
- input_ids = input_ids.expand(6, -1)
- model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
- input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
- batch_size=input_ids.shape[0],
- model_input_name=bart_model.main_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=bart_model.config.decoder_start_token_id,
- bos_token_id=bart_model.config.bos_token_id,
- )
-
- with self.assertWarns(UserWarning):
- bart_model.greedy_search(
- input_ids,
- max_length=max_length,
- pad_token_id=bart_model.config.pad_token_id,
- stopping_criteria=stopping_criteria,
- eos_token_id=bart_model.config.eos_token_id,
- **model_kwargs,
- )
-
- # Sample
- with self.assertWarns(UserWarning):
- with torch.no_grad():
- bart_model.sample(
- input_ids,
- max_length=max_length,
- stopping_criteria=stopping_criteria,
- pad_token_id=bart_model.config.pad_token_id,
- eos_token_id=bart_model.config.eos_token_id,
- **model_kwargs,
- )
-
- # Beam
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- device=torch_device,
- )
- with self.assertWarns(UserWarning):
- with torch.no_grad():
- bart_model.beam_search(
- input_ids,
- num_beams=num_beams,
- stopping_criteria=stopping_criteria,
- max_length=max_length,
- beam_scorer=beam_scorer,
- **model_kwargs,
- )
-
- # Grouped beam search
- diverse_beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- device=torch_device,
- num_beam_hyps_to_keep=num_return_sequences,
- num_beam_groups=num_beam_groups,
- )
- with self.assertWarns(UserWarning):
- bart_model.group_beam_search(
- input_ids,
- diverse_beam_scorer,
- stopping_criteria=stopping_criteria,
- num_beams=num_beams,
- max_length=max_length,
- **model_kwargs,
- )
-
- def test_custom_stopping_criteria_overload_error(self):
- # PT-only test: TF doesn't have StoppingCriteria
- article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
- bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
-
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
- stopping_criteria = StoppingCriteriaList()
- stopping_criteria.append(MaxLengthCriteria(max_length=42))
- with self.assertRaises(ValueError):
- bart_model.generate(input_ids, stopping_criteria=stopping_criteria)
- with self.assertRaises(ValueError):
- bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
-
- def test_custom_stopping_criteria(self):
- # PT-only test: TF doesn't have StoppingCriteria
- article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
- bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
- bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
- input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
-
- class DummyCriteria(StoppingCriteria):
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
- return input_ids.shape[-1] >= 20
-
- stopping_criteria = StoppingCriteriaList()
- stopping_criteria.append(DummyCriteria())
-
- self.assertEqual(
- list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=22).shape),
- [1, 20],
- )
- self.assertEqual(
- list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape),
- [1, 18],
+ self.assertEqual(
+ list(bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=18).shape),
+ [1, 18],
)
+ # TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
@@ -2651,17 +2549,11 @@ def test_stop_sequence_stopping_criteria(self):
output = generator(prompt)
self.assertEqual(
output,
- [
- {
- "generated_text": (
- "Hello I believe in in in number number number number number number number number number"
- )
- }
- ],
+ [{"generated_text": ("Hello I believe in we we we we we we we we we")}],
)
- output = generator(prompt, stop_sequence=" number")
- self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
+ output = generator(prompt, stop_sequence=" we")
+ self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
def test_generate_non_nlp_input_ids_as_kwarg(self):
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
@@ -2687,6 +2579,7 @@ def test_generate_input_values_as_encoder_kwarg(self):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
+ @pytest.mark.skip("Group beam search is not supported by optimum-habana")
def test_transition_scores_group_beam_search_encoder_decoder(self):
# PT-only test: TF doesn't have group beam search
articles = [
@@ -2716,13 +2609,61 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_beam_search_low_memory(self):
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
+
+ low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
+
+ high_output = model.generate(
+ model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
+ )
+ self.assertListEqual(low_output.tolist(), high_output.tolist())
+
+ @slow
+ @pytest.mark.skip("Watermarking is not supported by optimum-habana yet")
+ def test_watermark_generation(self):
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
+ input_len = model_inputs["input_ids"].shape[-1]
+
+ # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
+ watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
+ _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
+
+ # We will not check watermarked text, since we check it in `logits_processors` tests
+ # Checking if generated ids are as expected fails on different hardware
+ args = {
+ "bias": 2.0,
+ "context_width": 1,
+ "seeding_scheme": "selfhash",
+ "greenlist_ratio": 0.25,
+ "hashing_key": 15485863,
+ }
+ output = model.generate(**model_inputs, do_sample=False, max_length=15)
+ output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
+
+ # Check that the detector is detecting watermarked text
+ detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
+ detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
+ detection_out = detector(output[:, input_len:], return_dict=True)
+
+ self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
+ self.assertListEqual(detection_out.prediction.tolist(), [False])
+
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
# exactly the example provided in the docstrings of beam search, which previously
# failed after directly copying from it. Refer to PR #15555
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@@ -2730,31 +2671,15 @@ def test_beam_search_example_integration(self):
# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
- input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
+ input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
- model_kwargs = {
- "encoder_outputs": model.get_encoder()(
- encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- )
- }
-
- # instantiate beam scorer
- beam_scorer = BeamSearchScorer(
- batch_size=1,
- num_beams=num_beams,
- device=model.device,
- )
+ model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
- # instantiate logits processors
- logits_processor = LogitsProcessorList(
- [
- MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ]
+ outputs = model.generate(
+ input_ids, num_beams=num_beams, min_length=5, eos_token_id=model.config.eos_token_id, **model_kwargs
)
-
- outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alt bist du?"])
@@ -2762,8 +2687,8 @@ def test_beam_search_example_integration(self):
@slow
def test_constrained_beam_search(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
@@ -2800,8 +2725,8 @@ def test_constrained_beam_search(self):
@slow
def test_constrained_beam_search_mixed(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
@@ -2841,8 +2766,8 @@ def test_constrained_beam_search_mixed(self):
@slow
def test_constrained_beam_search_mixed_mixin(self):
# PT-only test: TF doesn't have constrained beam search
- model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
@@ -2877,9 +2802,15 @@ def test_constrained_beam_search_mixed_mixin(self):
)
@slow
+ @pytest.mark.xfail
def test_cfg_mixin(self):
- model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+
+ # add pad_token_id for static shape
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
input["input_ids"] = input["input_ids"].to(torch_device)
@@ -2919,8 +2850,8 @@ def test_cfg_mixin(self):
@slow
def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["sind"]
@@ -2944,8 +2875,8 @@ def test_constrained_beam_search_example_translation_mixin(self):
@slow
def test_constrained_beam_search_example_integration(self):
# PT-only test: TF doesn't have constrained beam search
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
- model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
+ model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@@ -2953,38 +2884,65 @@ def test_constrained_beam_search_example_integration(self):
# lets run beam search using 5 beams
num_beams = 5
# define decoder start token ids
- input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
+ input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
# add encoder_outputs to model keyword arguments
- model_kwargs = {
- "encoder_outputs": model.get_encoder()(
- encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- )
- }
+ model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}
constraint_str = "sind"
constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # remove eos token
- constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
- # instantiate beam scorer
- beam_scorer = ConstrainedBeamSearchScorer(
- batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
+ outputs = model.generate(
+ input_ids,
+ num_beams=num_beams,
+ force_words_ids=[constraint_token_ids],
+ min_length=5,
+ eos_token_id=model.config.eos_token_id,
+ **model_kwargs,
)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- # instantiate logits processors
- logits_processor = LogitsProcessorList(
- [
- MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ]
- )
+ self.assertListEqual(outputs, ["Wie alt sind Sie?"])
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ @slow
+ def test_per_row_stopping_criteria(self):
+ text = [
+ "They completed the challenging puzzle, revealing the hidden",
+ "Today a dragon flew over France",
+ "The aroma of freshly baked pizza filled the kitchen",
+ ]
+ stop_strings = ["secrets"]
- outputs = model.constrained_beam_search(
- input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to(
+ torch_device
)
- outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- self.assertListEqual(outputs, ["Wie alt sind Sie?"])
+ # normal generation with one stopping criteria
+ out = model.generate(input_ids, max_length=15)
+ out_text = tokenizer.batch_decode(out)
+ expected_out = [
+ "They completed the challenging puzzle, revealing the hidden secrets of the world.\n",
+ "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
+ "The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
+ ]
+ self.assertListEqual(out_text, expected_out)
+
+ # generation should stop at "secrets" for first batch only, filling the rest with eos tokens
+ out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer)
+ out_text = tokenizer.batch_decode(out)
+ expected_out = [
+ "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>",
+ "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced",
+ "The aroma of freshly baked pizza filled the kitchen with a sense of freshness",
+ ]
+ self.assertListEqual(out_text, expected_out)
def test_constrained_beam_search_mixin_type_checks(self):
# PT-only test: TF doesn't have constrained beam search
@@ -3027,6 +2985,55 @@ def test_constrained_beam_search_mixin_type_checks(self):
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
+ def test_batched_decoder_start_id(self):
+ # PT-only test: TF doesn't support batched_decoder_start_id
+ articles = [
+ "Justin Timberlake and Jessica Biel, welcome to parenthood.",
+ "Michael Phelps is arguably the most decorated Olympian of all time.",
+ ]
+ bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
+ input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
+ decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
+ decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0]
+
+ outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id)
+
+ outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch)
+
+ self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
+
+ def test_decoder_start_id_from_config(self):
+ # Refer to: (#30899)
+ articles = [
+ "Justin Timberlake and Jessica Biel, welcome to parenthood.",
+ "Michael Phelps is arguably the most decorated Olympian of all time.",
+ ]
+ bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
+ bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
+ torch_device
+ )
+ input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
+ decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
+
+ # we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
+ outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
+
+ # If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
+ bart_model.generation_config.decoder_start_token_id = None
+ bart_model.generation_config.bos_token_id = None
+ outputs_with_user_id = bart_model.generate(
+ input_ids,
+ generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
+ )
+
+ self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())
+
+ with self.assertRaises(ValueError):
+ outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
+
def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
@@ -3053,6 +3060,27 @@ def test_contrastive_search_batched(self):
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
+ def test_logits_processor_not_inplace(self):
+ # PT-only test: TF fixes were not made
+ article = "Today a dragon flew over Paris."
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
+
+ out = model.generate(input_ids, output_logits=True, output_scores=True, return_dict_in_generate=True)
+ out_with_temp = model.generate(
+ input_ids,
+ temperature=0.5,
+ do_sample=True,
+ output_logits=True,
+ output_scores=True,
+ return_dict_in_generate=True,
+ )
+
+ # if no logits processor is used, scores == logits. Otherwise, the processor has to modify the scores
+ self.assertListEqual(out.logits[-1].tolist(), out.scores[-1].tolist())
+ self.assertNotEqual(out_with_temp.logits[-1].tolist(), out_with_temp.scores[-1].tolist())
+
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has TF equivalent: this test relies on random sampling
generation_kwargs = {
@@ -3107,6 +3135,10 @@ def forward(self, input_ids, foo=None, **kwargs):
# because it doesn't do signature filtering.
class FakeEncoder(bart_model.model.encoder.__class__):
def forward(self, input_ids, **kwargs):
+ # We remove these to pass gaudi_BartEncoder_forward TypeError
+ kwargs.pop("bucket_size", None)
+ kwargs.pop("bucket_internal", None)
+ kwargs.pop("reduce_recompile", None)
return super().forward(input_ids, **kwargs)
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device)
@@ -3121,15 +3153,16 @@ def forward(self, input_ids, **kwargs):
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model.config.pad_token_id = tokenizer.eos_token_id
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Default generation config value of 20 -> emits warning
- with self.assertWarns(UserWarning):
- model.generate(input_ids)
+ # NOTE: in OH we do not have this warning
+ # with self.assertWarns(UserWarning):
+ # model.generate(input_ids)
# Explicitly setting max_length to 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
@@ -3138,7 +3171,805 @@ def test_default_max_length_warning(self):
# Generation config max_length != 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
+ # generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model.generation_config.max_length = 10
- model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_length_warning_assisted_generation(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
+ assistant.generation_config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # This should not raise any warning that min length is not feasible in candidate generation
+ with warnings.catch_warnings(record=True) as warning_list:
+ model.generate(
+ input_ids,
+ assistant_model=assistant,
+ min_new_tokens=10,
+ max_length=20,
+ )
+ self.assertEqual(len(warning_list), 0)
+
+ def test_default_assisted_generation(self):
+ # Initialize the GenerationConfig object
+ config = GenerationConfig()
+
+ # Check the default values
+ self.assertEqual(config.num_assistant_tokens, 20)
+ self.assertEqual(config.num_assistant_tokens_schedule, "constant")
+ self.assertEqual(config.assistant_confidence_threshold, 0.4)
+ self.assertEqual(config.is_assistant, False)
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_generated_length_assisted_generation(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
+ assistant.generation_config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+ input_length = input_ids.shape[-1]
+
+ out = model.generate(
+ input_ids,
+ assistant_model=assistant,
+ min_new_tokens=10,
+ max_new_tokens=20,
+ )
+ self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))
+
+ out = model.generate(
+ input_ids,
+ assistant_model=assistant,
+ min_new_tokens=10,
+ )
+ self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_model_kwarg_assisted_decoding_decoder_only(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # Traditional way of generating text
+ outputs_normal = model.generate(input_ids)
+ self.assertEqual(outputs_normal.shape, (1, 20))
+
+ # Should be different with token_type_ids
+ outputs_tti = model.generate(
+ input_ids,
+ token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
+ )
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
+
+ # Assistant model
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ assistant.config.pad_token_id = tokenizer.eos_token_id
+
+ # If assisted generation passes model_kwargs correctly, should be same as previous
+ outputs_assisted = model.generate(
+ input_ids,
+ token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
+ assistant_model=assistant,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_model_kwarg_assisted_decoding_encoder_decoder(self):
+ """
+ Tests that the following scenario is compatible with assisted generation:
+ 1. encoder-decoder main model
+ 2. encoder-decoder assistant model
+ 3. both have a custom input
+ (e.g. Whisper)
+ """
+
+ # PT-only test: TF doesn't support assisted decoding yet.
+ # Bart subclass with a kwarg that distorts the output
+ class FakeBart(BartForConditionalGeneration):
+ def forward(self, input_ids, past_key_values, foo=False, **kwargs):
+ outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
+ if foo:
+ outs["logits"][:, :, :] = 0.0
+ return outs
+
+ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
+ kwargs["encoder_outputs"] = encoder_outputs
+ inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+ inputs["foo"] = foo
+ return inputs
+
+ model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
+ torch_device
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # Traditional way of generating text
+ outputs_normal = model.generate(input_ids)
+ self.assertEqual(outputs_normal.shape, (1, 20))
+
+ # Should be different with foo
+ outputs_foo = model.generate(input_ids, foo=True)
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
+
+ # Assistant model
+ assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
+ torch_device
+ )
+
+ # If assisted generation passes model_kwargs correctly, should be same as previous
+ outputs_assisted = model.generate(
+ input_ids,
+ foo=True,
+ assistant_model=assistant,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
+
+ # Check that passing encoder_outputs directly also works as expected
+ encoder_outputs = assistant.get_encoder()(input_ids)
+
+ outputs_assisted = model.generate(
+ foo=True,
+ assistant_model=assistant,
+ encoder_outputs=encoder_outputs,
+ assistant_encoder_outputs=encoder_outputs,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_assisted_decoding_encoder_decoder_shared_encoder(self):
+ """
+ Tests that the following scenario is compatible with assisted generation:
+ 1. encoder-decoder main model
+ 2. decoder-only assistant model
+ 3. both have a custom input
+ (e.g. DistilWhisper)
+ """
+
+ # PT-only test: TF doesn't support assisted decoding yet.
+ # Bart subclass with a kwarg called foo that distorts the output
+ class FakeBartSeq2Seq(BartForConditionalGeneration):
+ def forward(self, input_ids, foo=False, **kwargs):
+ outs = super().forward(input_ids, **kwargs)
+ if foo:
+ outs["logits"][:, :, :] = 0.0
+ return outs
+
+ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
+ kwargs["encoder_outputs"] = encoder_outputs
+ inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+ inputs["foo"] = foo
+ return inputs
+
+ class FakeBartCausalLM(BartForCausalLM):
+ def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
+ outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
+ if foo:
+ outs["logits"][:, :, :] = 0.0
+ return outs
+
+ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
+ kwargs["encoder_outputs"] = encoder_outputs
+ inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+ inputs["foo"] = foo
+ return inputs
+
+ model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
+ torch_device
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # Traditional way of generating text
+ outputs_normal = model.generate(input_ids)
+ self.assertEqual(outputs_normal.shape, (1, 20))
+
+ # Should be different with foo
+ outputs_foo = model.generate(input_ids, foo=True)
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
+
+ # Assistant model
+ assistant = FakeBartCausalLM.from_pretrained(
+ "hf-internal-testing/tiny-random-BartForConditionalGeneration"
+ ).to(torch_device)
+
+ # If assisted generation passes model_kwargs correctly, should be same as previous
+ outputs_assisted = model.generate(
+ input_ids,
+ foo=True,
+ assistant_model=assistant,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
+
+ # Check that passing encoder_outputs directly also works as expected
+ encoder_outputs = model.get_encoder()(input_ids)
+
+ outputs_assisted = model.generate(
+ foo=True,
+ assistant_model=assistant,
+ encoder_outputs=encoder_outputs,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
+ # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
+
+ prompt = "Alice and Bob"
+ checkpoint = "EleutherAI/pythia-160m-deduped"
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
+ inputs = tokenizer(prompt, return_tensors="pt")
+
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
+
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 5
+ assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
+ generation_kwargs = {
+ "eos_token_id": -1,
+ "max_new_tokens": 5,
+ "do_sample": False,
+ "assistant_model": assistant_model,
+ }
+ model.generate(**inputs, **generation_kwargs)
+ # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
+ self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
+ # This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
+
+ prompt = "Alice and Bob"
+ checkpoint = "EleutherAI/pythia-160m-deduped"
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
+ inputs = tokenizer(prompt, return_tensors="pt")
+
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
+
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 5
+ assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient"
+ generation_kwargs = {
+ "eos_token_id": -1,
+ "max_new_tokens": 5,
+ "do_sample": False,
+ "assistant_model": assistant_model,
+ }
+ model.generate(**inputs, **generation_kwargs)
+ # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
+ self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @slow
+ @pytest.mark.xfail
+ def test_validate_assistant(self):
+ # Generate a random sample:
+ inputs = np.random.rand(160000)
+
+ # Load a main encoder-decoder model:
+ model_id = "openai/whisper-large-v2"
+ processor = AutoProcessor.from_pretrained(model_id)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
+ model_id,
+ low_cpu_mem_usage=True,
+ use_safetensors=True,
+ )
+ model.to(torch_device)
+
+ # process the input:
+ features = processor(inputs, return_tensors="pt").to(torch_device)
+
+ # Load an encoder-decoder assistant with same encoder as the main model:
+ assistant_distil_model_id = "distil-whisper/distil-large-v2"
+ assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
+ assistant_distil_model_id,
+ use_safetensors=True,
+ ).to(torch_device)
+ self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
+
+ # Load its decoder only version:
+ assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
+ assistant_distil_model_id,
+ low_cpu_mem_usage=True,
+ use_safetensors=True,
+ ).to(torch_device)
+ self.assertTrue(model.generate(**features, assistant_model=assistant_causal_lm).sum())
+
+ # Load an encoder-decoder assistant with a different encoder than the main model:
+ assistant_distil_model_id = "openai/whisper-tiny"
+ assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
+ assistant_distil_model_id,
+ use_safetensors=True,
+ ).to(torch_device)
+ self.assertTrue(model.generate(**features, assistant_model=assistant_seq_to_seq).sum())
+
+ # Load its decoder only version:
+ assistant_causal_lm = AutoModelForCausalLM.from_pretrained(
+ assistant_distil_model_id,
+ low_cpu_mem_usage=True,
+ use_safetensors=True,
+ ).to(torch_device)
+ # It will raise an error as the encoder of the main and assistant model are not compatible:
+ with self.assertRaises(ValueError):
+ model.generate(**features, assistant_model=assistant_causal_lm)
+
+ # Load an encoder-decoder model with a different tokenizer than the main model:
+ assistant_distil_model_id = "hf-internal-testing/tiny-random-SeamlessM4Tv2ForSpeechToText"
+ assistant_seq_to_seq = AutoModelForSpeechSeq2Seq.from_pretrained(
+ assistant_distil_model_id,
+ ).to(torch_device)
+ # This should raise an error as the main and assistant model don't use the same tokenizer:
+ with self.assertRaises(ValueError):
+ model.generate(**features, assistant_model=assistant_seq_to_seq)
+
+ def test_compare_unprocessed_logit_scores(self):
+ # Get unprocessed logit scores back from model generate function.
+ # Assert that unprocessed logits from generate() are same as those from modal eval()
+
+ # tell model to generate text and return unprocessed/unwarped logit scores
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ text = "generate yes or no: "
+ input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
+
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+
+ with torch.no_grad():
+ # Get logits for the next token from fwd pass
+ logits_fwd = model(input_ids).logits[:, -1, :][0]
+
+ # Get logits for the next token from generate function
+ outputs = model.generate(
+ input_ids=input_ids,
+ return_dict_in_generate=True,
+ output_logits=True,
+ max_new_tokens=1,
+ do_sample=True,
+ )
+ logits_gen = outputs.logits[0][0]
+
+ # assert that unprocessed logits from generate() are same as those from modal eval()
+ self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
+
+ def test_return_unprocessed_logit_scores(self):
+ # tell model to generate text and return unprocessed/unwarped logit scores
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ text = "generate yes or no: "
+ input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+
+ outputs = model.generate(
+ input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3
+ )
+
+ # perform dummy check if unpreprocessed logits make sense.
+ # do preselection on high probabilities; find scores of y and n tokens
+ probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1)
+ indices = torch.argwhere(probs_all > 0.001)
+ indices = indices[:, -1]
+ tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True)
+ probs_max = probs_all[probs_all > 0.001]
+
+ self.assertTrue(len(indices) >= 2)
+ next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)}
+ self.assertTrue("n" in next_token_dict)
+ self.assertTrue("y" in next_token_dict)
+ y_prob = next_token_dict["y"]
+ n_prob = next_token_dict["n"]
+
+ self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
+ self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
+
+ @slow
+ @require_torch_multi_gpu
+ def test_assisted_decoding_in_different_gpu(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
+ "cuda:1"
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
+ model.config.pad_token_id = tokenizer.eos_token_id
+ assistant.config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+ input_length = input_ids.shape[-1]
+
+ out = model.generate(
+ input_ids,
+ assistant_model=assistant,
+ max_new_tokens=20,
+ )
+ self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
+
+ @slow
+ @require_torch_gpu
+ def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
+ "cpu"
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
+ model.config.pad_token_id = tokenizer.eos_token_id
+ assistant.config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+ input_length = input_ids.shape[-1]
+
+ out = model.generate(
+ input_ids,
+ assistant_model=assistant,
+ max_new_tokens=20,
+ )
+ self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
+
+ def test_special_tokens_fall_back_to_model_default(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
+ torch_device
+ )
+ test_bos_id = 50
+
+ # Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
+ gen_output = model.generate()
+ self.assertTrue(model.generation_config.bos_token_id is not None)
+ self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
+
+ # If we pass a generation config **with** a BOS token, `generate` will use it
+ generation_config = GenerationConfig(bos_token_id=test_bos_id)
+ gen_output = model.generate(generation_config=generation_config)
+ self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
+ self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
+ self.assertTrue(test_bos_id == gen_output[0, 0])
+
+ # If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
+ # `model.generation_config`
+ generation_config = GenerationConfig(bos_token_id=None)
+ gen_output = model.generate(generation_config=generation_config)
+ self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
+ self.assertFalse(test_bos_id == gen_output[0, 0])
+ self.assertTrue(generation_config.bos_token_id is None)
+
+ # Changing `model.generation_config` will affect fallback behavior
+ model.generation_config.bos_token_id = test_bos_id
+ gen_output = model.generate(generation_config=generation_config)
+ self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
+ self.assertTrue(test_bos_id == gen_output[0, 0])
+ self.assertTrue(generation_config.bos_token_id is None)
+
+ @pytest.mark.generate
+ @require_torch_multi_gpu
+ def test_generate_with_static_cache_multi_gpu(self):
+ """
+ Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
+ """
+ # need to split manually as auto doesn't work well with unbalanced model
+ device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
+ model = AutoModelForCausalLM.from_pretrained(
+ "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ generation_kwargs = {
+ "max_new_tokens": 20,
+ "cache_implementation": "static",
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ }
+
+ results = model.generate(input_ids, **generation_kwargs)
+ self.assertTrue(isinstance(results.past_key_values, StaticCache))
+
+ # check device of each layer
+ key_cache_0 = results.past_key_values.key_cache[0]
+ value_cache_0 = results.past_key_values.value_cache[0]
+ self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
+
+ key_cache_1 = results.past_key_values.key_cache[1]
+ value_cache_1 = results.past_key_values.value_cache[1]
+ self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
+
+ @pytest.mark.generate
+ @require_torch_multi_gpu
+ def test_init_static_cache_multi_gpu(self):
+ """
+ Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
+ """
+ # need to split manually as auto doesn't work well with unbalanced model
+ device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
+ model = AutoModelForCausalLM.from_pretrained(
+ "hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ generation_kwargs = {
+ "max_new_tokens": 20,
+ "return_dict_in_generate": True, # Required to return `past_key_values`
+ }
+
+ # TODO: We need to raise a warning in case the cache is not set correctly
+ # with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
+ # past_key_values = StaticCache(
+ # config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
+ # )
+ # results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
+
+ # deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
+ layer_device_map = {0: 0, 1: 1}
+ past_key_values = StaticCache(
+ config=model.config,
+ batch_size=1,
+ max_cache_len=30,
+ device=torch_device,
+ dtype=model.dtype,
+ layer_device_map=layer_device_map,
+ )
+ results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
+
+ # check device of each layer
+ key_cache_0 = results.past_key_values.key_cache[0]
+ value_cache_0 = results.past_key_values.value_cache[0]
+ self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
+
+ key_cache_1 = results.past_key_values.key_cache[1]
+ value_cache_1 = results.past_key_values.value_cache[1]
+ self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
+
+ @slow
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_padding_input_contrastive_search_gpt2(self):
+ # Load the pre-trained GPT-2 model and tokenizer
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
+ model.to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
+
+ # Set the tokenizer to left-pad the sequences
+ tokenizer.padding_side = "left"
+
+ # Define the PAD token as the EOS token
+ tokenizer.pad_token = tokenizer.eos_token
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
+
+ # Define the input prompt
+ prompt_text = "The whispered legends of the haunted mansion spoke"
+
+ # Tokenize the input prompt
+ encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
+ input_ids = encoded_prompt.input_ids.to(torch_device)
+ attention_mask = encoded_prompt.attention_mask.to(torch_device)
+
+ # Define the contrastive search params
+ penalty_alpha = 0.6
+ top_k = 4
+
+ # Define the padding length to add to the input IDs and attention mask
+ padding_length = 10
+
+ # Generate text without padding
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ do_sample=False,
+ penalty_alpha=penalty_alpha,
+ top_k=top_k,
+ max_new_tokens=64,
+ )
+ generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ # Pad the input IDs and attention mask on the left
+ padded_input_ids = F.pad(
+ input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
+ )
+ padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)
+
+ # Generate text with padded inputs
+ outputs_with_padding = model.generate(
+ input_ids=padded_input_ids,
+ attention_mask=padded_attention_mask,
+ do_sample=False,
+ penalty_alpha=penalty_alpha,
+ top_k=top_k,
+ max_new_tokens=64,
+ )
+ generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
+
+ # Assert that the generated texts are identical for padded and non-padded inputs
+ self.assertEqual(generated_text_no_padding, generated_text_with_padding)
+ self.assertEqual(
+ generated_text_with_padding,
+ 'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
+ 'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
+ 'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
+ )
+
+ @slow
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_padding_input_contrastive_search_t5(self):
+ # Load the pre-trained T5 model and tokenizer
+ model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
+ model.to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)
+
+ # Define the input prompt
+ prompt_text = "translate English to German: I need to finish this task before the end of the day."
+
+ # Tokenize the input prompt
+ encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
+ input_ids = encoded_prompt.input_ids.to(torch_device)
+ attention_mask = encoded_prompt.attention_mask.to(torch_device)
+
+ # Define the decoder prompt
+ decoder_prompt_text = "Ich muss diese Aufgabe"
+ encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
+ decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
+ decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)
+
+ # Define the contrastive search params
+ penalty_alpha = 0.6
+ top_k = 4
+
+ # Generate text without padding
+ outputs = model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ do_sample=False,
+ penalty_alpha=penalty_alpha,
+ top_k=top_k,
+ max_new_tokens=64,
+ )
+ generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ # Define the padding length to add to the input IDs and attention mask
+ padding_length = 10
+
+ # Pad the decoder input IDs and attention mask on the left
+ padded_decoder_input_ids = F.pad(
+ decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
+ )
+ padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
+ # Since the decoder_start_token_id is the same as the pad_token_id,
+ # the last padded token represents the decoder start token.
+ # Set the attention mask for the decoder_start_token_id to True (1).
+ padded_decoder_attention_mask[:, padding_length - 1] = 1
+ # Generate text with padded inputs
+ outputs_with_padding = model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=padded_decoder_input_ids,
+ decoder_attention_mask=padded_decoder_attention_mask,
+ do_sample=False,
+ penalty_alpha=penalty_alpha,
+ top_k=top_k,
+ max_new_tokens=64,
+ )
+ generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
+
+ # Assert that the generated texts are identical for padded and non-padded inputs
+ self.assertEqual(generated_text_no_padding, generated_text_with_padding)
+ self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
+
+ # TODO [gustavo] Enable this test to Optimum-habana
+ @pytest.mark.xfail
+ def test_generate_compile_fullgraph_tiny(self):
+ """
+ Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
+ NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
+ non-slow tests to prevent regressions!
+ """
+ model = AutoModelForCausalLM.from_pretrained(
+ "hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
+
+ # compile generate
+ compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
+
+ # compiled generate does NOT accept parameterization except a) model inputs b) a generation config
+ generation_config = copy.deepcopy(model.generation_config)
+ generation_config.pad_token_id = model.config.eos_token_id
+
+ model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
+ model_inputs = model_inputs.to(model.device)
+ gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
+ self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
+
+
+@require_torch
+class TokenHealingTestCase(unittest.TestCase):
+ @parameterized.expand(
+ [
+ (
+ "square_bracket",
+ 'An example ["like this"] and another example [',
+ 'An example ["like this"] and another example ["',
+ ),
+ ("url", 'The link is