diff --git a/samples/colab/pytorch_huggingface_whisper.ipynb b/samples/colab/pytorch_huggingface_whisper.ipynb index ef61b9ce79e6..4d0314e66a1b 100644 --- a/samples/colab/pytorch_huggingface_whisper.ipynb +++ b/samples/colab/pytorch_huggingface_whisper.ipynb @@ -15,6 +15,8 @@ }, "language_info": { "name": "python" + }, + "widgets": { } }, "cells": [ @@ -75,8 +77,7 @@ "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" ], "metadata": { - "id": "KsPubQSvCbXd", - "cellView": "form" + "id": "KsPubQSvCbXd" }, "execution_count": 2, "outputs": [] @@ -84,14 +85,14 @@ { "cell_type": "code", "source": [ - "!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" + "!python -m pip install --pre --index-url https://download.pytorch.org/whl/cpu --upgrade torch==2.5.0" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oO1tirq2ggmO", - "outputId": "c3390361-9f40-4a49-b5c7-898a62614143" + "outputId": "1c10e964-1bd3-41e7-d7ce-70cf574d817b" }, "execution_count": 3, "outputs": [ @@ -99,24 +100,26 @@ "output_type": "stream", "name": "stdout", "text": [ - "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", - "Collecting torch==2.3.0\n", - " Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n", - "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", + "Looking in indexes: https://download.pytorch.org/whl/cpu\n", + "Collecting torch==2.5.0\n", + " Downloading https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp310-cp310-linux_x86_64.whl (174.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (4.12.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.1.5)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (2024.10.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.0) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.0) (3.0.2)\n", "Installing collected packages: torch\n", " Attempting uninstall: torch\n", - " Found existing installation: torch 2.2.1+cu121\n", - " Uninstalling torch-2.2.1+cu121:\n", - " Successfully uninstalled torch-2.2.1+cu121\n", - "Successfully installed torch-2.3.0+cpu\n" + " Found existing installation: torch 2.5.1+cu121\n", + " Uninstalling torch-2.5.1+cu121:\n", + " Successfully uninstalled torch-2.5.1+cu121\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "timm 1.0.12 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed torch-2.5.0+cpu\n" ] } ] @@ -129,7 +132,7 @@ "base_uri": "https://localhost:8080/" }, "id": "4iJFDHbsAzo4", - "outputId": "94721ae8-e222-4203-c356-888b42bc20b9" + "outputId": "c95e32a5-70ab-43e7-8c8c-300d37cccfd3" }, "outputs": [ { @@ -137,27 +140,35 @@ "name": "stdout", "text": [ "Collecting iree-turbine\n", - " Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n", - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/150.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━\u001b[0m \u001b[32m143.4/150.4 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n", - "Collecting iree-compiler>=20240410.859 (from iree-turbine)\n", - " Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n", - " Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", - "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n", - "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n", - "Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n", - "Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n" + " Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n", + "Collecting iree-base-compiler (from iree-turbine)\n", + " Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", + "Collecting iree-base-runtime (from iree-turbine)\n", + " Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", + "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n", + "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n", + " Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n", + "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n", + " Attempting uninstall: ml_dtypes\n", + " Found existing installation: ml-dtypes 0.4.1\n", + " Uninstalling ml-dtypes-0.4.1:\n", + " Successfully uninstalled ml-dtypes-0.4.1\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n" ] } ], @@ -182,7 +193,7 @@ "base_uri": "https://localhost:8080/" }, "id": "nkVLzRpcDnVL", - "outputId": "ee4e956f-ca7d-45ac-9913-672ad444d89f" + "outputId": "210a54b9-4044-4426-f9ee-09d5fd23839c" }, "execution_count": 5, "outputs": [ @@ -190,15 +201,15 @@ "output_type": "stream", "name": "stdout", "text": [ - "Installed iree-turbine, Version: 2.3.0rc20240410\n", + "Installed iree-turbine, Version: 3.1.0\n", "\n", "Installed IREE, compiler version information:\n", "IREE (https://iree.dev):\n", - " IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n", - " LLVM version 19.0.0git\n", + " IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n", + " LLVM version 20.0.0git\n", " Optimized build\n", "\n", - "Installed PyTorch, version: 2.3.0+cpu\n" + "Installed PyTorch, version: 2.5.0+cpu\n" ] } ] @@ -259,10 +270,288 @@ "example_args = (example_input,)" ], "metadata": { - "id": "HLbfUuoBPHgH" + "id": "HLbfUuoBPHgH", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 476, + "referenced_widgets": [ + "aaaaa863923e4b6fb92380ff063677a7", + "37ee0e0bcecc44a1aaf25a8e66d752ad", + "139c5c777da84307a28ffc7de9e037f6", + "ea6b64fa6c5a4d0aa8797ac9dfa9051a", + "28125f8e13a14bfaad7c645590360b22", + "4b4afd6e36a94799856aa8dd2243f9a2", + "7e2f1ea6da7d430fa5df1fc3b9733e10", + "5257eab4b9cd42bbb106404bfe428903", + "9ae9e376cec54470b398a9cf5fa7b9fd", + "b0a36845fd614468a45de443b2e5c0f8", + "757a8987af504be89d2cc7d058728d9c", + "d88e2b3299f8431bbe7af6e1232ce54c", + "d1d3a7166d0f40b4bb26a146e8b76d5a", + "f105605987244b56b6e33f162fbe6930", + "70a4da0f7f224aa9b828c4493fe31101", + "4e8a310daae6485cb53c853bcbb6b029", + "9c966c41c6eb4407b6c8751c29b9d082", + "86685457be39483fa223541a8e51a79e", + "cc26cd911fb84bc19bdb782060138df4", + "a8fc625551bb40c9aedb37d548837cd2", + "ef484e2a7891478d822929c4728dfdd3", + "c2a512da304643e9ae86eb6b1c434934", + "d24fd8bd3c6349b8a265a18f96901458", + "a217516b31244cb194bb47f4da51ae6c", + "8b50957482094ec58561ed62fe53c720", + "87e3626416c440a785c3898baf2c8bce", + "ecafe4f7ddc64973a6fce7a4d0fcdcb3", + "cfa6bedf2057488ca273dea84107cdc9", + "6e7a805a79c749f48d3bfa1028e3de70", + "fff09ea7e19b47f29304f9f315425884", + "9931ae28665347c883d6b4723f405bfb", + "f454754510404b61b18a8f87cd8ba1ce", + "da14fb80f6994378bf405b38c3f86bab", + "560e52debb244737b1cb8f3088506e80", + "4fe7fe3078b24fc3811072b7790a9371", + "bec9bf346af9464b8fb120fbbaf2fefd", + "ffd5fca2e0b84317a473e358e85f3d77", + "114aa03e366f46188a771c98607e5adb", + "e8dbea7f1cd0443ca3cbe114e24ff3da", + "ab512519d4bb4d28b6109a25d9bd6b88", + "4ead0d0ec9994682ab8cc9ef027eaae2", + "3e61b4815dd14f619b90c715d62df347", + "c1420dce3e7246f3866a8ce85474fe5f", + "fa9daa6d53564dfdb7af8593edc69884", + "a6b1386c9842438ca5801134d26f0b51", + "c21ed7621a004d9bb24e9d68c3a76a6f", + "7b6e231cec8946118d7bcd745f518010", + "2fd77a6a408a41f1be6068ef77057a28", + "2c2b92f12df041b0b973ae9793cdc1ab", + "446e30992e244f48a68d9130e82a7126", + "9aece5f7db4646a7a7f46a186cde18eb", + "6faf78708ec44e5f8d5db701938965c3", + "3a26fee124d442158961fd5c1b28e5bd", + "549fc99a11d44ef1ab48c64b742d58e7", + "7516468965394c4b8bcbc7ea3db3b457", + "d476cb392ec5423cba72f46757f1df1a", + "9fbede93702942c492488056485bdb6c", + "e5b32ffa050946469cc375b9234e35ae", + "4d18beace81f48c08540812193fd5244", + "9076e8da1009489291af97753dcae650", + "5f0a6e2ed40a47cc8016e1a11100579e", + "df19cd089b854db6ad2230f0a457aca8", + "b671614693b141f9821ef7b78ee98ae1", + "f3e226744d6e45918d4a48b924389ad7", + "6f4e119ee815411094a7d9d5311f10b5", + "77503681f26b46419157af3d49a71bcb", + "16778bf77b8942939e323b655ac4dfa6", + "67057ee8ce464d64a7769ade5d7479dd", + "e05d7305000c4b1091631d7b15f7900a", + "c9c0abecdc0e4462a8a72e47bfb1e53d", + "02202837100e448383e6615758556655", + "86ffbd1afeab411fbd322c909cb21a5d", + "5870517d9c124931886a88a310c2386e", + "59dc9a70d31c492788508f67cd975365", + "cf5a0cc0a32d4df2a871ec642d2de5da", + "9a9554c7d1f04d9ab14df042cfefbedd", + "33a04a7c30dd42de9cfe6a43e988604d", + "83e8699174dd43bc9b646d6d49c993b3", + "27dbc17542ba451888f1e847b59f2da7", + "5a480878ee0d441da7ab360d1c93fbc6", + "d646f8677a1b4ca190f2410a8cfc05b7", + "1c4cd997e1104a4ba0f3a15b4ce7dea0", + "342b91aaba4d4df8aa567779d1a6f4e7", + "c7487e228a7048ff9e9be026dc5a9f46", + "60cec20206c243efa658a94c7278719c", + "57f8627eb25c4b84a74c4bf71c6c122f", + "3772acd3cc9d4e61ab2a4bf0f4a15774", + "f7a3e9b12c4b461e972113b8880aa985", + "a8fdc4ca612b47c68fd130b51bcd1ece", + "b5a84ce7032343b9b040a03e5432d96f", + "b893ddddacf74e7c8ca40666fb84c24e", + "9b44401b93664666a42c56d3165de181", + "c976d288588145069099727ab5183da6", + "b952b8ec0ebb4727802daddb7a3d5f4a", + "18293fa01eb14c8b9d7f2b00669ba82a", + "eab894861d1f406c973b525366a8e157", + "b727f95de2e245019a61ac5c9466ece4", + "a23bf4d95ebe4af1991fe1531b6b7b2f", + "fe045669fdc147a5a6bde04b4f31fcef", + "ccee9eaca1d944099ee97f8c0a5790b5", + "675803d6f0b94b07b3b465b427547a00", + "cff95693504d4e7eac5fdfe972cf7e12", + "566899e9d7c04a7c954f7186d988c1e4", + "d994cd6184ee402ca6e4ae0b7db8faea", + "d9d7829fe66049379ecf88ce7b385c34", + "fb82acd9ae0f441d9c1dcf87b47f3486", + "ef1916aba2f449cbabc64248ef9cc95f", + "58d678849b444dfaa467f87a1b7bd9fc", + "bb1027ed5dbb4fb68f90480e85cde62c", + "79cf395270734f4a926dd3fc165f65e2" + ] + }, + "outputId": "c33917e5-8ec5-4e03-85c3-9424a529fac9" }, - "execution_count": null, - "outputs": [] + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/283k [00:00=20231004.665 (from shark-turbine)\n", - " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n", - " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m91.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (2.1.0+cu118)\n", - "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n", - "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2.1.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n", - "Building wheels for collected packages: shark-turbine\n", - " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=507dec827b9a2eea18f47c6ebdc84347c9956b8f2e0b186d3107a006e0742d81\n", - " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n", - "Successfully built shark-turbine\n", - "Installing collected packages: iree-runtime, iree-compiler, shark-turbine\n", - "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 shark-turbine-0.9.1.dev3\n" + "Collecting iree-turbine\n", + " Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n", + "Collecting iree-base-compiler (from iree-turbine)\n", + " Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", + "Collecting iree-base-runtime (from iree-turbine)\n", + " Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", + "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n", + "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n", + " Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n", + "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n", + " Attempting uninstall: ml_dtypes\n", + " Found existing installation: ml-dtypes 0.4.1\n", + " Uninstalling ml-dtypes-0.4.1:\n", + " Successfully uninstalled ml-dtypes-0.4.1\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n" ] } ] @@ -173,7 +172,7 @@ "cell_type": "code", "source": [ "#@title Report version information\n", - "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n", + "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", "\n", "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", "!iree-compile --version\n", @@ -186,23 +185,23 @@ "base_uri": "https://localhost:8080/" }, "id": "Oj5I6R9LI7t_", - "outputId": "35d79e6a-7bd0-46e1-8113-5af1a7bcbb5b" + "outputId": "deaa1abf-dc0e-49d8-d165-47d53592d94f" }, - "execution_count": 4, + "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "Installed SHARK-Turbine, Version: 0.9.1.dev3\n", + "Installed iree-turbine, Version: 3.1.0\n", "\n", "Installed IREE, compiler version information:\n", "IREE (https://iree.dev):\n", - " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n", - " LLVM version 18.0.0git\n", + " IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n", + " LLVM version 20.0.0git\n", " Optimized build\n", "\n", - "Installed PyTorch, version: 2.1.0+cu118\n" + "Installed PyTorch, version: 2.5.1+cu121\n" ] } ] @@ -210,7 +209,7 @@ { "cell_type": "markdown", "source": [ - "## Create a program using PyTorch + SHARK-Turbine\n", + "## Create a program using PyTorch + iree-turbine\n", "\n", "NOTE: as in other domains, providing more information to a compiler allows it\n", "to generate more efficient code. As a general rule, the slowest varying\n", @@ -227,45 +226,78 @@ { "cell_type": "code", "source": [ - "#@title Define a sample `shark_turbine.aot.CompiledModule` using dynamic shapes\n", + "#@title Define a sample `torch.nn.Module`.\n", "\n", - "import shark_turbine.aot as aot\n", + "import iree.turbine.aot as aot\n", "\n", - "class DynamicShapesModule(aot.CompiledModule, export_name=\"module\"):\n", + "class DynamicShapesModule(torch.nn.Module):\n", " # reduce_sum_1d (dynamic input size, static output size)\n", " # tensor -> tensor\n", " # e.g. [1, 2, 3] -> 6\n", - " def reduce_sum_1d(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", - " return self.compute_reduce_sum_1d(values)\n", - "\n", - " @aot.jittable\n", - " def compute_reduce_sum_1d(values):\n", - " return torch.sum(values, dtype=torch.int32)\n", + " def reduce_sum_1d(self, values):\n", + " return torch.sum(values)\n", "\n", " # reduce_sum_2d (partially dynamic input size, static output size)\n", " # tensor -> tensor<3xi32>\n", " # e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n", - " def reduce_sum_2d(self, values=aot.AbstractTensor(None, 3, dtype=torch.int32)):\n", - " return self.compute_reduce_sum_2d(values)\n", - "\n", - " @aot.jittable\n", - " def compute_reduce_sum_2d(values):\n", - " return torch.sum(values, 0, dtype=torch.int32)\n", + " def reduce_sum_2d(self, values):\n", + " return torch.sum(values, 0)\n", "\n", " # add_one (dynamic input size, dynamic output size)\n", " # tensor) -> tensor\n", " # e.g. [1, 2, 3] -> [2, 3, 4]\n", - " def add_one(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", - " return self.compute_add_one(values)\n", - "\n", - " @aot.jittable\n", - " def compute_add_one(values):\n", + " def add_one(self, values):\n", " return values + 1" ], "metadata": { "id": "vsf9F4WxI_DX" }, - "execution_count": 5, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Export using FxProgramsBuilder.\n", + "\n", + "fxb = aot.FxProgramsBuilder(DynamicShapesModule())\n", + "\n", + "# Create a single dynamic export dimension.\n", + "dynamic_x = torch.export.Dim(\"x\")\n", + "# Example inputs with a mix of placeholder (dynamic) and static dimensions.\n", + "example_1d = torch.empty(16, dtype=torch.int32)\n", + "example_2d = torch.empty((16, 3), dtype=torch.int32)\n", + "\n", + "# Export reduce_sum_1d with a dynamic dimension.\n", + "@fxb.export_program(\n", + " args=(example_1d,),\n", + " dynamic_shapes={\"values\": {0: dynamic_x}},\n", + ")\n", + "def reduce_sum_1d(module, values):\n", + " return module.reduce_sum_1d(values)\n", + "\n", + "# Export reduce_sum_2d with one dynamic dimension.\n", + "@fxb.export_program(\n", + " args=(example_2d,),\n", + " dynamic_shapes={\"values\": {0: dynamic_x}},\n", + ")\n", + "def reduce_sum_2d(module, values):\n", + " return module.reduce_sum_2d(values)\n", + "\n", + "# Export add_one with a dynamic dimension.\n", + "@fxb.export_program(\n", + " args=(example_1d,),\n", + " dynamic_shapes={\"values\": {0: dynamic_x}},\n", + ")\n", + "def add_one(module, values):\n", + " return module.add_one(values)\n", + "\n", + "export_output = aot.export(fxb)" + ], + "metadata": { + "id": "cCy3nuLBKTAg" + }, + "execution_count": 7, "outputs": [] }, { @@ -273,10 +305,8 @@ "source": [ "from iree.compiler.ir import Context\n", "\n", - "# Import into MLIR and save to disk.\n", - "dynamic_shapes_instance = DynamicShapesModule(context=Context())\n", "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n", - "aot.CompiledModule.save_mlir(dynamic_shapes_instance, imported_mlir_path)\n", + "export_output.save_mlir(imported_mlir_path)\n", "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")\n", "\n", "# Inspect the IR.\n", @@ -289,9 +319,9 @@ "base_uri": "https://localhost:8080/" }, "id": "_OQIpOtNr4Gh", - "outputId": "888c0bf3-bec6-403c-9993-ad45d21364fb" + "outputId": "abe96b74-88de-4979-959c-cdfbc981b17c" }, - "execution_count": 6, + "execution_count": 8, "outputs": [ { "output_type": "stream", @@ -300,56 +330,25 @@ "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n", "\n", "Dynamic Shapes MLIR:\n", - "#map = affine_map<(d0) -> (d0)>\n", - "#map1 = affine_map<(d0) -> ()>\n", - "#map2 = affine_map<(d0, d1) -> (d0, d1)>\n", - "#map3 = affine_map<(d0, d1) -> (d1)>\n", "module @module {\n", - " func.func @reduce_sum_1d(%arg0: tensor) -> tensor attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", - " %0 = call @compute_reduce_sum_1d(%arg0) : (tensor) -> tensor\n", - " return %0 : tensor\n", - " }\n", - " func.func private @compute_reduce_sum_1d(%arg0: tensor) -> tensor {\n", - " %c0_i32 = arith.constant 0 : i32\n", - " %0 = tensor.empty() : tensor\n", - " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor) -> tensor\n", - " %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = [\"reduction\"]} ins(%arg0 : tensor) outs(%1 : tensor) {\n", - " ^bb0(%in: i32, %out: i32):\n", - " %3 = arith.addi %in, %out : i32\n", - " linalg.yield %3 : i32\n", - " } -> tensor\n", - " return %2 : tensor\n", - " }\n", - " func.func @reduce_sum_2d(%arg0: tensor) -> tensor<3xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", - " %0 = call @compute_reduce_sum_2d(%arg0) : (tensor) -> tensor<3xi32>\n", - " return %0 : tensor<3xi32>\n", - " }\n", - " func.func private @compute_reduce_sum_2d(%arg0: tensor) -> tensor<3xi32> {\n", - " %c0_i32 = arith.constant 0 : i32\n", - " %0 = tensor.empty() : tensor<3xi32>\n", - " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<3xi32>) -> tensor<3xi32>\n", - " %2 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"reduction\", \"parallel\"]} ins(%arg0 : tensor) outs(%1 : tensor<3xi32>) {\n", - " ^bb0(%in: i32, %out: i32):\n", - " %3 = arith.addi %in, %out : i32\n", - " linalg.yield %3 : i32\n", - " } -> tensor<3xi32>\n", - " return %2 : tensor<3xi32>\n", + " func.func @reduce_sum_1d(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[],si64> attributes {torch.assume_strict_symbolic_shapes} {\n", + " %none = torch.constant.none\n", + " %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?],si32>, !torch.none -> !torch.vtensor<[],si64>\n", + " return %0 : !torch.vtensor<[],si64>\n", " }\n", - " func.func @add_one(%arg0: tensor) -> tensor attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", - " %0 = call @compute_add_one(%arg0) : (tensor) -> tensor\n", - " return %0 : tensor\n", + " func.func @reduce_sum_2d(%arg0: !torch.vtensor<[?,3],si32>) -> !torch.vtensor<[3],si64> attributes {torch.assume_strict_symbolic_shapes} {\n", + " %int0 = torch.constant.int 0\n", + " %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n", + " %false = torch.constant.bool false\n", + " %none = torch.constant.none\n", + " %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,3],si32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3],si64>\n", + " return %1 : !torch.vtensor<[3],si64>\n", " }\n", - " func.func private @compute_add_one(%arg0: tensor) -> tensor {\n", - " %c0 = arith.constant 0 : index\n", - " %c1_i32 = arith.constant 1 : i32\n", - " %dim = tensor.dim %arg0, %c0 : tensor\n", - " %0 = tensor.empty(%dim) : tensor\n", - " %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = [\"parallel\"]} ins(%arg0 : tensor) outs(%0 : tensor) {\n", - " ^bb0(%in: i32, %out: i32):\n", - " %2 = arith.addi %in, %c1_i32 : i32\n", - " linalg.yield %2 : i32\n", - " } -> tensor\n", - " return %1 : tensor\n", + " func.func @add_one(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si32> attributes {torch.assume_strict_symbolic_shapes} {\n", + " %int1 = torch.constant.int 1\n", + " %int1_0 = torch.constant.int 1\n", + " %0 = torch.aten.add.Scalar %arg0, %int1, %int1_0 : !torch.vtensor<[?],si32>, !torch.int, !torch.int -> !torch.vtensor<[?],si32>\n", + " return %0 : !torch.vtensor<[?],si32>\n", " }\n", "}\n" ] @@ -377,25 +376,22 @@ { "cell_type": "code", "source": [ - "# Export and compile.\n", - "exported_output = aot.export(DynamicShapesModule)\n", - "\n", "# Compile to a file on disk for usage outside of Python.\n", "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n", - "exported_output.compile(save_to=flatbuffer_path)\n", + "export_output.compile(save_to=flatbuffer_path)\n", "print(f\"Wrote compiled program to path '{flatbuffer_path}'\")\n", "\n", "# Compile into memory for testing.\n", - "binary = exported_output.compile(save_to=None)" + "binary = export_output.compile(save_to=None)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0PGyH1tvI_Ic", - "outputId": "23b53928-4d77-461f-e4b8-b2c8ffb25ef0" + "outputId": "2ac3f280-1834-4d6c-f5b0-c9b470549ca7" }, - "execution_count": 7, + "execution_count": 9, "outputs": [ { "output_type": "stream", @@ -429,9 +425,9 @@ "base_uri": "https://localhost:8080/" }, "id": "9ilJY15BI_LD", - "outputId": "57db6e52-83f1-4283-fc08-31e743cc9b42" + "outputId": "f20aec4f-353e-4793-f9f1-066006d4471b" }, - "execution_count": 8, + "execution_count": 10, "outputs": [ { "output_type": "stream", @@ -476,9 +472,9 @@ "height": 86 }, "id": "dgaXpdiWuGtx", - "outputId": "dc0fbca1-c5b0-44f9-e1ff-9bf1307c049f" + "outputId": "94823b69-1095-4a97-9974-7d36fb3e2fb8" }, - "execution_count": 9, + "execution_count": 11, "outputs": [ { "output_type": "stream", @@ -486,7 +482,7 @@ "text": [ "Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n", " adding: dynamic_shapes_cpu.vmfb (deflated 66%)\n", - " adding: dynamic_shapes.mlir (deflated 82%)\n", + " adding: dynamic_shapes.mlir (deflated 72%)\n", "Downloading the artifacts zip file...\n" ] }, @@ -549,7 +545,7 @@ "" ], "application/javascript": [ - "download(\"download_e2630f9b-e811-4164-b2d8-80cf52f17145\", \"dynamic_shapes_colab_artifacts.zip\", 5699)" + "download(\"download_7377c999-5cd8-4987-95c4-921d56969f65\", \"dynamic_shapes_colab_artifacts.zip\", 5472)" ] }, "metadata": {} @@ -557,4 +553,4 @@ ] } ] -} +} \ No newline at end of file