diff --git a/apps/accelerate/chatllama/chatllama/rlhf/trainer.py b/apps/accelerate/chatllama/chatllama/rlhf/trainer.py index 74f0b7ec..359725e7 100644 --- a/apps/accelerate/chatllama/chatllama/rlhf/trainer.py +++ b/apps/accelerate/chatllama/chatllama/rlhf/trainer.py @@ -584,7 +584,7 @@ def learn(self, memories: Deque[Memory]) -> None: # compute KL divergence kl_div_loss = ( - (actions_prob * (old_actions_log_probs - actions_log_prob)) + (actions_prob * (actions_log_prob - old_actions_log_probs)) .sum(dim=-1) .mean() )