@@ -259,6 +259,27 @@ class DataArguments:
259
259
save_last_ckpt : bool = field (
260
260
default = True , metadata = {"help" : "Whether to save checkpoint at the end of the training." }
261
261
)
262
+ instruction_column_name : Optional [str ] = field (
263
+ default = None ,
264
+ metadata = {
265
+ "help" : "Name of the column in the dataset that describes the task that the model should perform. By "
266
+ "default, the 'instruction' column is used for non-SQL prompts and the 'question' column is used for SQL prompts."
267
+ },
268
+ )
269
+ input_column_name : Optional [str ] = field (
270
+ default = None ,
271
+ metadata = {
272
+ "help" : "Name of the column in the dataset that optionally provides context or input for the task. By "
273
+ "default, the 'input' column is used for non-SQL prompts and the 'context' column is used for SQL prompts."
274
+ },
275
+ )
276
+ output_column_name : Optional [str ] = field (
277
+ default = None ,
278
+ metadata = {
279
+ "help" : "Name of the column in the dataset with the answer to the instruction. By default, the "
280
+ "'output' column is used for non-SQL prompts and the 'answer' column is used for SQL prompts."
281
+ },
282
+ )
262
283
263
284
264
285
@dataclass
@@ -357,7 +378,7 @@ def create_prompts(examples):
357
378
prompts ["target" ] = []
358
379
for example in examples :
359
380
prompt_template = (
360
- PROMPT_DICT ["prompt_with_input" ] if example [ "input" ] != "" else PROMPT_DICT ["prompt_without_input" ]
381
+ PROMPT_DICT ["prompt_with_input" ] if example . get ( "input" , "" ) != "" else PROMPT_DICT ["prompt_without_input" ]
361
382
)
362
383
source = prompt_template .format_map (example )
363
384
prompts ["source" ].append (source )
@@ -531,19 +552,7 @@ def main():
531
552
** dataset_args ,
532
553
)
533
554
534
- if data_args .dataset_name == "tatsu-lab/alpaca" or data_args .sql_prompt :
535
- # Preprocessing the datasets.
536
- for key in raw_datasets :
537
- prompts = (
538
- create_prompts (raw_datasets [key ])
539
- if not data_args .sql_prompt
540
- else create_sql_prompts (raw_datasets [key ])
541
- )
542
- columns_to_be_removed = list (raw_datasets [key ].features .keys ())
543
- raw_datasets [key ] = raw_datasets [key ].add_column ("prompt_sources" , prompts ["source" ])
544
- raw_datasets [key ] = raw_datasets [key ].add_column ("prompt_targets" , prompts ["target" ])
545
- raw_datasets [key ] = raw_datasets [key ].remove_columns (columns_to_be_removed )
546
- elif (
555
+ if (
547
556
data_args .dataset_name == "timdettmers/openassistant-guanaco"
548
557
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
549
558
raw_datasets = raw_datasets .map (
@@ -557,7 +566,33 @@ def main():
557
566
[col for col in raw_datasets .column_names ["train" ] if col not in ["input" , "output" ]]
558
567
)
559
568
else :
560
- raise ValueError ("Unsupported dataset" )
569
+ # Preprocessing the datasets.
570
+ for key in raw_datasets :
571
+ if data_args .instruction_column_name :
572
+ raw_datasets [key ] = raw_datasets [key ].rename_column (
573
+ data_args .instruction_column_name , "question" if data_args .sql_prompt else "instruction"
574
+ )
575
+
576
+ if data_args .input_column_name :
577
+ raw_datasets [key ] = raw_datasets [key ].rename_column (
578
+ data_args .input_column_name , "context" if data_args .sql_prompt else "input"
579
+ )
580
+
581
+ if data_args .output_column_name :
582
+ raw_datasets [key ] = raw_datasets [key ].rename_column (
583
+ data_args .output_column_name , "answer" if data_args .sql_prompt else "output"
584
+ )
585
+
586
+ prompts = (
587
+ create_prompts (raw_datasets [key ])
588
+ if not data_args .sql_prompt
589
+ else create_sql_prompts (raw_datasets [key ])
590
+ )
591
+ columns_to_be_removed = list (raw_datasets [key ].features .keys ())
592
+ raw_datasets [key ] = raw_datasets [key ].add_column ("prompt_sources" , prompts ["source" ])
593
+ raw_datasets [key ] = raw_datasets [key ].add_column ("prompt_targets" , prompts ["target" ])
594
+ raw_datasets [key ] = raw_datasets [key ].remove_columns (columns_to_be_removed )
595
+
561
596
# Load model
562
597
if model_args .model_name_or_path :
563
598
model_dtype = torch .bfloat16 if training_args .bf16 else None
@@ -661,18 +696,16 @@ def concatenate_data(dataset, max_seq_length):
661
696
concatenated_dataset [column ] = reshaped_data
662
697
return datasets .Dataset .from_dict (concatenated_dataset )
663
698
664
- if data_args .dataset_name == "tatsu-lab/alpaca" or data_args .sql_prompt :
699
+ if data_args .dataset_name == "timdettmers/openassistant-guanaco" :
700
+ tokenized_datasets_ = tokenized_datasets ["train" ].remove_columns (["input" , "output" ])
701
+ if training_args .do_eval :
702
+ tokenized_datasets_eval_ = tokenized_datasets ["test" ].remove_columns (["input" , "output" ])
703
+ else :
665
704
tokenized_datasets_ = tokenized_datasets ["train" ].remove_columns (["prompt_sources" , "prompt_targets" ])
666
705
if training_args .do_eval :
667
706
tokenized_datasets_eval_ = tokenized_datasets ["validation" ].remove_columns (
668
707
["prompt_sources" , "prompt_targets" ]
669
708
)
670
- elif data_args .dataset_name == "timdettmers/openassistant-guanaco" :
671
- tokenized_datasets_ = tokenized_datasets ["train" ].remove_columns (["input" , "output" ])
672
- if training_args .do_eval :
673
- tokenized_datasets_eval_ = tokenized_datasets ["test" ].remove_columns (["input" , "output" ])
674
- else :
675
- raise ValueError ("Unsupported dataset" )
676
709
tokenized_datasets ["train" ] = concatenate_data (tokenized_datasets_ , data_args .max_seq_length )
677
710
if training_args .do_eval :
678
711
tokenized_datasets ["validation" ] = concatenate_data (tokenized_datasets_eval_ , data_args .max_seq_length )
0 commit comments