diff --git a/examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml b/examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml index c83c6f67b..3a4988dc4 100644 --- a/examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml +++ b/examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml @@ -3,7 +3,7 @@ kind: Workspace metadata: name: workspace-tuning-phi-3 resource: - instanceType: "Standard_NC6s_v3" + instanceType: "Standard_NC24ads_A100_v4" labelSelector: matchLabels: app: tuning-phi-3 diff --git a/presets/tuning/text-generation/fine_tuning.py b/presets/tuning/text-generation/fine_tuning.py index 5809f887a..095f94a3a 100644 --- a/presets/tuning/text-generation/fine_tuning.py +++ b/presets/tuning/text-generation/fine_tuning.py @@ -13,6 +13,7 @@ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, Trainer, + TrainerCallback, TrainerControl, TrainerState, TrainingArguments) from trl import SFTTrainer @@ -91,7 +92,11 @@ train_dataset, eval_dataset = dm.split_dataset() -# checkpoint_callback = CheckpointCallback() +class EmptyCacheCallback(TrainerCallback): + def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): + torch.cuda.empty_cache() + return control +empty_cache_callback = EmptyCacheCallback() # Prepare for training torch.cuda.set_device(accelerator.process_index) @@ -105,6 +110,7 @@ args=ta_args, data_collator=dc_args, dataset_text_field=dm.dataset_text_field, + callbacks=[empty_cache_callback] # metrics = "tensorboard" or "wandb" # TODO )) trainer.train() @@ -113,6 +119,7 @@ # Write file to signify training completion timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") +print("Fine-Tuning completed\n") completion_indicator_path = os.path.join(ta_args.output_dir, "fine_tuning_completed.txt") with open(completion_indicator_path, 'w') as f: f.write(f"Fine-Tuning completed at {timestamp}\n")