diff --git a/README.md b/README.md index bee8635d..44f06807 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,7 @@ Several folders contain optional materials as a bonus for interested readers: - [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface) - [Converting GPT to Llama](ch05/07_gpt_to_llama) - [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb) + - [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb) - **Chapter 6:** - [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments) - [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification) diff --git a/ch05/08_memory_efficient_weight_loading/README.md b/ch05/08_memory_efficient_weight_loading/README.md new file mode 100644 index 00000000..2b8fef08 --- /dev/null +++ b/ch05/08_memory_efficient_weight_loading/README.md @@ -0,0 +1,5 @@ +# Memory-efficient Model Weight Loading + +This folder contains code to illustrate how to load model weights more efficiently + +- [memory-efficient-state-dict.ipynb](memory-efficient-state-dict.ipynb): contains code to load model weights via PyTorch's `load_state_dict` method more efficiently diff --git a/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb new file mode 100644 index 00000000..a8b3d351 --- /dev/null +++ b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb @@ -0,0 +1,866 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1E_HhLEeYqFG" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZuWudYFWYiH7" + }, + "source": [ + "# Memory-efficient Model Weight Loading" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qt0Qyg6ewUt6" + }, + "source": [ + "- This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited\n", + "- Specifically, it focuses on cases where you saved the model using `torch.save(model.state_dict(), \"model.pth\")` (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning\n", + "- While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SxQzFoS-IXdY", + "outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "memory_profiler version: 0.61.0\n", + "torch version: 2.4.1+cu121\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"torch\",\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y47iQaQKyHap" + }, + "source": [ + " \n", + "## 1. Benchmark utilities" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nQeOEoo6yT0X" + }, + "source": [ + "- First, let's define some utility code to track VRAM (GPU memory)\n", + "- Later, we will also introduce a tool to track the main system RAM (CPU memory)\n", + "- The purpose of these functions will become clear when we apply them later" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "pEiqjYrVivgt" + }, + "outputs": [], + "source": [ + "import gc\n", + "import time\n", + "import torch\n", + "\n", + "\n", + "def start_memory_tracking():\n", + " \"\"\"Initialize GPU memory tracking.\"\"\"\n", + " if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + " else:\n", + " print(\"This notebook is intended for CUDA GPUs but CUDA is not available.\")\n", + "\n", + "def print_memory_usage():\n", + " max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB\n", + " print(f\"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB\")\n", + "\n", + "def cleanup():\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + " time.sleep(3) # some buffer time to allow memory to clear\n", + " torch.cuda.reset_peak_memory_stats()\n", + " max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)\n", + " print(f\"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z5oJwoc-kkXs" + }, + "source": [ + " \n", + "## 2. Model setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YfJE0vnMyr88" + }, + "source": [ + "- This code section sets up the model itself\n", + "- Here, we use the \"large\" GPT-2 model to make things more interesting (you may use the \"gpt2-small (124M)\" to lower the memory requirements and execution time of this notebook)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "tMuhCYaVI0w7" + }, + "outputs": [], + "source": [ + "from previous_chapters import GPTModel\n", + "\n", + "\n", + "BASE_CONFIG = {\n", + " \"vocab_size\": 50257, # Vocabulary size\n", + " \"context_length\": 1024, # Context length\n", + " \"drop_rate\": 0.0, # Dropout rate\n", + " \"qkv_bias\": True # Query-key-value bias\n", + "}\n", + "\n", + "model_configs = {\n", + " \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", + " \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", + " \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", + " \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", + "}\n", + "\n", + "CHOOSE_MODEL = \"gpt2-xl (1558M)\"\n", + "\n", + "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KWYoo1z5y8aX" + }, + "source": [ + "- Now, let's see the GPU memory functions in action:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GK3NEA3eJv3f", + "outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 6.4 GB\n" + ] + } + ], + "source": [ + "start_memory_tracking()\n", + "\n", + "\n", + "model = GPTModel(BASE_CONFIG)\n", + "device = torch.device(\"cuda\")\n", + "model.to(device)\n", + "\n", + "print_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GIhwBEBxzBsF" + }, + "source": [ + "- Additionally, let's make sure that the model runs okay by passing in some example tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "i_j6nZruUd7g" + }, + "outputs": [], + "source": [ + "# Test if the model works (no need to track memory here)\n", + "test_input = torch.tensor([[1, 2, 3]]).to(device)\n", + "model.eval()\n", + "\n", + "with torch.no_grad():\n", + " model(test_input)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UgNb8c32zh4g" + }, + "source": [ + "- Next, imagine we were pretraining the model and saving it for later use\n", + "- We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "wUIXjcsimXU7" + }, + "outputs": [], + "source": [ + "# Training code would go here...\n", + "\n", + "model.train()\n", + "torch.save(model.state_dict(), \"model.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s9tBS4HUzz1g" + }, + "source": [ + "- Lastly, we delete the model and example tensor in the Python session to reset the GPU memory" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SqmTzztqKnTs", + "outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 0.0 GB\n" + ] + } + ], + "source": [ + "del model, test_input\n", + "cleanup()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7EnO8beUJ6Sb" + }, + "source": [ + " \n", + "## 3. Weight loading" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JtAXKjsG0AVL" + }, + "source": [ + "- Now begins the interesting part where we load the pretrained model weights\n", + "- Let's see how much GPU memory is required to load the previously saved model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wCrQNbSJJO9w", + "outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 12.8 GB\n" + ] + } + ], + "source": [ + "# Then load pretrained weights\n", + "\n", + "start_memory_tracking()\n", + "\n", + "model = GPTModel(BASE_CONFIG)\n", + "model.to(device)\n", + "\n", + "model.load_state_dict(\n", + " torch.load(\"model.pth\", map_location=device, weights_only=True)\n", + ")\n", + "model.to(device)\n", + "model.eval();\n", + "\n", + "print_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4AGvOrcN0KdJ" + }, + "source": [ + "- Notice that the memory is 2x as large as in the previous session\n", + "- This is because we have the same model in memory twice, for a short period of time:\n", + " - The first time via `model.to(device)`\n", + " - The second time via the code line `model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))`; eventually, the loaded model weights will be copied into the model, and the `state_dict` will be discarded, but for a brief amount of time, we have both the main model and the loaded `state_dict` in memory\n", + "- The remaining sections focus on addressing this\n", + "- But first, let's test the model and reset the GPU memory\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DvlUn-nmmbuj", + "outputId": "11d3ab68-f570-4c1e-c631-fe5547026799" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 0.0 GB\n" + ] + } + ], + "source": [ + "# Test if the model works (no need to track memory here)\n", + "test_input = torch.tensor([[1, 2, 3]]).to(device)\n", + "model.eval()\n", + "\n", + "with torch.no_grad():\n", + " model(test_input)\n", + "\n", + "del model, test_input\n", + "cleanup()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RdPnW3iLLrjX" + }, + "source": [ + " \n", + "## 4. Loading weights sequentially" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FYqtUON602TD" + }, + "source": [ + "- One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially\n", + "- Below, we:\n", + " - first load the model into GPU memory\n", + " - then load the model weights into CPU memory\n", + " - and finally copy each parameter one by one into GPU memory\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DOIGTNWTmx9G", + "outputId": "145162e6-aaa6-4c2a-ed8f-f1cf068adb80" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 6.4 GB\n", + "Maximum GPU memory allocated: 6.7 GB\n" + ] + } + ], + "source": [ + "start_memory_tracking()\n", + "\n", + "model = GPTModel(BASE_CONFIG).to(device)\n", + "\n", + "state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n", + "\n", + "print_memory_usage()\n", + "\n", + "# Sequentially copy weights to the model's parameters\n", + "with torch.no_grad():\n", + " for name, param in model.named_parameters():\n", + " if name in state_dict:\n", + " param.copy_(state_dict[name].to(device))\n", + " else:\n", + " print(f\"Warning: {name} not found in state_dict.\")\n", + "\n", + "print_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pn9xD_xL1ZzM" + }, + "source": [ + "- As we can see above, the memory usage is much lower than before\n", + "- Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using `\".to\"` the model)\n", + "- Overall, this is a significant improvement\n", + "- Again, let's briefly test the model and then reset the GPU memory for the next section" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PRHnjA48nJgw", + "outputId": "dcd6b1b2-538f-4862-96a6-a5fcbf3326a4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 0.0 GB\n" + ] + } + ], + "source": [ + "# Test if the model works (no need to track memory here)\n", + "test_input = torch.tensor([[1, 2, 3]]).to(device)\n", + "model.eval()\n", + "\n", + "with torch.no_grad():\n", + " model(test_input)\n", + "\n", + "del model, test_input, state_dict, param\n", + "cleanup()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5M92LK7usb-Z" + }, + "source": [ + " \n", + "## 5. Loading the model with low CPU memory" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R45qgeB613e2" + }, + "source": [ + "- In the previous session, we reduced GPU memory use by loading the weights (`state_dict`) into CPU memory first before copying them one-by-one into the model\n", + "- However, what do we do if we have limited CPU memory?\n", + "- This section uses PyTorch's so-called `\"meta\"` device approach to load a model on machines with large GPU memory but small CPU memory\n", + "- But first, let's define a convenience function to monitor CPU memory" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "BrcWy0q-3Bbe" + }, + "outputs": [], + "source": [ + "import os\n", + "import psutil\n", + "from threading import Thread\n", + "\n", + "\n", + "def memory_usage_in_gb(func, *args, **kwargs):\n", + " process = psutil.Process(os.getpid())\n", + "\n", + " # Measure the baseline memory usage before running the function\n", + " baseline_mem = process.memory_info().rss / 1024 ** 3 # in GB\n", + "\n", + " # Start monitoring memory in a separate thread\n", + " mem_usage = []\n", + " done = False\n", + "\n", + " def monitor_memory():\n", + " while not done:\n", + " mem_usage.append(process.memory_info().rss / 1024 ** 3) # Convert to GB\n", + " time.sleep(0.1)\n", + "\n", + " t = Thread(target=monitor_memory)\n", + " t.start()\n", + "\n", + " # Run the function\n", + " func(*args, **kwargs)\n", + "\n", + " # Stop monitoring\n", + " done = True\n", + " t.join()\n", + "\n", + " peak_mem_usage_gb = max(mem_usage) - baseline_mem\n", + " return peak_mem_usage_gb\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ayy30Ytd5hjF" + }, + "source": [ + "- To start with, let's track the CPU memory of the sequential weight loading approach from the previous section" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rCkV6IbQtpVn", + "outputId": "26c0435a-1e3d-4e8f-fbe2-f9655bad61b4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 6.4 GB\n", + "Maximum GPU memory allocated: 6.7 GB\n", + "-> Maximum CPU memory allocated: 6.3 GB\n" + ] + } + ], + "source": [ + "def load_sequentially():\n", + " start_memory_tracking()\n", + "\n", + " model = GPTModel(BASE_CONFIG).to(device)\n", + "\n", + " state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n", + "\n", + " print_memory_usage()\n", + "\n", + " # Sequentially copy weights to the model's parameters\n", + " with torch.no_grad():\n", + " for name, param in model.named_parameters():\n", + " if name in state_dict:\n", + " param.copy_(state_dict[name].to(device))\n", + " else:\n", + " print(f\"Warning: {name} not found in state_dict.\")\n", + "\n", + " print_memory_usage()\n", + "\n", + "\n", + "peak_memory_used = memory_usage_in_gb(load_sequentially)\n", + "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UWrmnCML5oKy" + }, + "source": [ + "- Now, suppose we have a machine with low CPU memory but large GPU memory\n", + "- We can trade off CPU memory and GPU memory usage by introducing PyTorch's so-called \"meta\" device\n", + "- PyTorch's meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating \"meta\" tensors\n", + "- This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PBErC_5Yt8ly", + "outputId": "8799db06-191c-47c4-92fa-fbb95d685aa9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 12.8 GB\n", + "Maximum GPU memory allocated: 12.8 GB\n", + "-> Maximum CPU memory allocated: 1.3 GB\n" + ] + } + ], + "source": [ + "def load_sequentially_with_meta():\n", + " start_memory_tracking()\n", + "\n", + " with torch.device(\"meta\"):\n", + " model = GPTModel(BASE_CONFIG)\n", + "\n", + " model = model.to_empty(device=device)\n", + "\n", + " state_dict = torch.load(\"model.pth\", map_location=device, weights_only=True)\n", + "\n", + " print_memory_usage()\n", + "\n", + " # Sequentially copy weights to the model's parameters\n", + " with torch.no_grad():\n", + " for name, param in model.named_parameters():\n", + " if name in state_dict:\n", + " param.copy_(state_dict[name])\n", + " else:\n", + " print(f\"Warning: {name} not found in state_dict.\")\n", + "\n", + " print_memory_usage()\n", + "\n", + "peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)\n", + "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VpnCABp75-VQ" + }, + "source": [ + "- As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements\n", + "- One might ask: \"Is the sequential weight loading still necessary then, and how does that compare to the original approach?\"\n", + "- Let's check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4f-bqBNRuR39", + "outputId": "f7c0a901-b404-433a-9b93-2bbfa8183c56" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 12.8 GB\n", + "-> Maximum CPU memory allocated: 4.4 GB\n" + ] + } + ], + "source": [ + "def baseline():\n", + " start_memory_tracking()\n", + "\n", + " model = GPTModel(BASE_CONFIG)\n", + " model.to(device)\n", + "\n", + " model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n", + " model.to(device)\n", + " model.eval();\n", + "\n", + " print_memory_usage()\n", + "\n", + "peak_memory_used = memory_usage_in_gb(baseline)\n", + "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NKAjxbX86xnb" + }, + "source": [ + "- As we can see above, the \"simple\" weight loading without the meta device uses more memory\n", + "- In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " \n", + "## 6. Other methods" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- This notebook is focused on simple, built-in methods for loading weights in PyTorch.\n", + "- In case none of these methods work because you (1) don't have enough CPU memory for the `load_sequentially` approach and don't have enough GPU VRAM to have 2 copies of the weights in memory (the `load_sequentially_with_meta` approach), one option is to save and load each weight tensor separately:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "2CgPEZUIb00w" + }, + "outputs": [], + "source": [ + "model = GPTModel(BASE_CONFIG)\n", + "# Assume `model` is your trained model\n", + "state_dict = model.state_dict()\n", + "\n", + "# Create a directory to store individual parameter files\n", + "os.makedirs(\"model_parameters\", exist_ok=True)\n", + "\n", + "# Save each parameter tensor separately\n", + "for name, param in state_dict.items():\n", + " torch.save(param.cpu(), f\"model_parameters/{name}.pt\")\n", + "\n", + "del model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gTsmtJK-b4yy", + "outputId": "d361e2d3-e34c-48d7-9047-846c9bfd291e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum GPU memory allocated: 6.4 GB\n", + "Maximum GPU memory allocated: 6.4 GB\n", + "-> Maximum CPU memory allocated: 0.3 GB\n" + ] + } + ], + "source": [ + "def load_individual_weights():\n", + "\n", + " start_memory_tracking()\n", + "\n", + " with torch.device(\"meta\"):\n", + " model = GPTModel(BASE_CONFIG)\n", + "\n", + " model = model.to_empty(device=device)\n", + "\n", + " print_memory_usage()\n", + " param_dir = \"model_parameters\"\n", + "\n", + " with torch.no_grad():\n", + " for name, param in model.named_parameters():\n", + " weight_path = os.path.join(param_dir, f\"{name}.pt\")\n", + " if os.path.exists(weight_path):\n", + " param_data = torch.load(weight_path, map_location=\"cpu\", weights_only=True)\n", + " param.copy_(param_data)\n", + " del param_data # Free memory\n", + " else:\n", + " print(f\"Warning: {name} not found in {param_dir}.\")\n", + "\n", + " print_memory_usage()\n", + "\n", + "\n", + "peak_memory_used = memory_usage_in_gb(load_individual_weights)\n", + "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "L4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/ch05/08_memory_efficient_weight_loading/previous_chapters.py b/ch05/08_memory_efficient_weight_loading/previous_chapters.py new file mode 100644 index 00000000..3674201e --- /dev/null +++ b/ch05/08_memory_efficient_weight_loading/previous_chapters.py @@ -0,0 +1,175 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch +# +# This file collects all the relevant code that we covered thus far +# throughout Chapters 2-5. + +import json +import os +import urllib + +import numpy as np +import tensorflow as tf +import torch +import torch.nn as nn +from tqdm import tqdm + + +##################################### +# Chapter 3 +##################################### +class MultiHeadAttention(nn.Module): + def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): + super().__init__() + assert d_out % num_heads == 0, "d_out must be divisible by n_heads" + + self.d_out = d_out + self.num_heads = num_heads + self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim + + self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) + self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) + self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs + self.dropout = nn.Dropout(dropout) + self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) + + def forward(self, x): + b, num_tokens, d_in = x.shape + + keys = self.W_key(x) # Shape: (b, num_tokens, d_out) + queries = self.W_query(x) + values = self.W_value(x) + + # We implicitly split the matrix by adding a `num_heads` dimension + # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) + keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) + values = values.view(b, num_tokens, self.num_heads, self.head_dim) + queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) + + # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) + keys = keys.transpose(1, 2) + queries = queries.transpose(1, 2) + values = values.transpose(1, 2) + + # Compute scaled dot-product attention (aka self-attention) with a causal mask + attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head + + # Original mask truncated to the number of tokens and converted to boolean + mask_bool = self.mask.bool()[:num_tokens, :num_tokens] + + # Use the mask to fill attention scores + attn_scores.masked_fill_(mask_bool, -torch.inf) + + attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) + attn_weights = self.dropout(attn_weights) + + # Shape: (b, num_tokens, num_heads, head_dim) + context_vec = (attn_weights @ values).transpose(1, 2) + + # Combine heads, where self.d_out = self.num_heads * self.head_dim + context_vec = context_vec.reshape(b, num_tokens, self.d_out) + context_vec = self.out_proj(context_vec) # optional projection + + return context_vec + + +##################################### +# Chapter 4 +##################################### +class LayerNorm(nn.Module): + def __init__(self, emb_dim): + super().__init__() + self.eps = 1e-5 + self.scale = nn.Parameter(torch.ones(emb_dim)) + self.shift = nn.Parameter(torch.zeros(emb_dim)) + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + norm_x = (x - mean) / torch.sqrt(var + self.eps) + return self.scale * norm_x + self.shift + + +class GELU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + torch.sqrt(torch.tensor(2.0 / torch.pi)) * + (x + 0.044715 * torch.pow(x, 3)) + )) + + +class FeedForward(nn.Module): + def __init__(self, cfg): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), + GELU(), + nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), + ) + + def forward(self, x): + return self.layers(x) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg): + super().__init__() + self.att = MultiHeadAttention( + d_in=cfg["emb_dim"], + d_out=cfg["emb_dim"], + context_length=cfg["context_length"], + num_heads=cfg["n_heads"], + dropout=cfg["drop_rate"], + qkv_bias=cfg["qkv_bias"]) + self.ff = FeedForward(cfg) + self.norm1 = LayerNorm(cfg["emb_dim"]) + self.norm2 = LayerNorm(cfg["emb_dim"]) + self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) + + def forward(self, x): + # Shortcut connection for attention block + shortcut = x + x = self.norm1(x) + x = self.att(x) # Shape [batch_size, num_tokens, emb_size] + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + # Shortcut connection for feed-forward block + shortcut = x + x = self.norm2(x) + x = self.ff(x) + x = self.drop_shortcut(x) + x = x + shortcut # Add the original input back + + return x + + +class GPTModel(nn.Module): + def __init__(self, cfg): + super().__init__() + self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) + self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) + + self.trf_blocks = nn.Sequential( + *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) + + self.final_norm = LayerNorm(cfg["emb_dim"]) + self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + + def forward(self, in_idx): + batch_size, seq_len = in_idx.shape + tok_embeds = self.tok_emb(in_idx) + pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) + x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] + x = self.drop_emb(x) + x = self.trf_blocks(x) + x = self.final_norm(x) + logits = self.out_head(x) + return logits diff --git a/ch05/README.md b/ch05/README.md index 4718a509..0a446e43 100644 --- a/ch05/README.md +++ b/ch05/README.md @@ -14,3 +14,4 @@ - [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script - [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM - [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI +- [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently