Skip to content

Commit

Permalink
fix: Clear CUDA cache to reduce OOM (#536)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Use callback function to clear CUDA cache to reduce OOM every training
step.
  • Loading branch information
ishaansehgal99 authored Jul 25, 2024
1 parent f259329 commit 66f5711
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ kind: Workspace
metadata:
name: workspace-tuning-phi-3
resource:
instanceType: "Standard_NC6s_v3"
instanceType: "Standard_NC24ads_A100_v4"
labelSelector:
matchLabels:
app: tuning-phi-3
Expand Down
9 changes: 8 additions & 1 deletion presets/tuning/text-generation/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, HfArgumentParser, Trainer,
TrainerCallback, TrainerControl, TrainerState,
TrainingArguments)
from trl import SFTTrainer

Expand Down Expand Up @@ -91,7 +92,11 @@

train_dataset, eval_dataset = dm.split_dataset()

# checkpoint_callback = CheckpointCallback()
class EmptyCacheCallback(TrainerCallback):
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
torch.cuda.empty_cache()
return control
empty_cache_callback = EmptyCacheCallback()

# Prepare for training
torch.cuda.set_device(accelerator.process_index)
Expand All @@ -105,6 +110,7 @@
args=ta_args,
data_collator=dc_args,
dataset_text_field=dm.dataset_text_field,
callbacks=[empty_cache_callback]
# metrics = "tensorboard" or "wandb" # TODO
))
trainer.train()
Expand All @@ -113,6 +119,7 @@

# Write file to signify training completion
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print("Fine-Tuning completed\n")
completion_indicator_path = os.path.join(ta_args.output_dir, "fine_tuning_completed.txt")
with open(completion_indicator_path, 'w') as f:
f.write(f"Fine-Tuning completed at {timestamp}\n")

0 comments on commit 66f5711

Please sign in to comment.