diff --git a/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb index 37825279..8ab9d4f7 100644 --- a/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb +++ b/ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb @@ -752,7 +752,7 @@ "metadata": {}, "source": [ " \n", - "## 6. Using `mmap=True`" + "## 6. Using `mmap=True` (recommmended)" ] }, { @@ -760,19 +760,20 @@ "metadata": {}, "source": [ "- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n", - "- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM\n", - "- However, in practice, I found it to be less efficient than the sequential approaches above" + "- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n", + "- Also, see the helpful comment by [mikaylagawarecki](https://github.com/rasbt/LLMs-from-scratch/issues/402)\n", + "- At first glance, it may look less efficient than the sequential approaches above:" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "7AX3vPrpv5c_", - "outputId": "e6fca10b-55c3-4e89-8674-075df5ce26e7" + "id": "GKwV0AMNemuR", + "outputId": "e207f2bf-5c87-498e-80fe-e8c4016ac711" }, "outputs": [ { @@ -780,63 +781,32 @@ "output_type": "stream", "text": [ "Maximum GPU memory allocated: 6.4 GB\n", - "-> Maximum CPU memory allocated: 9.9 GB\n" + "-> Maximum CPU memory allocated: 5.9 GB\n" ] } ], "source": [ - "def baseline_mmap():\n", - " start_memory_tracking()\n", + "def best_practices():\n", + " with torch.device(\"meta\"):\n", + " model = GPTModel(BASE_CONFIG)\n", "\n", - " model = GPTModel(BASE_CONFIG) # load model on CPU\n", - "\n", - " model.load_state_dict(\n", - " torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True, mmap=True)\n", - " )\n", - " model.to(device) # Move model to GPU\n", - " model.eval();\n", + " model.load_state_dict(\n", + " torch.load(\"model.pth\", map_location=device, weights_only=True, mmap=True),\n", + " assign=True\n", + " )\n", "\n", - " print_memory_usage()\n", + " print_memory_usage()\n", "\n", - "peak_memory_used = memory_usage_in_gb(baseline_mmap)\n", + "peak_memory_used = memory_usage_in_gb(best_practices)\n", "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" ] }, { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KUyK3QVRwmjR", - "outputId": "a77c191a-2f9e-4ae5-be19-8ce128e704e9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Maximum GPU memory allocated: 12.8 GB\n", - "-> Maximum CPU memory allocated: 7.0 GB\n" - ] - } - ], + "cell_type": "markdown", + "metadata": {}, "source": [ - "def baseline_mmap_2():\n", - " start_memory_tracking()\n", - "\n", - " model = GPTModel(BASE_CONFIG).to(device)\n", - "\n", - " model.load_state_dict(\n", - " torch.load(\"model.pth\", map_location=device, weights_only=True, mmap=True)\n", - " )\n", - " model.eval();\n", - "\n", - " print_memory_usage()\n", - "\n", - "peak_memory_used = memory_usage_in_gb(baseline_mmap_2)\n", - "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")" + "- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n", + "- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory" ] }, { @@ -851,8 +821,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- This notebook is focused on simple, built-in methods for loading weights in PyTorch.\n", - "- In case none of these methods work because you (1) don't have enough CPU memory for the `load_sequentially` approach and don't have enough GPU VRAM to have 2 copies of the weights in memory (the `load_sequentially_with_meta` approach), one option is to save and load each weight tensor separately:" + "- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n", + "- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n", + "- Alternatively, one other option is a brute-force approach that saves and loads each weight tensor separately:" ] }, {