From f77c376b0524398ca4f23e34de34c622185dd709 Mon Sep 17 00:00:00 2001 From: Mingyuan Xu Date: Fri, 13 Sep 2024 21:01:52 +0800 Subject: [PATCH] Run generate example in ch06 optionally on GPU (#352) * model.to("cuda") model.to("cuda") * update device placement --------- Co-authored-by: rasbt --- .../load-finetuned-model.ipynb | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/ch06/01_main-chapter-code/load-finetuned-model.ipynb b/ch06/01_main-chapter-code/load-finetuned-model.ipynb index 7ac6f5e6..4e931695 100644 --- a/ch06/01_main-chapter-code/load-finetuned-model.ipynb +++ b/ch06/01_main-chapter-code/load-finetuned-model.ipynb @@ -3,7 +3,9 @@ { "cell_type": "markdown", "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d", - "metadata": {}, + "metadata": { + "id": "1545a16b-bc8d-4e49-b9a6-db6631e7483d" + }, "source": [ "\n", "\n", @@ -23,7 +25,9 @@ { "cell_type": "markdown", "id": "f3f83194-82b9-4478-9550-5ad793467bd0", - "metadata": {}, + "metadata": { + "id": "f3f83194-82b9-4478-9550-5ad793467bd0" + }, "source": [ "# Load And Use Finetuned Model" ] @@ -31,7 +35,9 @@ { "cell_type": "markdown", "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e", - "metadata": {}, + "metadata": { + "id": "466b564e-4fd5-4d76-a3a1-63f9f0993b7e" + }, "source": [ "This notebook contains minimal code to load the finetuned model that was created and saved in chapter 6 via [ch06.ipynb](ch06.ipynb)." ] @@ -40,7 +46,13 @@ "cell_type": "code", "execution_count": 1, "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", - "metadata": {}, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fd80e5f5-0f79-4a6c-bf31-2026e7d30e52", + "outputId": "9eeefb8e-a7eb-4d62-cf78-c797b3ed4e2e" + }, "outputs": [ { "name": "stdout", @@ -66,7 +78,9 @@ "cell_type": "code", "execution_count": 2, "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201", - "metadata": {}, + "metadata": { + "id": "ed86d6b7-f32d-4601-b585-a2ea3dbf7201" + }, "outputs": [], "source": [ "from pathlib import Path\n", @@ -83,7 +97,9 @@ "cell_type": "code", "execution_count": 3, "id": "fb02584a-5e31-45d5-8377-794876907bc6", - "metadata": {}, + "metadata": { + "id": "fb02584a-5e31-45d5-8377-794876907bc6" + }, "outputs": [], "source": [ "from previous_chapters import GPTModel\n", @@ -116,7 +132,9 @@ "cell_type": "code", "execution_count": 4, "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94", - "metadata": {}, + "metadata": { + "id": "f1ccf2b7-176e-4cfd-af7a-53fb76010b94" + }, "outputs": [], "source": [ "import torch\n", @@ -128,6 +146,7 @@ "# Then load pretrained weights\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.load_state_dict(torch.load(\"review_classifier.pth\", map_location=device, weights_only=True))\n", + "model.to(device)\n", "model.eval();" ] }, @@ -135,7 +154,9 @@ "cell_type": "code", "execution_count": 5, "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5", - "metadata": {}, + "metadata": { + "id": "a1fd174e-9555-46c5-8780-19b0aa4f26e5" + }, "outputs": [], "source": [ "import tiktoken\n", @@ -147,7 +168,9 @@ "cell_type": "code", "execution_count": 6, "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2", - "metadata": {}, + "metadata": { + "id": "2a4c0129-efe5-46e9-bb90-ba08d407c1a2" + }, "outputs": [], "source": [ "# This function was implemented in ch06.ipynb\n", @@ -167,7 +190,7 @@ "\n", " # Model inference\n", " with torch.no_grad():\n", - " logits = model(input_tensor)[:, -1, :] # Logits of the last output token\n", + " logits = model(input_tensor.to(device))[:, -1, :] # Logits of the last output token\n", " predicted_label = torch.argmax(logits, dim=-1).item()\n", "\n", " # Return the classified result\n", @@ -178,7 +201,13 @@ "cell_type": "code", "execution_count": 7, "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", - "metadata": {}, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1e26862c-10b5-4a0f-9dd6-b6ddbad2fc3f", + "outputId": "28eb2c02-0e38-4356-b2a3-2bf6accb5316" + }, "outputs": [ { "name": "stdout", @@ -203,7 +232,13 @@ "cell_type": "code", "execution_count": 8, "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", - "metadata": {}, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "78472e05-cb4e-4ec4-82e8-23777aa90cf8", + "outputId": "0cd3cd62-f407-45f3-fa4f-51ff665355eb" + }, "outputs": [ { "name": "stdout", @@ -226,6 +261,11 @@ } ], "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "L4", + "provenance": [] + }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python",