diff --git a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb index ce328e64..c2c5d9e7 100644 --- a/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb +++ b/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb @@ -2149,16 +2149,18 @@ " labels=batch[\"rejected\"],\n", " selection_mask=batch[\"rejected_mask\"]\n", " )\n", - " ref_chosen_log_probas = compute_logprobs(\n", - " logits=reference_model(batch[\"chosen\"]),\n", - " labels=batch[\"chosen\"],\n", - " selection_mask=batch[\"chosen_mask\"]\n", - " )\n", - " ref_rejected_log_probas = compute_logprobs(\n", - " logits=reference_model(batch[\"rejected\"]),\n", - " labels=batch[\"rejected\"],\n", - " selection_mask=batch[\"rejected_mask\"]\n", - " )\n", + " \n", + " with torch.no_grad():\n", + " ref_chosen_log_probas = compute_logprobs(\n", + " logits=reference_model(batch[\"chosen\"]),\n", + " labels=batch[\"chosen\"],\n", + " selection_mask=batch[\"chosen_mask\"]\n", + " )\n", + " ref_rejected_log_probas = compute_logprobs(\n", + " logits=reference_model(batch[\"rejected\"]),\n", + " labels=batch[\"rejected\"],\n", + " selection_mask=batch[\"rejected_mask\"]\n", + " )\n", " loss, chosen_rewards, rejected_rewards = compute_dpo_loss(\n", " model_chosen_logprobs=policy_chosen_log_probas,\n", " model_rejected_logprobs=policy_rejected_log_probas,\n", @@ -3090,7 +3092,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" } }, "nbformat": 4,