Skip to content

Commit 1825d15

Browse files
authored
Updates run_lora_clm.py with enhanced dataset support (huggingface#955)
1 parent a10726f commit 1825d15

File tree

3 files changed

+97
-22
lines changed

3 files changed

+97
-22
lines changed

examples/language-modeling/run_lora_clm.py

+55-22
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,27 @@ class DataArguments:
259259
save_last_ckpt: bool = field(
260260
default=True, metadata={"help": "Whether to save checkpoint at the end of the training."}
261261
)
262+
instruction_column_name: Optional[str] = field(
263+
default=None,
264+
metadata={
265+
"help": "Name of the column in the dataset that describes the task that the model should perform. By "
266+
"default, the 'instruction' column is used for non-SQL prompts and the 'question' column is used for SQL prompts."
267+
},
268+
)
269+
input_column_name: Optional[str] = field(
270+
default=None,
271+
metadata={
272+
"help": "Name of the column in the dataset that optionally provides context or input for the task. By "
273+
"default, the 'input' column is used for non-SQL prompts and the 'context' column is used for SQL prompts."
274+
},
275+
)
276+
output_column_name: Optional[str] = field(
277+
default=None,
278+
metadata={
279+
"help": "Name of the column in the dataset with the answer to the instruction. By default, the "
280+
"'output' column is used for non-SQL prompts and the 'answer' column is used for SQL prompts."
281+
},
282+
)
262283

263284

264285
@dataclass
@@ -357,7 +378,7 @@ def create_prompts(examples):
357378
prompts["target"] = []
358379
for example in examples:
359380
prompt_template = (
360-
PROMPT_DICT["prompt_with_input"] if example["input"] != "" else PROMPT_DICT["prompt_without_input"]
381+
PROMPT_DICT["prompt_with_input"] if example.get("input", "") != "" else PROMPT_DICT["prompt_without_input"]
361382
)
362383
source = prompt_template.format_map(example)
363384
prompts["source"].append(source)
@@ -531,19 +552,7 @@ def main():
531552
**dataset_args,
532553
)
533554

534-
if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
535-
# Preprocessing the datasets.
536-
for key in raw_datasets:
537-
prompts = (
538-
create_prompts(raw_datasets[key])
539-
if not data_args.sql_prompt
540-
else create_sql_prompts(raw_datasets[key])
541-
)
542-
columns_to_be_removed = list(raw_datasets[key].features.keys())
543-
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
544-
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
545-
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)
546-
elif (
555+
if (
547556
data_args.dataset_name == "timdettmers/openassistant-guanaco"
548557
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
549558
raw_datasets = raw_datasets.map(
@@ -557,7 +566,33 @@ def main():
557566
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
558567
)
559568
else:
560-
raise ValueError("Unsupported dataset")
569+
# Preprocessing the datasets.
570+
for key in raw_datasets:
571+
if data_args.instruction_column_name:
572+
raw_datasets[key] = raw_datasets[key].rename_column(
573+
data_args.instruction_column_name, "question" if data_args.sql_prompt else "instruction"
574+
)
575+
576+
if data_args.input_column_name:
577+
raw_datasets[key] = raw_datasets[key].rename_column(
578+
data_args.input_column_name, "context" if data_args.sql_prompt else "input"
579+
)
580+
581+
if data_args.output_column_name:
582+
raw_datasets[key] = raw_datasets[key].rename_column(
583+
data_args.output_column_name, "answer" if data_args.sql_prompt else "output"
584+
)
585+
586+
prompts = (
587+
create_prompts(raw_datasets[key])
588+
if not data_args.sql_prompt
589+
else create_sql_prompts(raw_datasets[key])
590+
)
591+
columns_to_be_removed = list(raw_datasets[key].features.keys())
592+
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
593+
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
594+
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)
595+
561596
# Load model
562597
if model_args.model_name_or_path:
563598
model_dtype = torch.bfloat16 if training_args.bf16 else None
@@ -661,18 +696,16 @@ def concatenate_data(dataset, max_seq_length):
661696
concatenated_dataset[column] = reshaped_data
662697
return datasets.Dataset.from_dict(concatenated_dataset)
663698

664-
if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
699+
if data_args.dataset_name == "timdettmers/openassistant-guanaco":
700+
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
701+
if training_args.do_eval:
702+
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
703+
else:
665704
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"])
666705
if training_args.do_eval:
667706
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(
668707
["prompt_sources", "prompt_targets"]
669708
)
670-
elif data_args.dataset_name == "timdettmers/openassistant-guanaco":
671-
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
672-
if training_args.do_eval:
673-
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
674-
else:
675-
raise ValueError("Unsupported dataset")
676709
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
677710
if training_args.do_eval:
678711
tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length)

tests/baselines/llama_7b.json

+36
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@
2525
}
2626
},
2727
"gaudi2": {
28+
"databricks/databricks-dolly-15k": {
29+
"num_train_epochs": 1,
30+
"eval_batch_size": 8,
31+
"distribution": {
32+
"single_card": {
33+
"learning_rate": 2e-4,
34+
"train_batch_size": 16,
35+
"perplexity": 3.8436,
36+
"train_runtime": 113.9713,
37+
"train_samples_per_second": 18.428,
38+
"extra_arguments": [
39+
"--bf16",
40+
"--gradient_accumulation_steps 1",
41+
"--evaluation_strategy no",
42+
"--save_strategy no",
43+
"--warmup_ratio 0.03",
44+
"--lr_scheduler_type constant",
45+
"--max_grad_norm 0.3",
46+
"--logging_steps 1",
47+
"--use_hpu_graphs_for_inference",
48+
"--lora_rank 8",
49+
"--lora_alpha 16",
50+
"--lora_dropout 0.1",
51+
"--lora_target_modules q_proj v_proj",
52+
"--dataset_concatenation",
53+
"--low_cpu_mem_usage True",
54+
"--adam_epsilon 1e-08",
55+
"--validation_split_percentage 20",
56+
"--attn_softmax_bf16",
57+
"--max_steps 100",
58+
"--input_column_name context",
59+
"--output_column_name response"
60+
]
61+
}
62+
}
63+
},
2864
"tatsu-lab/alpaca": {
2965
"num_train_epochs": 3,
3066
"eval_batch_size": 4,

tests/test_examples.py

+6
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,12 @@ class ProteinFoldingExampleTester2(ExampleTesterBase, metaclass=ExampleTestMeta,
713713
pass
714714

715715

716+
class CausalLanguageModelingLORAExampleTester(
717+
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm"
718+
):
719+
TASK_NAME = "databricks/databricks-dolly-15k"
720+
721+
716722
class MultiCardCausalLanguageModelingLORAExampleTester(
717723
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
718724
):

0 commit comments

Comments
 (0)