From 98f17f5b6b0562b2cc02a87a8fba0d0a5d7a070c Mon Sep 17 00:00:00 2001 From: Saibo Geng Date: Fri, 22 Dec 2023 17:04:12 +0100 Subject: [PATCH] fix bug in key_value split, this solves the problem of output mismatch --- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3426d9382ad..04d8c46202d9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4952,7 +4952,7 @@ def _split(data): if isinstance(data[0], tuple): return [ tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) - for i in range(0, full_batch_size, 1) + for i in range(0, full_batch_size, split_size) ] else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 547ff492e388..5b57ead54708 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1516,7 +1516,7 @@ def test_beam_search_low_memory(self): ] ): self.skipTest("May fix in the future: need model-specific fixes") - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2) # batch_size=1 is ok, but batch_size>1 will cause non-identical output config.use_cache = True