Skip to content

Commit 33ee016

Browse files
authored
Fix prefix tuning finetune issue and update test (huggingface#975)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 1825d15 commit 33ee016

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

optimum/habana/transformers/models/llama/modeling_llama.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,10 @@ def update(self, prev, cur, dim, idx, inp_seq_len):
262262
if prev.shape == cur.shape:
263263
prev.copy_(cur)
264264
return orig_cur
265-
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
265+
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
266266
# Initialize
267267
prev[:, :, :inp_seq_len, :].copy_(cur)
268268
return orig_cur
269-
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
270269
if idx is not None:
271270
prev.index_copy_(dim, idx - 1, cur)
272271
return prev

tests/baselines/llama_7b.json

+8-8
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,16 @@
230230
"multi_card": {
231231
"learning_rate": 5e-4,
232232
"train_batch_size": 1,
233-
"train_runtime": 16.5,
234-
"train_samples_per_second": 63.161,
235-
"perplexity": 1.224,
233+
"train_runtime": 16.1,
234+
"train_samples_per_second": 63.249,
235+
"perplexity": 1.172,
236236
"extra_arguments": [
237237
"--num_virtual_tokens 8",
238238
"--max_seq_length 64",
239239
"--logging_steps 1",
240240
"--report_to none",
241241
"--max_steps 100",
242-
"--peft_type prompt_tuning",
242+
"--peft_type prefix_tuning",
243243
"--max_seq_length 64",
244244
"--lr_scheduler_type cosine",
245245
"--warmup_steps 0",
@@ -256,16 +256,16 @@
256256
"multi_card": {
257257
"learning_rate": 5e-4,
258258
"train_batch_size": 1,
259-
"train_runtime": 16.5,
259+
"train_runtime": 18.7,
260260
"train_samples_per_second": 63.161,
261-
"perplexity": 1.224,
261+
"perplexity": 1.047,
262262
"extra_arguments": [
263263
"--num_virtual_tokens 8",
264264
"--max_seq_length 64",
265265
"--logging_steps 1",
266266
"--report_to none",
267267
"--max_steps 100",
268-
"--peft_type prompt_tuning",
268+
"--peft_type p_tuning",
269269
"--max_seq_length 64",
270270
"--lr_scheduler_type cosine",
271271
"--warmup_steps 0",
@@ -276,4 +276,4 @@
276276
}
277277
}
278278
}
279-
}
279+
}

0 commit comments

Comments
 (0)