diff --git a/README.md b/README.md
index bee8635d..44f06807 100644
--- a/README.md
+++ b/README.md
@@ -118,6 +118,7 @@ Several folders contain optional materials as a bonus for interested readers:
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
+ - [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
- **Chapter 6:**
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
diff --git a/ch05/08_memory_efficient_weight_loading/README.md b/ch05/08_memory_efficient_weight_loading/README.md
new file mode 100644
index 00000000..2b8fef08
--- /dev/null
+++ b/ch05/08_memory_efficient_weight_loading/README.md
@@ -0,0 +1,5 @@
+# Memory-efficient Model Weight Loading
+
+This folder contains code to illustrate how to load model weights more efficiently
+
+- [memory-efficient-state-dict.ipynb](memory-efficient-state-dict.ipynb): contains code to load model weights via PyTorch's `load_state_dict` method more efficiently
diff --git a/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb
new file mode 100644
index 00000000..a8b3d351
--- /dev/null
+++ b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb
@@ -0,0 +1,866 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1E_HhLEeYqFG"
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka \n",
+ " Code repository: https://github.com/rasbt/LLMs-from-scratch\n",
+ "\n",
+ " | \n",
+ "\n",
+ "\n",
+ " | \n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZuWudYFWYiH7"
+ },
+ "source": [
+ "# Memory-efficient Model Weight Loading"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qt0Qyg6ewUt6"
+ },
+ "source": [
+ "- This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited\n",
+ "- Specifically, it focuses on cases where you saved the model using `torch.save(model.state_dict(), \"model.pth\")` (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning\n",
+ "- While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SxQzFoS-IXdY",
+ "outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "memory_profiler version: 0.61.0\n",
+ "torch version: 2.4.1+cu121\n"
+ ]
+ }
+ ],
+ "source": [
+ "from importlib.metadata import version\n",
+ "\n",
+ "pkgs = [\n",
+ " \"torch\",\n",
+ "]\n",
+ "for p in pkgs:\n",
+ " print(f\"{p} version: {version(p)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "y47iQaQKyHap"
+ },
+ "source": [
+ " \n",
+ "## 1. Benchmark utilities"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nQeOEoo6yT0X"
+ },
+ "source": [
+ "- First, let's define some utility code to track VRAM (GPU memory)\n",
+ "- Later, we will also introduce a tool to track the main system RAM (CPU memory)\n",
+ "- The purpose of these functions will become clear when we apply them later"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "pEiqjYrVivgt"
+ },
+ "outputs": [],
+ "source": [
+ "import gc\n",
+ "import time\n",
+ "import torch\n",
+ "\n",
+ "\n",
+ "def start_memory_tracking():\n",
+ " \"\"\"Initialize GPU memory tracking.\"\"\"\n",
+ " if torch.cuda.is_available():\n",
+ " torch.cuda.reset_peak_memory_stats()\n",
+ " else:\n",
+ " print(\"This notebook is intended for CUDA GPUs but CUDA is not available.\")\n",
+ "\n",
+ "def print_memory_usage():\n",
+ " max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB\n",
+ " print(f\"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB\")\n",
+ "\n",
+ "def cleanup():\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " time.sleep(3) # some buffer time to allow memory to clear\n",
+ " torch.cuda.reset_peak_memory_stats()\n",
+ " max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)\n",
+ " print(f\"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z5oJwoc-kkXs"
+ },
+ "source": [
+ " \n",
+ "## 2. Model setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YfJE0vnMyr88"
+ },
+ "source": [
+ "- This code section sets up the model itself\n",
+ "- Here, we use the \"large\" GPT-2 model to make things more interesting (you may use the \"gpt2-small (124M)\" to lower the memory requirements and execution time of this notebook)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "tMuhCYaVI0w7"
+ },
+ "outputs": [],
+ "source": [
+ "from previous_chapters import GPTModel\n",
+ "\n",
+ "\n",
+ "BASE_CONFIG = {\n",
+ " \"vocab_size\": 50257, # Vocabulary size\n",
+ " \"context_length\": 1024, # Context length\n",
+ " \"drop_rate\": 0.0, # Dropout rate\n",
+ " \"qkv_bias\": True # Query-key-value bias\n",
+ "}\n",
+ "\n",
+ "model_configs = {\n",
+ " \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
+ " \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
+ " \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
+ " \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
+ "}\n",
+ "\n",
+ "CHOOSE_MODEL = \"gpt2-xl (1558M)\"\n",
+ "\n",
+ "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KWYoo1z5y8aX"
+ },
+ "source": [
+ "- Now, let's see the GPU memory functions in action:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "GK3NEA3eJv3f",
+ "outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 6.4 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "start_memory_tracking()\n",
+ "\n",
+ "\n",
+ "model = GPTModel(BASE_CONFIG)\n",
+ "device = torch.device(\"cuda\")\n",
+ "model.to(device)\n",
+ "\n",
+ "print_memory_usage()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GIhwBEBxzBsF"
+ },
+ "source": [
+ "- Additionally, let's make sure that the model runs okay by passing in some example tensor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "i_j6nZruUd7g"
+ },
+ "outputs": [],
+ "source": [
+ "# Test if the model works (no need to track memory here)\n",
+ "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
+ "model.eval()\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " model(test_input)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UgNb8c32zh4g"
+ },
+ "source": [
+ "- Next, imagine we were pretraining the model and saving it for later use\n",
+ "- We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "wUIXjcsimXU7"
+ },
+ "outputs": [],
+ "source": [
+ "# Training code would go here...\n",
+ "\n",
+ "model.train()\n",
+ "torch.save(model.state_dict(), \"model.pth\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "s9tBS4HUzz1g"
+ },
+ "source": [
+ "- Lastly, we delete the model and example tensor in the Python session to reset the GPU memory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SqmTzztqKnTs",
+ "outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 0.0 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "del model, test_input\n",
+ "cleanup()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7EnO8beUJ6Sb"
+ },
+ "source": [
+ " \n",
+ "## 3. Weight loading"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JtAXKjsG0AVL"
+ },
+ "source": [
+ "- Now begins the interesting part where we load the pretrained model weights\n",
+ "- Let's see how much GPU memory is required to load the previously saved model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wCrQNbSJJO9w",
+ "outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 12.8 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Then load pretrained weights\n",
+ "\n",
+ "start_memory_tracking()\n",
+ "\n",
+ "model = GPTModel(BASE_CONFIG)\n",
+ "model.to(device)\n",
+ "\n",
+ "model.load_state_dict(\n",
+ " torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
+ ")\n",
+ "model.to(device)\n",
+ "model.eval();\n",
+ "\n",
+ "print_memory_usage()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4AGvOrcN0KdJ"
+ },
+ "source": [
+ "- Notice that the memory is 2x as large as in the previous session\n",
+ "- This is because we have the same model in memory twice, for a short period of time:\n",
+ " - The first time via `model.to(device)`\n",
+ " - The second time via the code line `model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))`; eventually, the loaded model weights will be copied into the model, and the `state_dict` will be discarded, but for a brief amount of time, we have both the main model and the loaded `state_dict` in memory\n",
+ "- The remaining sections focus on addressing this\n",
+ "- But first, let's test the model and reset the GPU memory\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DvlUn-nmmbuj",
+ "outputId": "11d3ab68-f570-4c1e-c631-fe5547026799"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 0.0 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test if the model works (no need to track memory here)\n",
+ "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
+ "model.eval()\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " model(test_input)\n",
+ "\n",
+ "del model, test_input\n",
+ "cleanup()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RdPnW3iLLrjX"
+ },
+ "source": [
+ " \n",
+ "## 4. Loading weights sequentially"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FYqtUON602TD"
+ },
+ "source": [
+ "- One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially\n",
+ "- Below, we:\n",
+ " - first load the model into GPU memory\n",
+ " - then load the model weights into CPU memory\n",
+ " - and finally copy each parameter one by one into GPU memory\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DOIGTNWTmx9G",
+ "outputId": "145162e6-aaa6-4c2a-ed8f-f1cf068adb80"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 6.4 GB\n",
+ "Maximum GPU memory allocated: 6.7 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "start_memory_tracking()\n",
+ "\n",
+ "model = GPTModel(BASE_CONFIG).to(device)\n",
+ "\n",
+ "state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
+ "\n",
+ "print_memory_usage()\n",
+ "\n",
+ "# Sequentially copy weights to the model's parameters\n",
+ "with torch.no_grad():\n",
+ " for name, param in model.named_parameters():\n",
+ " if name in state_dict:\n",
+ " param.copy_(state_dict[name].to(device))\n",
+ " else:\n",
+ " print(f\"Warning: {name} not found in state_dict.\")\n",
+ "\n",
+ "print_memory_usage()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Pn9xD_xL1ZzM"
+ },
+ "source": [
+ "- As we can see above, the memory usage is much lower than before\n",
+ "- Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using `\".to\"` the model)\n",
+ "- Overall, this is a significant improvement\n",
+ "- Again, let's briefly test the model and then reset the GPU memory for the next section"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PRHnjA48nJgw",
+ "outputId": "dcd6b1b2-538f-4862-96a6-a5fcbf3326a4"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 0.0 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test if the model works (no need to track memory here)\n",
+ "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
+ "model.eval()\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " model(test_input)\n",
+ "\n",
+ "del model, test_input, state_dict, param\n",
+ "cleanup()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5M92LK7usb-Z"
+ },
+ "source": [
+ " \n",
+ "## 5. Loading the model with low CPU memory"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "R45qgeB613e2"
+ },
+ "source": [
+ "- In the previous session, we reduced GPU memory use by loading the weights (`state_dict`) into CPU memory first before copying them one-by-one into the model\n",
+ "- However, what do we do if we have limited CPU memory?\n",
+ "- This section uses PyTorch's so-called `\"meta\"` device approach to load a model on machines with large GPU memory but small CPU memory\n",
+ "- But first, let's define a convenience function to monitor CPU memory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "id": "BrcWy0q-3Bbe"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import psutil\n",
+ "from threading import Thread\n",
+ "\n",
+ "\n",
+ "def memory_usage_in_gb(func, *args, **kwargs):\n",
+ " process = psutil.Process(os.getpid())\n",
+ "\n",
+ " # Measure the baseline memory usage before running the function\n",
+ " baseline_mem = process.memory_info().rss / 1024 ** 3 # in GB\n",
+ "\n",
+ " # Start monitoring memory in a separate thread\n",
+ " mem_usage = []\n",
+ " done = False\n",
+ "\n",
+ " def monitor_memory():\n",
+ " while not done:\n",
+ " mem_usage.append(process.memory_info().rss / 1024 ** 3) # Convert to GB\n",
+ " time.sleep(0.1)\n",
+ "\n",
+ " t = Thread(target=monitor_memory)\n",
+ " t.start()\n",
+ "\n",
+ " # Run the function\n",
+ " func(*args, **kwargs)\n",
+ "\n",
+ " # Stop monitoring\n",
+ " done = True\n",
+ " t.join()\n",
+ "\n",
+ " peak_mem_usage_gb = max(mem_usage) - baseline_mem\n",
+ " return peak_mem_usage_gb\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ayy30Ytd5hjF"
+ },
+ "source": [
+ "- To start with, let's track the CPU memory of the sequential weight loading approach from the previous section"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rCkV6IbQtpVn",
+ "outputId": "26c0435a-1e3d-4e8f-fbe2-f9655bad61b4"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 6.4 GB\n",
+ "Maximum GPU memory allocated: 6.7 GB\n",
+ "-> Maximum CPU memory allocated: 6.3 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "def load_sequentially():\n",
+ " start_memory_tracking()\n",
+ "\n",
+ " model = GPTModel(BASE_CONFIG).to(device)\n",
+ "\n",
+ " state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ " # Sequentially copy weights to the model's parameters\n",
+ " with torch.no_grad():\n",
+ " for name, param in model.named_parameters():\n",
+ " if name in state_dict:\n",
+ " param.copy_(state_dict[name].to(device))\n",
+ " else:\n",
+ " print(f\"Warning: {name} not found in state_dict.\")\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ "\n",
+ "peak_memory_used = memory_usage_in_gb(load_sequentially)\n",
+ "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UWrmnCML5oKy"
+ },
+ "source": [
+ "- Now, suppose we have a machine with low CPU memory but large GPU memory\n",
+ "- We can trade off CPU memory and GPU memory usage by introducing PyTorch's so-called \"meta\" device\n",
+ "- PyTorch's meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating \"meta\" tensors\n",
+ "- This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "PBErC_5Yt8ly",
+ "outputId": "8799db06-191c-47c4-92fa-fbb95d685aa9"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 12.8 GB\n",
+ "Maximum GPU memory allocated: 12.8 GB\n",
+ "-> Maximum CPU memory allocated: 1.3 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "def load_sequentially_with_meta():\n",
+ " start_memory_tracking()\n",
+ "\n",
+ " with torch.device(\"meta\"):\n",
+ " model = GPTModel(BASE_CONFIG)\n",
+ "\n",
+ " model = model.to_empty(device=device)\n",
+ "\n",
+ " state_dict = torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ " # Sequentially copy weights to the model's parameters\n",
+ " with torch.no_grad():\n",
+ " for name, param in model.named_parameters():\n",
+ " if name in state_dict:\n",
+ " param.copy_(state_dict[name])\n",
+ " else:\n",
+ " print(f\"Warning: {name} not found in state_dict.\")\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ "peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)\n",
+ "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VpnCABp75-VQ"
+ },
+ "source": [
+ "- As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements\n",
+ "- One might ask: \"Is the sequential weight loading still necessary then, and how does that compare to the original approach?\"\n",
+ "- Let's check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4f-bqBNRuR39",
+ "outputId": "f7c0a901-b404-433a-9b93-2bbfa8183c56"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 12.8 GB\n",
+ "-> Maximum CPU memory allocated: 4.4 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "def baseline():\n",
+ " start_memory_tracking()\n",
+ "\n",
+ " model = GPTModel(BASE_CONFIG)\n",
+ " model.to(device)\n",
+ "\n",
+ " model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n",
+ " model.to(device)\n",
+ " model.eval();\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ "peak_memory_used = memory_usage_in_gb(baseline)\n",
+ "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NKAjxbX86xnb"
+ },
+ "source": [
+ "- As we can see above, the \"simple\" weight loading without the meta device uses more memory\n",
+ "- In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " \n",
+ "## 6. Other methods"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- This notebook is focused on simple, built-in methods for loading weights in PyTorch.\n",
+ "- In case none of these methods work because you (1) don't have enough CPU memory for the `load_sequentially` approach and don't have enough GPU VRAM to have 2 copies of the weights in memory (the `load_sequentially_with_meta` approach), one option is to save and load each weight tensor separately:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "id": "2CgPEZUIb00w"
+ },
+ "outputs": [],
+ "source": [
+ "model = GPTModel(BASE_CONFIG)\n",
+ "# Assume `model` is your trained model\n",
+ "state_dict = model.state_dict()\n",
+ "\n",
+ "# Create a directory to store individual parameter files\n",
+ "os.makedirs(\"model_parameters\", exist_ok=True)\n",
+ "\n",
+ "# Save each parameter tensor separately\n",
+ "for name, param in state_dict.items():\n",
+ " torch.save(param.cpu(), f\"model_parameters/{name}.pt\")\n",
+ "\n",
+ "del model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gTsmtJK-b4yy",
+ "outputId": "d361e2d3-e34c-48d7-9047-846c9bfd291e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maximum GPU memory allocated: 6.4 GB\n",
+ "Maximum GPU memory allocated: 6.4 GB\n",
+ "-> Maximum CPU memory allocated: 0.3 GB\n"
+ ]
+ }
+ ],
+ "source": [
+ "def load_individual_weights():\n",
+ "\n",
+ " start_memory_tracking()\n",
+ "\n",
+ " with torch.device(\"meta\"):\n",
+ " model = GPTModel(BASE_CONFIG)\n",
+ "\n",
+ " model = model.to_empty(device=device)\n",
+ "\n",
+ " print_memory_usage()\n",
+ " param_dir = \"model_parameters\"\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for name, param in model.named_parameters():\n",
+ " weight_path = os.path.join(param_dir, f\"{name}.pt\")\n",
+ " if os.path.exists(weight_path):\n",
+ " param_data = torch.load(weight_path, map_location=\"cpu\", weights_only=True)\n",
+ " param.copy_(param_data)\n",
+ " del param_data # Free memory\n",
+ " else:\n",
+ " print(f\"Warning: {name} not found in {param_dir}.\")\n",
+ "\n",
+ " print_memory_usage()\n",
+ "\n",
+ "\n",
+ "peak_memory_used = memory_usage_in_gb(load_individual_weights)\n",
+ "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "L4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/ch05/08_memory_efficient_weight_loading/previous_chapters.py b/ch05/08_memory_efficient_weight_loading/previous_chapters.py
new file mode 100644
index 00000000..3674201e
--- /dev/null
+++ b/ch05/08_memory_efficient_weight_loading/previous_chapters.py
@@ -0,0 +1,175 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+#
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 2-5.
+
+import json
+import os
+import urllib
+
+import numpy as np
+import tensorflow as tf
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+
+#####################################
+# Chapter 3
+#####################################
+class MultiHeadAttention(nn.Module):
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
+ super().__init__()
+ assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
+
+ self.d_out = d_out
+ self.num_heads = num_heads
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
+
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
+ self.dropout = nn.Dropout(dropout)
+ self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
+
+ def forward(self, x):
+ b, num_tokens, d_in = x.shape
+
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
+ queries = self.W_query(x)
+ values = self.W_value(x)
+
+ # We implicitly split the matrix by adding a `num_heads` dimension
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
+
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
+ keys = keys.transpose(1, 2)
+ queries = queries.transpose(1, 2)
+ values = values.transpose(1, 2)
+
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
+
+ # Original mask truncated to the number of tokens and converted to boolean
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
+
+ # Use the mask to fill attention scores
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
+
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+ attn_weights = self.dropout(attn_weights)
+
+ # Shape: (b, num_tokens, num_heads, head_dim)
+ context_vec = (attn_weights @ values).transpose(1, 2)
+
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
+ context_vec = context_vec.reshape(b, num_tokens, self.d_out)
+ context_vec = self.out_proj(context_vec) # optional projection
+
+ return context_vec
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+ def __init__(self, emb_dim):
+ super().__init__()
+ self.eps = 1e-5
+ self.scale = nn.Parameter(torch.ones(emb_dim))
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
+ return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+ (x + 0.044715 * torch.pow(x, 3))
+ ))
+
+
+class FeedForward(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+ GELU(),
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.att = MultiHeadAttention(
+ d_in=cfg["emb_dim"],
+ d_out=cfg["emb_dim"],
+ context_length=cfg["context_length"],
+ num_heads=cfg["n_heads"],
+ dropout=cfg["drop_rate"],
+ qkv_bias=cfg["qkv_bias"])
+ self.ff = FeedForward(cfg)
+ self.norm1 = LayerNorm(cfg["emb_dim"])
+ self.norm2 = LayerNorm(cfg["emb_dim"])
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+ def forward(self, x):
+ # Shortcut connection for attention block
+ shortcut = x
+ x = self.norm1(x)
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ # Shortcut connection for feed-forward block
+ shortcut = x
+ x = self.norm2(x)
+ x = self.ff(x)
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ return x
+
+
+class GPTModel(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+ self.trf_blocks = nn.Sequential(
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+ self.final_norm = LayerNorm(cfg["emb_dim"])
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+ def forward(self, in_idx):
+ batch_size, seq_len = in_idx.shape
+ tok_embeds = self.tok_emb(in_idx)
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_emb(x)
+ x = self.trf_blocks(x)
+ x = self.final_norm(x)
+ logits = self.out_head(x)
+ return logits
diff --git a/ch05/README.md b/ch05/README.md
index 4718a509..0a446e43 100644
--- a/ch05/README.md
+++ b/ch05/README.md
@@ -14,3 +14,4 @@
- [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script
- [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM
- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI
+- [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently