Skip to content

Commit

Permalink
fix bug in key_value split, this solves the problem of output mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Saibo-creator committed Dec 22, 2023
1 parent 4680d17 commit 98f17f5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4952,7 +4952,7 @@ def _split(data):
if isinstance(data[0], tuple):
return [
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
for i in range(0, full_batch_size, 1)
for i in range(0, full_batch_size, split_size)
]

else:
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def test_beam_search_low_memory(self):
]
):
self.skipTest("May fix in the future: need model-specific fixes")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output

config.use_cache = True
Expand Down

0 comments on commit 98f17f5

Please sign in to comment.