Skip to content

Commit

Permalink
fix dpo metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Nov 2, 2024
1 parent 83535bb commit 225a545
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 57 deletions.
31 changes: 16 additions & 15 deletions src/llamafactory/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies"),
lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
Expand All @@ -304,16 +304,17 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra

if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)

logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info_rank0(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
)
)
if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
for extra_key in ("reward", "accuracy", "throughput"):
if logs.get(extra_key):
log_str += f", '{extra_key}': {logs[extra_key]:.2f}"

logger.info_rank0("{" + log_str + "}")

if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
Expand Down
48 changes: 35 additions & 13 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,18 @@ def get_batch_loss_metrics(
if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss

reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo":
metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()

return losses.mean(), metrics

Expand All @@ -275,6 +273,30 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps

return loss

@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics).to(self.accelerator.device).mean().item())

del self._stored_metrics[train_eval]
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list):
logs[key] = metric

return Trainer.log(self, logs)
80 changes: 54 additions & 26 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_batch_samples(self, epoch_iterator, num_batches):
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
Expand All @@ -151,23 +151,25 @@ def forward(

logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length
return logits, logps, logps / valid_length

@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch)
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_")
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")

if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")

chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg

@override
def compute_reference_log_probs(
Expand All @@ -184,7 +186,7 @@ def compute_reference_log_probs(
ref_context = nullcontext()

with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch
)

Expand All @@ -200,9 +202,14 @@ def get_batch_loss_metrics(
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
self.concatenated_forward(model, batch)
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch
)
Expand All @@ -220,24 +227,21 @@ def get_batch_loss_metrics(
sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])

num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()

if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
num_chosen = len(chosen_rewards)
num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = num_chosen

if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected
if num_rejected > 0:
metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["count/rejected"] = num_rejected

metrics["kl"] = kl.item()

return losses, metrics

@override
Expand All @@ -248,6 +252,30 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps

return loss

@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics).to(self.accelerator.device).mean().item())

del self._stored_metrics[train_eval]
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list):
logs[key] = metric

return Trainer.log(self, logs)
6 changes: 5 additions & 1 deletion src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps

return loss
6 changes: 5 additions & 1 deletion src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps

return loss

Expand Down
4 changes: 3 additions & 1 deletion src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview

def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None
self.aborted = False
self.running = False
Expand Down Expand Up @@ -357,6 +358,7 @@ def monitor(self):
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None

running_log = ""
while self.trainer is not None:
if self.aborted:
yield {
Expand Down Expand Up @@ -392,7 +394,7 @@ def monitor(self):
finish_info = ALERTS["err_failed"][lang]

return_dict = {
output_box: self._finalize(lang, finish_info),
output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
progress_bar: gr.Slider(visible=False),
}
yield return_dict
Expand Down

0 comments on commit 225a545

Please sign in to comment.