Skip to content

Commit

Permalink
Fix tests.test_peft_inference failure (#1543)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Dec 3, 2024
1 parent 36812cd commit e883691
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,11 @@ def create_pad_arg(pad_amount, i, j):
# This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size
# This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension
# in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
if (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)]
== params["allocated_space"] - pad_amount
== params["allocated_space"] - pad_amount + num_virtual_tokens
):
num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
assert (
model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
== num_virtual_tokens
Expand Down

0 comments on commit e883691

Please sign in to comment.