diff --git a/.github/workflows/ci-mlir.yml b/.github/workflows/ci-mlir.yml index ff7d69e298..6d4b1bb0d7 100644 --- a/.github/workflows/ci-mlir.yml +++ b/.github/workflows/ci-mlir.yml @@ -109,13 +109,6 @@ jobs: export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ uv run pytest --nbval docs/mlir_interoperation.ipynb --maxfail 1 -vv - - name: Test ONNX-dependent marimo notebooks - run: | - cd xdsl - # Add mlir-opt to the path - export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ - uv run make tests-marimo-onnx - - name: Combine coverage data run: | cd xdsl diff --git a/Makefile b/Makefile index 80089b02ee..026a193278 100644 --- a/Makefile +++ b/Makefile @@ -98,30 +98,10 @@ tests-marimo: uv-installed done @echo "All marimo tests passed successfully." -.PHONY: tests-marimo-onnx -tests-marimo-onnx: uv-installed - @if uv run python -c "import onnx" > /dev/null 2>&1; then \ - echo "onnx is installed, running tests."; \ - if ! command -v mlir-opt > /dev/null 2>&1; then \ - echo "MLIR is not installed, skipping tests."; \ - exit 0; \ - fi; \ - for file in docs/marimo/onnx/*.py; do \ - echo "Running $$file"; \ - error_message=$$(uv run python3 "$$file" 2>&1) || { \ - echo "Error running $$file"; \ - echo "$$error_message"; \ - exit 1; \ - }; \ - done; \ - echo "All marimo onnx tests passed successfully."; \ - else \ - echo "onnx is not installed, skipping tests."; \ - fi # run all tests .PHONY: tests-functional -tests-functional: pytest tests-toy filecheck pytest-nb tests-marimo tests-marimo-onnx +tests-functional: pytest tests-toy filecheck pytest-nb tests-marimo @echo All functional tests done. # run all tests diff --git a/README.md b/README.md index 7ccf4e4e54..9886b3a809 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ In order to keep the set of dependencies ot a minimum, these extra dependencies specified explicitly. To install these, use: ``` bash -pip install xdsl[gui,jax,riscv,onnx] +pip install xdsl[gui,jax,riscv] ``` To install the testing/development dependencies, use: diff --git a/docs/marimo/__marimo__/linalg_snitch.ipynb b/docs/marimo/__marimo__/linalg_snitch.ipynb index 255bf55707..b3490c91e6 100644 --- a/docs/marimo/__marimo__/linalg_snitch.ipynb +++ b/docs/marimo/__marimo__/linalg_snitch.ipynb @@ -591,7 +591,7 @@ "riscv_op_counter = OpCounter()\n", "riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,))\n", "\n", - "register_implementations(riscv_interpreter, riscv_ctx, include_onnx=False)\n", + "register_implementations(riscv_interpreter, riscv_ctx)\n", "\n", "riscv_interpreter.call_op(\"matmul\", (a_shaped.data_ptr.raw, b_shaped.data_ptr.raw, riscv_c_shaped.data_ptr.raw))\n", "\n", @@ -633,7 +633,7 @@ "\n", "snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape)\n", "\n", - "register_implementations(snitch_interpreter, snitch_stream_ctx, include_onnx=False)\n", + "register_implementations(snitch_interpreter, snitch_stream_ctx)\n", "\n", "snitch_interpreter.call_op(\n", " \"matmul\",\n", diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index f1120277da..18d58e2f15 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -513,7 +513,7 @@ def _(TypedPtr, a_shape, b_shape, c_shape, mo, riscv_ctx, riscv_module): riscv_op_counter = OpCounter() riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,)) - register_implementations(riscv_interpreter, riscv_ctx, include_onnx=False) + register_implementations(riscv_interpreter, riscv_ctx) riscv_interpreter.call_op("matmul", (a_shaped.data_ptr.raw, b_shaped.data_ptr.raw, riscv_c_shaped.data_ptr.raw)) @@ -566,7 +566,7 @@ def _( snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape) - register_implementations(snitch_interpreter, snitch_stream_ctx, include_onnx=False) + register_implementations(snitch_interpreter, snitch_stream_ctx) snitch_interpreter.call_op( "matmul", diff --git a/docs/marimo/onnx/README.md b/docs/marimo/onnx/README.md deleted file mode 100644 index 648b1d0dc4..0000000000 --- a/docs/marimo/onnx/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Marimo notebooks that depend on ONNX and mlir-opt - -For these notebooks to run as intended, the `onnx` package needs to be installed and `mlir-opt` needs to be in the path. -Please see the [MLIR Interoperation](../../mlir_interoperation.md) document for more information. diff --git a/docs/marimo/onnx/__marimo__/onnx_demo.ipynb b/docs/marimo/onnx/__marimo__/onnx_demo.ipynb deleted file mode 100644 index fc4a9e1416..0000000000 --- a/docs/marimo/onnx/__marimo__/onnx_demo.ipynb +++ /dev/null @@ -1,421 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "Hbol", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "source": [ - "# ONNX to Snitch\n", - "\n", - "This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "MJUe", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "For example, here is a slider, which can take on values from 1 to 4.\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "rank = mo.ui.slider(1, 4, value=2, label=\"Rank\")\n", - "\n", - "mo.md(\n", - " f\"\"\"\n", - " For example, here is a slider, which can take on values from 1 to 4.\n", - "\n", - " {rank}\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "vblA", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "We use the slider to determine the shape of our inputs and outputs:\n", - "
A: 2x3xf64\n",
-       "B: 2x3xf64\n",
-       "C: 2x3xf64\n",
-       "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "shape = tuple(range(2, 2 + rank.value))\n", - "\n", - "mo.md(\n", - " f\"\"\"\n", - " We use the slider to determine the shape of our inputs and outputs:\n", - "\n", - " ```\n", - " A: {'x'.join(str(dim) for dim in shape)}xf64\n", - " B: {'x'.join(str(dim) for dim in shape)}xf64\n", - " C: {'x'.join(str(dim) for dim in shape)}xf64\n", - " ```\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bkHC", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "

The ONNX model

\n", - "We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape (2, 3)
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "mo.md(\n", - " f\"\"\"\n", - " ### The ONNX model\n", - "\n", - " We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape}\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "lEQa", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "
Traceback (most recent call last):\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell\n",
-      "    exec(cell.body, glbls)\n",
-      "  File "/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_lEQa_.py", line 1, in <module>\n",
-      "    import onnx\n",
-      "ModuleNotFoundError: No module named 'onnx'\n",
-      "
\n", - "
" - ] - } - ], - "source": [ - "import onnx\n", - "from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "PKri", - "metadata": {}, - "outputs": [], - "source": [ - "# Create one input (ValueInfoProto)\n", - "X1 = helper.make_tensor_value_info(\"X1\", TensorProto.DOUBLE, shape)\n", - "X2 = helper.make_tensor_value_info(\"X2\", TensorProto.DOUBLE, shape)\n", - "\n", - "# Create one output (ValueInfoProto)\n", - "Y = helper.make_tensor_value_info(\"Y\", TensorProto.DOUBLE, shape)\n", - "\n", - "# Create a node (NodeProto) - This is based on Pad-11\n", - "node_def = helper.make_node(\n", - " \"Sub\", # node name\n", - " [\"X1\", \"X2\"], # inputs\n", - " [\"Y\"], # outputs\n", - ")\n", - "\n", - "# Create the graph (GraphProto)\n", - "graph_def = helper.make_graph(\n", - " [node_def],\n", - " \"main_graph\",\n", - " [X1, X2],\n", - " [Y],\n", - ")\n", - "\n", - "# Set opset version to 18\n", - "opset_import = [helper.make_operatorsetid(\"\", 18)]\n", - "\n", - "# Create the model (ModelProto) without using helper.make_model\n", - "model_def = helper.make_model(\n", - " graph_def, producer_name=\"onnx-example\", opset_imports=opset_import\n", - ")\n", - "\n", - "onnx.checker.check_model(model_def)" - ] - }, - { - "cell_type": "markdown", - "id": "Xref", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "source": [ - "ONNX uses a serialized binary format for neural networks, but can also print a string format, which can be useful for debugging.\n", - "Here is the textual format of our model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "SFPL", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "mo.accordion(\n", - " {\n", - " \"ONNX Graph\": mo.plain_text(f\"{model_def}\"),\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "BYtC", - "metadata": {}, - "outputs": [], - "source": [ - "mo.md(f\"\"\"\n", - "### Converting to `linalg`\n", - "\n", - "Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it:\n", - "\n", - "{xmo.module_html(init_module)}\n", - "\"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "RGSE", - "metadata": {}, - "outputs": [], - "source": [ - "init_module = build_module(model_def.graph)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Kclp", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "ctx = MLContext()\n", - "\n", - "for dialect_name, dialect_factory in get_all_dialects().items():\n", - " ctx.register_dialect(dialect_name, dialect_factory)" - ] - }, - { - "cell_type": "markdown", - "id": "emfo", - "metadata": {}, - "source": [ - "xDSL seamlessly interoperates with MLIR, we the `mlir-opt` tool to compile the input to a form that we want to process:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Hstk", - "metadata": {}, - "outputs": [], - "source": [ - "bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html(\n", - " ctx,\n", - " init_module,\n", - " (\n", - " (\n", - " mo.md(\n", - " \"\"\"\\\n", - "We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style:\n", - "\"\"\"\n", - " ),\n", - " ConvertOnnxToLinalgPass()\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"\n", - "We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation:\n", - "\"\"\"\n", - " ),\n", - " MLIROptPass(\n", - " generic=False,\n", - " arguments=[\"--linalg-generalize-named-ops\"]\n", - " )\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"We prepare the result tensors for bufferization:\"\"\"\n", - " ),\n", - " EmptyTensorToAllocTensorPass()\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"We then use MLIR to bufferize our function:\"\"\"\n", - " ),\n", - " MLIROptPass(\n", - " arguments=[\n", - " \"--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map\",\n", - " ]\n", - " )\n", - " )\n", - " )\n", - ")\n", - "\n", - "linalg_html" - ] - }, - { - "cell_type": "markdown", - "id": "nWHF", - "metadata": {}, - "source": [ - "From here we can use a number of backends to generate executable code, like LLVM, or RISC-V assembly directly.\n", - "Please see other notebooks for details" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "iLit", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "
Traceback (most recent call last):\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell\n",
-      "    exec(cell.body, glbls)\n",
-      "  File "/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_iLit_.py", line 2, in <module>\n",
-      "    from xdsl.frontend.onnx.ir_builder import build_module\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/xdsl/frontend/onnx/ir_builder.py", line 1, in <module>\n",
-      "    from onnx import GraphProto, NodeProto, ValueInfoProto\n",
-      "ModuleNotFoundError: No module named 'onnx'\n",
-      "
\n", - "
" - ] - } - ], - "source": [ - "from xdsl.context import MLContext\n", - "from xdsl.frontend.onnx.ir_builder import build_module\n", - "from xdsl.ir import Attribute, SSAValue\n", - "from xdsl.passes import PipelinePass\n", - "from xdsl.tools.command_line_tool import get_all_dialects\n", - "from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass\n", - "from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass\n", - "from xdsl.transforms.mlir_opt import MLIROptPass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ZHCJ", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "import marimo as mo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ROlb", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "import xdsl.utils.marimo as xmo" - ] - } - ], - "metadata": {}, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/marimo/onnx/onnx_demo.py b/docs/marimo/onnx/onnx_demo.py deleted file mode 100644 index c9b3ac6a77..0000000000 --- a/docs/marimo/onnx/onnx_demo.py +++ /dev/null @@ -1,271 +0,0 @@ -import marimo - -__generated_with = "0.10.9" -app = marimo.App(auto_download=["ipynb"]) - - -@app.cell(hide_code=True) -def _(mo): - mo.md( - """ - # ONNX to Snitch - - This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state. - """ - ) - return - - -@app.cell(hide_code=True) -def _(mo): - rank = mo.ui.slider(1, 4, value=2, label="Rank") - - mo.md( - f""" - For example, here is a slider, which can take on values from 1 to 4. - - {rank} - """ - ) - return (rank,) - - -@app.cell(hide_code=True) -def _(mo, rank): - shape = tuple(range(2, 2 + rank.value)) - - mo.md( - f""" - We use the slider to determine the shape of our inputs and outputs: - - ``` - A: {'x'.join(str(dim) for dim in shape)}xf64 - B: {'x'.join(str(dim) for dim in shape)}xf64 - C: {'x'.join(str(dim) for dim in shape)}xf64 - ``` - """ - ) - return (shape,) - - -@app.cell(hide_code=True) -def _(mo, shape): - mo.md( - f""" - ### The ONNX model - - We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape} - """ - ) - return - - -@app.cell(hide_code=True) -def _(): - import onnx - from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper - return ( - AttributeProto, - GraphProto, - TensorProto, - ValueInfoProto, - helper, - onnx, - ) - - -@app.cell -def _(TensorProto, helper, onnx, shape): - # Create one input (ValueInfoProto) - X1 = helper.make_tensor_value_info("X1", TensorProto.DOUBLE, shape) - X2 = helper.make_tensor_value_info("X2", TensorProto.DOUBLE, shape) - - # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, shape) - - # Create a node (NodeProto) - This is based on Pad-11 - node_def = helper.make_node( - "Sub", # node name - ["X1", "X2"], # inputs - ["Y"], # outputs - ) - - # Create the graph (GraphProto) - graph_def = helper.make_graph( - [node_def], - "main_graph", - [X1, X2], - [Y], - ) - - # Set opset version to 18 - opset_import = [helper.make_operatorsetid("", 18)] - - # Create the model (ModelProto) without using helper.make_model - model_def = helper.make_model( - graph_def, producer_name="onnx-example", opset_imports=opset_import - ) - - onnx.checker.check_model(model_def) - return X1, X2, Y, graph_def, model_def, node_def, opset_import - - -@app.cell(hide_code=True) -def _(mo): - mo.md( - """ - ONNX uses a serialized binary format for neural networks, but can also print a string format, which can be useful for debugging. - Here is the textual format of our model: - """ - ) - return - - -@app.cell(hide_code=True) -def _(mo, model_def): - mo.accordion( - { - "ONNX Graph": mo.plain_text(f"{model_def}"), - } - ) - return - - -@app.cell -def _(init_module, mo, xmo): - mo.md(f""" - ### Converting to `linalg` - - Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it: - - {xmo.module_html(init_module)} - """ - ) - return - - -@app.cell -def _(build_module, model_def): - init_module = build_module(model_def.graph) - return (init_module,) - - -@app.cell(hide_code=True) -def _(MLContext, get_all_dialects): - ctx = MLContext() - - for dialect_name, dialect_factory in get_all_dialects().items(): - ctx.register_dialect(dialect_name, dialect_factory) - return ctx, dialect_factory, dialect_name - - -@app.cell -def _(mo): - mo.md("""xDSL seamlessly interoperates with MLIR, we the `mlir-opt` tool to compile the input to a form that we want to process:""") - return - - -@app.cell -def _( - ConvertOnnxToLinalgPass, - EmptyTensorToAllocTensorPass, - MLIROptPass, - ctx, - init_module, - mo, - xmo, -): - bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html( - ctx, - init_module, - ( - ( - mo.md( - """\ - We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style: - """ - ), - ConvertOnnxToLinalgPass() - ), - ( - mo.md( - """ - We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation: - """ - ), - MLIROptPass( - generic=False, - arguments=["--linalg-generalize-named-ops"] - ) - ), - ( - mo.md( - """We prepare the result tensors for bufferization:""" - ), - EmptyTensorToAllocTensorPass() - ), - ( - mo.md( - """We then use MLIR to bufferize our function:""" - ), - MLIROptPass( - arguments=[ - "--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map", - ] - ) - ) - ) - ) - - linalg_html - return bufferized_ctx, bufferized_module, linalg_html - - -@app.cell -def _(mo): - mo.md( - """ - From here we can use a number of backends to generate executable code, like LLVM, or RISC-V assembly directly. - Please see other notebooks for details - """ - ) - return - - -@app.cell -def _(): - from xdsl.context import MLContext - from xdsl.frontend.onnx.ir_builder import build_module - from xdsl.ir import Attribute, SSAValue - from xdsl.passes import PipelinePass - from xdsl.tools.command_line_tool import get_all_dialects - from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass - from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass - from xdsl.transforms.mlir_opt import MLIROptPass - return ( - Attribute, - ConvertOnnxToLinalgPass, - EmptyTensorToAllocTensorPass, - MLContext, - MLIROptPass, - PipelinePass, - SSAValue, - build_module, - get_all_dialects, - ) - - -@app.cell(hide_code=True) -def _(): - import marimo as mo - return (mo,) - - -@app.cell(hide_code=True) -def _(): - import xdsl.utils.marimo as xmo - return (xmo,) - - -if __name__ == "__main__": - app.run() diff --git a/pyproject.toml b/pyproject.toml index 7b2deed320..0384b532ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dev = [ ] gui = ["textual==1.0.0", "pyclip==0.7"] jax = ["jax==0.4.38", "numpy==2.2.1"] -onnx = ["onnx==1.17.0", "numpy==2.2.1"] riscv = ["riscemu==2.2.7"] [project.urls] @@ -72,12 +71,8 @@ extraPaths = ["tests"] "include" = ["docs", "xdsl", "tests"] "exclude" = [ "tests/dialects/test_memref.py", - "tests/frontend/onnx/test_build_onnx_ir.py", - "tests/frontend/onnx/test_type.py", "tests/test_frontend_op_resolver.py", "tests/test_frontend_python_code_check.py", - "xdsl/frontend/onnx/ir_builder.py", - "xdsl/frontend/onnx/type.py", ] "ignore" = [ "docs/marimo", diff --git a/tests/dialects/onnx/test_broadcast.py b/tests/dialects/onnx/test_broadcast.py deleted file mode 100644 index c6e49b0c2a..0000000000 --- a/tests/dialects/onnx/test_broadcast.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from xdsl.dialects.onnx import ( - multidirectional_broadcast_shape, - unidirectional_broadcast_shape, -) - -m_test_cases: list[tuple[tuple[list[int], list[int]], list[int] | None]] = [ - (([2, 3, 4, 5], []), [2, 3, 4, 5]), - (([2, 3, 4, 5], [5]), [2, 3, 4, 5]), - (([4, 5], [2, 3, 4, 5]), [2, 3, 4, 5]), - (([1, 4, 5], [2, 3, 1, 5]), [2, 3, 4, 5]), - (([3, 4, 5], [2, 1, 1, 1]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [6, 1]), None), - (([1, 3, 4, 1], [5]), None), - (([4, 5, 6], [2, 3, 4, 5]), None), - (([1, 4, 5], [2, 3, 1, 5]), None), - (([3, 4, 5], [2, 1, 1, 1]), None), -] - -u_test_cases: list[tuple[tuple[list[int], list[int]], list[int] | None]] = [ - (([2, 3, 4, 5], []), [2, 3, 4, 5]), - (([2, 3, 4, 5], [5]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [2, 1, 1, 5]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [1, 3, 1, 5]), [2, 3, 4, 5]), - (([], [2, 3, 4, 5]), None), - (([5], [2, 3, 4, 5]), None), - (([2, 1, 1, 5], [2, 3, 4, 5]), None), - (([2, 3, 5], [1, 3, 1, 5]), None), -] - - -# Multidirectional Broadcasting Tests -@pytest.mark.parametrize("input_shapes, expected_result", m_test_cases) -def multi_test_broadcast_shape( - input_shapes: tuple[list[int], list[int]], expected_result: list[int] | None -): - lhs, rhs = input_shapes - result = multidirectional_broadcast_shape(lhs, rhs) - assert result == expected_result - - -# Unidirectional Broadcasting Tests -@pytest.mark.parametrize("input_shapes, expected_result", u_test_cases) -def uni_test_broadcast_shape( - input_shapes: tuple[list[int], list[int]], expected_result: list[int] | None -): - lhs, rhs = input_shapes - result = unidirectional_broadcast_shape(lhs, rhs) - assert result == expected_result diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir deleted file mode 100644 index b15ba25761..0000000000 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ /dev/null @@ -1,571 +0,0 @@ -// RUN: xdsl-opt --verify-diagnostics --split-input-file %s | filecheck %s - -// Non-broadcastable operands are not allowed. - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<3x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (3, 2) - %res_add = "onnx.Add"(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<2x4xf32>, tensor<3x2xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<1x4xf32>) - - // CHECK: Operation does not verify: result shape [2, 4] does not match result type tensor<1x4xf32> - %res_sub = "onnx.Sub"(%t0, %t1) {onnx_node_name = "/Sub"} : (tensor<2x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (f32, tensor<2x4xf32>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_mul = "onnx.Mul"(%t0, %t1) {onnx_node_name = "/Mul"} : (f32, tensor<2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<2x4xi32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_div = "onnx.Div"(%t0, %t1) {onnx_node_name = "/Div"} : (tensor<2x4xf32>, tensor<2x4xi32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<2x4xf32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: Operation does not verify: Mismatch between operand type and res type of onnx.Relu - %res_relu = "onnx.Relu"(%t0) {onnx_node_name = "/Relu"} : (tensor<2x4xf32>) -> tensor<3x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<2x4xf32>, tensor<3x2xf32>, tensor<3x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (3, 2) - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (tensor<2x4xf32>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (f32, tensor<2x4xf32>,tensor<2x4xf32>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_gemm= "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (f32, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"(): () -> (tensor<5x2xf32>, tensor<2x1xf32>, tensor<5x4xf32>) - - //CHECK: Operation does not verify: result shape [5, 4] does not match result type tensor<5x2xf32> - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) : (tensor<5x2xf32>, tensor<2x1xf32>, tensor<5x4xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<2x4xf32>, tensor<2x4xi32>, tensor<2x4xf32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (tensor<2x4xf32>, tensor<2x4xi32>, tensor<2x4xf32>) -> tensor<2x4xf32> - - } - - // ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) - // CHECK: Operation does not verify: tensor A should be a 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"}: (tensor<5x3x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2x3xf32>, tensor<5x2xf32>) - // CHECK: Operation does not verify: tensor B should be a 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2x3xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2x7xf32>) - // CHECK: Operation does not verify: tensor C should be a 1D tensor or 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm", beta = 47.0 : f32}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2x7xf32>) -> tensor<5x3xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (f32, tensor<2x4xi64>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"} : (f32, tensor<2x4xi64>) -> tensor<2x4xi64> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<4x3x2xf32>, tensor<1xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<4x3x2xf32>, tensor<1xi64>) -> tensor<4x3x2xi32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: Operation does not verify: Input tensor's shape and output tensor's shape must have the same number of elements - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi64>) -> tensor<6x9xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi32>) - - // CHECK: Operation does not verify: shape element type has to be a 64-bit signless integer - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi32>) -> tensor<6x9xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: vector<6x9x5xf32> should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi64>) -> vector<6x9x5xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (vector<6x9x5xf32>, tensor<3xi64>) - - // CHECK: operand at position 0 does not verify: - // CHECK: vector<6x9x5xf32> should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (vector<6x9x5xf32>, tensor<3xi64>) -> tensor<6x9x5xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<5x5xf32>, tensor<2x2xi64>) - - //CHECK: Operation does not verify: Shape tensor must have a rank one - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"}: (tensor<5x5xf32>, tensor<2x2xi64>) -> tensor<5x5xf32> - -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<3x3xf32>, tensor) - - //CHECK: Operation does not verify: Shape tensor rank must not be equal to -1 - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"}: (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x3xf32>) - - //CHECK: Operation does not verify: Mismatch between operand type and res type of onnx.Abs - %res_abs = "onnx.Abs"(%t0) {onnx_node_name = "/Abs"}: (tensor<3x3xf32>) -> tensor<2x3xf32> - -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (f32, tensor<1x1x3x3xf32>, none) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {onnx_node_name = "/Conv"} : (f32, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv"} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xi32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: bias must be 1D - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [4 : i64, 4 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<4x2xf32>) -> tensor<1x1x3x3xf32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: kernel shape rank and weight tensor rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [4 : i64, 4 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: dilation value must be non zero positive - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [-2 : i64, -2: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: dilations rank and kernel shape rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64, 3: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: group value must be nonnegative - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: stride value must be non zero positive - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [-2 : i64, -2: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: strides rank and kernel shape rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: pads value must be nonnegative - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [-1 : i64, -1: i64, -1: i64, -1: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: pads rank is not twice the kernel shape rank - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [1 : i64, 1: i64, 1: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: Invalid auto_pad string. Must be one of ['NOTSET', 'SAME_UPPER', 'SAME_LOWER', 'VALID'] - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "INVALID", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - - } - - - // ----- - -builtin.module { - - // CHECK: f32 should be of base attribute tensor - %res_constant = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> f32 - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value attribute type must be of type TensorType - %res_constant = "onnx.Constant"() {value = dense<[3.0]> : vector<1xf32>} : () -> tensor<1xf32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value_int element type has to be a 64-bit signless integer - %res_constant = "onnx.Constant"() {value_int = 4 : i32} : () -> tensor<1xf32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value_ints elements type has to be a 64-bit signless integer - %res_constant = "onnx.Constant"() {value_ints = [1: i64, 2: i32, 3: i64]} : () -> tensor<3xi32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: Only one value attribute must be provided, but 2 were specified - %res_constant = "onnx.Constant"() {value_ints = [1: i64, 1: i64], value_int = 3: i64} : () -> tensor<3xi64> - - } - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (f32) - - // CHECK: operand at position 0 does not verify: - // CHECK: Expected tensor or memref type, got f32 - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {onnx_node_name = "/MaxPoolSingleOut"} : (f32) -> tensor<5x5x32x32xf32> -} - -// ----- - -builtin.module { - %t0= "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut"} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xi32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: Invalid auto_pad string. Must be one of ['NOTSET', 'SAME_UPPER', 'SAME_LOWER', 'VALID'] - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "INVALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: ceil value must be either zero or one - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 2 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: input data and kernel shape rank mismatch: (2) vs (1) - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: dilation value must be non zero positive - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [-1 : i64, -1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: dilations rank (3) and kernel shape rank (2) are not the same - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64, 1: i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: column major storage order not implemented yet - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 1 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: stride value must be non zero positive - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [-1 : i64, -1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: strides rank (3) and kernel shape rank (2) are not the same - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64, 1: i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: pads value must be nonnegative - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [-2 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: pads rank (5) is not twice the kernel shape rank (2) - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64, 0: i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4x3xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: input matrix A should be a 2D tensor - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4x3xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2x3xf32>) - - // CHECK: Operation does not verify: input matrix B should be a 2D tensor - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2x3xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<5x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (5, 2) - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<5x2xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: result shape [2, 2] does not match result type [2, 3] - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: permutation can not contain more than one occurrence of the same dimension: dimension #1 appears 2 times. - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 1 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: permutation can only contain values between 0 and 2-1: dimension #1 value is 2 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 2 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<1x3x4xf32>) - // CHECK: Operation does not verify: permutation and inputs dimensions must have the same size: #dimensions input is 3, #dimension perimutation is 2 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<1x3x4xf32>) -> tensor<3x1x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: incorrect output shape: output dimension #0 should be equal to 4 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<3x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<1x2x4xf32>) - - // CHECK: Operation does not verify: axes to squeeze must be between 0 and 2, axes: 3 - %res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 3 : i64} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - - // CHECK: Operation does not verify: tensor input shape (3, 4) is not equal to tensor output shape (7, 3) - %res_sigmoid = "onnx.Sigmoid"(%t0) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<7x3xf32> -} diff --git a/tests/filecheck/dialects/onnx/onnx_ops.mlir b/tests/filecheck/dialects/onnx/onnx_ops.mlir deleted file mode 100644 index ed12de520a..0000000000 --- a/tests/filecheck/dialects/onnx/onnx_ops.mlir +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: XDSL_ROUNDTRIP - -%t0, %t1 = "test.op"(): () -> (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -%t2, %t3 = "test.op"(): () -> (tensor<3x2xf32>, tensor<1x2xf32>) -%t4, %t5 = "test.op"(): () -> (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -%t6, %t7 = "test.op"(): () -> (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -%t8 = "test.op"(): () -> (tensor<3x4xf32>) -%t9, %t10, %t11 = "test.op"(): () -> (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -%t12, %t13, %t14 = "test.op"(): () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -%t15,%t16 = "test.op"(): () -> (tensor<48x256x64xf32>, tensor<3xi64>) -%t17,%t18 = "test.op"(): () -> (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -%t19 = "test.op"(): () -> (tensor<10x10xf32>) -%t20,%t21,%t22 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -%t23,%t24,%t25 = "test.op"(): () -> (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -%t26 = "test.op"(): () -> (tensor<5x5x32x32xf32>) -%t27, %t28, %t29 = "test.op"(): () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) - -%res_add = "onnx.Add"(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -> tensor<1x2x6xf32> -// CHECK: %res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -> tensor<1x2x6xf32> - -%res_sub = "onnx.Sub"(%t2, %t3) {onnx_node_name = "/Sub"}: (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<3x2xf32> -// CHECK: %res_sub = onnx.Sub(%t2, %t3) {onnx_node_name = "/Sub"} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<3x2xf32> - -%res_mul = "onnx.Mul"(%t4, %t5) {onnx_node_name = "/Mul"}: (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -> tensor<3x4x2xf32> -// CHECK: %res_mul = onnx.Mul(%t4, %t5) {onnx_node_name = "/Mul"} : (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -> tensor<3x4x2xf32> - -%res_div = "onnx.Div"(%t6, %t7) {onnx_node_name = "/Div"}: (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -> tensor<2x5x6x3xf32> -// CHECK: %res_div = onnx.Div(%t6, %t7) {onnx_node_name = "/Div"} : (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -> tensor<2x5x6x3xf32> - -%res_relu = "onnx.Relu"(%t8) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %res_relu = onnx.Relu(%t8) {onnx_node_name = "/Relu"} : (tensor<3x4xf32>) -> tensor<3x4xf32> - -%res_gemm = "onnx.Gemm"(%t9, %t10, %t11) {onnx_node_name = "/Gemm"}: (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %res_gemm = onnx.Gemm(%t9, %t10, %t11) {onnx_node_name = "/Gemm"} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - -%res_gemm_1 = "onnx.Gemm"(%t12, %t13, %t14) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -// CHECK: %res_gemm_1 = onnx.Gemm(%t12, %t13, %t14) {onnx_node_name = "/Gemm"} : (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> - -%res_reshape = "onnx.Reshape"(%t15, %t16) {onnx_node_name = "/Reshape", "allowzero" = 1 : i64}: (tensor<48x256x64xf32>, tensor<3xi64>) -> tensor<48x256x64xf32> -//CHECK: res_reshape = onnx.Reshape(%t15, %t16) {onnx_node_name = "/Reshape", allowzero = 1 : i64} : (tensor<48x256x64xf32>, tensor<3xi64>) -> tensor<48x256x64xf32> - -%res_reshape_1 = "onnx.Reshape"(%t17, %t18) {onnx_node_name = "/Reshape"}: (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -> tensor<1x120xf32> -//CHECK: %res_reshape_1 = onnx.Reshape(%t17, %t18) {onnx_node_name = "/Reshape"} : (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -> tensor<1x120xf32> - -%res_abs = "onnx.Abs"(%t19) {onnx_node_name = "/Abs"}: (tensor<10x10xf32>) -> tensor<10x10xf32> -// CHECK: %res_abs = onnx.Abs(%t19) {onnx_node_name = "/Abs"} : (tensor<10x10xf32>) -> tensor<10x10xf32> - -%res_conv = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [1 : i64, 1 : i64, 1: i64, 1 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x5x5xf32> -//CHECK: %res_conv = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [1 : i64, 1 : i64], pads = [1 : i64, 1 : i64, 1 : i64, 1 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x5x5xf32> - -%res_conv_1 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -//CHECK: %res_conv_1 = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [1 : i64, 1 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -%res_conv_2 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "SAME_LOWER", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -//CHECK: %res_conv_2 = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "SAME_LOWER", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -%res_conv_3 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [1 : i64, 1 : i64, 1 : i64, 1 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x3xf32> -//CHECK: %res_conv_3 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [1 : i64, 1 : i64, 1 : i64, 1 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x3xf32> - -%res_conv_4 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x2xf32> -//CHECK: %res_conv_4 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x2xf32> - -%res_conv_5 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [1 : i64, 0 : i64, 1 : i64, 0 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x2xf32> -//CHECK: %res_conv_5 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [1 : i64, 0 : i64, 1 : i64, 0 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x2xf32> - -%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "VALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]}: (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> -//CHECK: %res_max_pool_single_out = onnx.MaxPoolSingleOut(%t26) {onnx_node_name = "/MaxPoolSingleOut", auto_pad = "VALID", ceil_mode = 0 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64], storage_order = 0 : i64, strides = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> - -"onnx.EntryPoint"() {onnx_node_name = "/EntryPoint", "func" = @main_graph} : () -> () -//CHECK: "onnx.EntryPoint"() {onnx_node_name = "/EntryPoint", func = @main_graph} : () -> () - -%res_constant = onnx.Constant dense<1> : tensor<1xi64> -//CHECK: %res_constant = onnx.Constant dense<1> : tensor<1xi64> - -%res_constant_1 = onnx.Constant dense<[5, 5, 16, 2]> : tensor<4xi64> -//CHECK: %res_constant_1 = onnx.Constant dense<[5, 5, 16, 2]> : tensor<4xi64> - -%res_gemm_2 = "onnx.Gemm"(%t27, %t28, %t29) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> -// CHECK: %res_gemm_2 = onnx.Gemm(%t27, %t28, %t29) {onnx_node_name = "/Gemm", alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 1 : si64} : (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> - -%res_matmul = "onnx.MatMul"(%t9, %t10) {onnx_node_name = "/MatMul"}: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %res_matmul = onnx.MatMul(%t9, %t10) {onnx_node_name = "/MatMul"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - -%res_transpose = "onnx.Transpose"(%t8) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -// CHECK: %res_transpose = onnx.Transpose(%t8) {onnx_node_name = "/Transpose", perm = [1 : i64, 0 : i64]} : (tensor<3x4xf32>) -> tensor<4x3xf32> - -%res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 0}: (tensor<1x2x6xf32>) -> tensor<2x6xf32> -// CHECK: %res_squeeze = onnx.Squeeze(%t0) {onnx_node_name = "/Squeeze", axes = 0 : i64} : (tensor<1x2x6xf32>) -> tensor<2x6xf32> - -%res_sigmoid = "onnx.Sigmoid"(%t8) {onnx_node_name = "/Sigmoid"}: (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %res_sigmoid = onnx.Sigmoid(%t8) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<3x4xf32> diff --git a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir b/tests/filecheck/transforms/convert_onnx_to_linalg.mlir deleted file mode 100644 index 76ff2fef9a..0000000000 --- a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir +++ /dev/null @@ -1,117 +0,0 @@ -// RUN: xdsl-opt -p convert-onnx-to-linalg %s | filecheck %s - -// CHECK: builtin.module { - -%t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) -%res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> -%res_sub = onnx.Sub(%t0, %t1) {onnx_node_name = "/Sub"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - -// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) -// CHECK-NEXT: %res_add = tensor.empty() : tensor<3x2xf32> -// CHECK-NEXT: %res_add_1 = linalg.add ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_add : tensor<3x2xf32>) -> tensor<3x2xf32> -// CHECK-NEXT: %res_sub = tensor.empty() : tensor<3x2xf32> -// CHECK-NEXT: %res_sub_1 = linalg.sub ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_sub : tensor<3x2xf32>) -> tensor<3x2xf32> - - -%t2 = "test.op"() : () -> (tensor<3x4xf32>) -%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> - -// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf32> -// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf32> -// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf32>) outs(%res_relu : tensor<3x4xf32>) { -// CHECK-NEXT: ^0(%0 : f32, %1 : f32): -// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f32 -// CHECK-NEXT: linalg.yield %2 : f32 -// CHECK-NEXT: } -> tensor<3x4xf32> - -%t27 = "test.op"() : () -> (tensor<3x4xf64>) -%res_relu_3 = "onnx.Relu"(%t27) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64> - -// CHECK-NEXT: %t27 = "test.op"() : () -> tensor<3x4xf64> -// CHECK-NEXT: %res_relu_3 = tensor.empty() : tensor<3x4xf64> -// CHECK-NEXT: %res_relu_4 = arith.constant 0.000000e+00 : f64 -// CHECK-NEXT: %res_relu_5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t27 : tensor<3x4xf64>) outs(%res_relu_3 : tensor<3x4xf64>) { -// CHECK-NEXT: ^1(%3 : f64, %4 : f64): -// CHECK-NEXT: %5 = arith.maximumf %3, %res_relu_4 : f64 -// CHECK-NEXT: linalg.yield %5 : f64 -// CHECK-NEXT: } -> tensor<3x4xf64> - - -%t3,%t4 = "test.op"(): () -> (tensor<20x2xf32>, tensor<2xi64>) -%res_reshape = "onnx.Reshape"(%t3, %t4) {onnx_node_name = "/Reshape"}: (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32> - -// CHECK-NEXT: %t3, %t4 = "test.op"() : () -> (tensor<20x2xf32>, tensor<2xi64>) -// CHECK-NEXT: %res_reshape = tensor.reshape %t3(%t4) : (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32> - -%t5, %t6, %t7 = "test.op"(): () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -%res_gemm= "onnx.Gemm"(%t5, %t6, %t7) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> - -// CHECK-NEXT: %t5, %t6, %t7 = "test.op"() : () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -// CHECK-NEXT: %6 = tensor.empty() : tensor<320x50xf32> -// CHECK-NEXT: %7 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%6:tensor<320x50xf32>) permutation = [1, 0] -// CHECK-NEXT: %res_gemm = tensor.empty() : tensor<1x50xf32> -// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %7 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32> -// CHECK-NEXT: %res_gemm_2 = linalg.add ins(%res_gemm_1, %t7 : tensor<1x50xf32>, tensor<50xf32>) outs(%res_gemm_1 : tensor<1x50xf32>) -> tensor<1x50xf32> - - - -%t8, %t9, %t10 = "test.op"(): () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -%res_gemm_1 = "onnx.Gemm"(%t8, %t9, %t10) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> - - -// CHECK-NEXT: %t8, %t9, %t10 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -// CHECK-NEXT: %res_gemm_3 = tensor.empty() : tensor<5x2xf32> -// CHECK-NEXT: %res_gemm_4 = linalg.matmul ins(%t8, %t9 : tensor<5x3xf32>, tensor<3x2xf32>) outs(%res_gemm_3 : tensor<5x2xf32>) -> tensor<5x2xf32> -// CHECK-NEXT: %res_gemm_5 = linalg.add ins(%res_gemm_4, %t10 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%res_gemm_4 : tensor<5x2xf32>) -> tensor<5x2xf32> - - -%t11, %t12, %t13 = "test.op"(): () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -%res_gemm_2 = "onnx.Gemm"(%t11, %t12, %t13) {onnx_node_name = "/Gemm", "alpha" = 0.500000e+00 : f32, "beta" = 0.500000e+00 : f32, "transA" = 1 : si64}: (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> - - -// CHECK-NEXT: %t11, %t12, %t13 = "test.op"() : () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -// CHECK-NEXT: %8 = tensor.empty() : tensor<5x10xf32> -// CHECK-NEXT: %9 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%8:tensor<5x10xf32>) permutation = [1, 0] -// CHECK-NEXT: %10 = arith.constant 5.000000e-01 : f32 -// CHECK-NEXT: %11 = linalg.mul ins(%10, %9 : f32, tensor<5x10xf32>) outs(%9 : tensor<5x10xf32>) -> tensor<5x10xf32> -// CHECK-NEXT: %12 = arith.constant 5.000000e-01 : f32 -// CHECK-NEXT: %13 = linalg.mul ins(%12, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_6 = tensor.empty() : tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_7 = linalg.matmul ins(%11, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_6 : tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_8 = linalg.add ins(%res_gemm_7, %13 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_7 : tensor<5x3xf32>) -> tensor<5x3xf32> - -%t26 = "test.op"(): () -> (tensor<1x16x14x14xf32>) -%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : si64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : si64, strides = [3 : i64, 3 : i64]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> - -// CHECK-NEXT: %t26 = "test.op"() : () -> tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_max_pool_single_out = tensor.empty() : tensor<3x3xf32> -// CHECK-NEXT: %res_max_pool_single_out_1 = tensor.empty() : tensor<1x16x4x4xf32> -// CHECK-NEXT: %res_max_pool_single_out_2 = arith.constant -1.000000e+308 : f64 -// CHECK-NEXT: %res_max_pool_single_out_3 = linalg.fill ins(%res_max_pool_single_out_2 : f64) outs(%res_max_pool_single_out_1 : tensor<1x16x4x4xf32>) -> tensor<1x16x4x4xf32> -// CHECK-NEXT: %res_max_pool_single_out_4 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>} ins(%t26, %res_max_pool_single_out : tensor<1x16x14x14xf32>, tensor<3x3xf32>) outs(%res_max_pool_single_out_3 : tensor<1x16x4x4xf32>) -> tensor<1x16x4x4xf32> - -%t20, %t21, %t22 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -%res_conv_2 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -// CHECK-NEXT: %t20, %t21, %t22 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -// CHECK-NEXT: %res_conv = tensor.empty() : tensor<1x1x3x3xf32> -// CHECK-NEXT: %res_conv_1 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%t20, %t21 : tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>) outs(%res_conv : tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> - -%t23, %t24, %t25 = "test.op"() : () -> (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -%res_conv_3 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "SAME_UPPER", "group" = 1 : i64, "kernel_shape" = [5 : i64, 5 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> - -// CHECK-NEXT: %t23, %t24, %t25 = "test.op"() : () -> (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -// CHECK-NEXT: %res_conv_2 = tensor.empty() : tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_conv_3 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%t23, %t24 : tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>) outs(%res_conv_2 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_conv_4 = linalg.add ins(%t25 : tensor<16xf32>) outs(%res_conv_3 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> - -%res_constant = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<1> : tensor<1xi64>}: () -> tensor<1xi64> -%res_constant_2 = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> - -// CHECK-NEXT: %res_constant = ml_program.global_load_const @onnx_constant_1 : tensor<1xi64> -// CHECK-NEXT: %res_constant_1 = ml_program.global_load_const @onnx_constant_2 : tensor<1x5xf32> -// CHECK-NEXT: ml_program.global private @onnx_constant_1(dense<1> : tensor<1xi64>) : tensor<1xi64> -// CHECK-NEXT: ml_program.global private @onnx_constant_2(dense<2.000000e+00> : tensor<1x5xf32>) : tensor<1x5xf32> - -// CHECK-NEXT: } diff --git a/tests/frontend/onnx/test_build_onnx_ir.py b/tests/frontend/onnx/test_build_onnx_ir.py deleted file mode 100644 index 7a0be0340b..0000000000 --- a/tests/frontend/onnx/test_build_onnx_ir.py +++ /dev/null @@ -1,377 +0,0 @@ -import onnx -import pytest - -from xdsl.dialects.builtin import TensorType, f32 -from xdsl.dialects.onnx import AddOp, MatMulOp, SubOp, TransposeOp -from xdsl.ir import Attribute -from xdsl.utils.test_value import TestSSAValue - -try: - from onnx import TensorProto, ValueInfoProto, helper - - from xdsl.frontend.onnx.ir_builder import ( - OnnxXdslMapping, - build_module, - visit_graph, - visit_node, - visit_value_info, - ) -except ImportError as exc: - print(exc) - pytest.skip("onnx is an optional dependency", allow_module_level=True) - - -def test_visit_value_info(): - # initialize context - ctx = OnnxXdslMapping() - - # create ValueInfoProto input tensor - input_value_info = ValueInfoProto() - input_value_info.name = "input_tensor" - - # define the type of the input tensor - input_tensor_type = input_value_info.type.tensor_type - input_tensor_type.elem_type = TensorProto.FLOAT - input_tensor_type.shape.dim.extend( - [ - onnx.TensorShapeProto.Dimension(dim_value=1), - onnx.TensorShapeProto.Dimension(dim_value=3), - onnx.TensorShapeProto.Dimension(dim_value=224), - onnx.TensorShapeProto.Dimension(dim_value=224), - ] - ) - - # run visit_value_info with empty context - t1 = visit_value_info(input_value_info, ctx) - - # check type info - assert isinstance(t1, Attribute) - assert str(t1) == "tensor<1x3x224x224xf32>" - - # check keys in context - keys = list(ctx.type_by_name.keys()) - assert keys == ["input_tensor"] - - # run visit_value_info again - t2 = visit_value_info(input_value_info, ctx) - - # check type info - assert isinstance(t2, Attribute) - assert str(t2) == "tensor<1x3x224x224xf32>" - - # check keys in context - keys = list(ctx.type_by_name.keys()) - assert keys == ["input_tensor"] - - # check it is returned the same reference - assert t1 is t2 - - -def test_visit_node_unknown_op_name(): - """ - Test for unknown expected onnx op - """ - node_attributes = { - "name": "dummy_name", - "op_type": "dummy_op", - } - - node = helper.make_node( - **node_attributes, inputs=["input1", "input2"], outputs=["output"] - ) - ctx = OnnxXdslMapping() - with pytest.raises(ValueError, match="Unknown ONNX op name dummy_op"): - visit_node(node=node, ctx=ctx) - - -def test_visit_node_add(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Add operation - _, add_node = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - lhs = TestSSAValue(TensorType(f32, [64])) - rhs = TestSSAValue(TensorType(f32, [64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64]) - rhs_type = TensorType(f32, [64]) - out_type = TensorType(f32, [64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - # visit node - op = visit_node(add_node, ctx) - - assert isinstance(op, AddOp) - assert op.lhs is lhs - assert op.rhs is rhs - assert op.res is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_sub(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, sub_node = _create_graph_binary_op("Sub", "sub_graph", [64], [64], [64]) - - lhs = TestSSAValue(TensorType(f32, [64])) - rhs = TestSSAValue(TensorType(f32, [64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64]) - rhs_type = TensorType(f32, [64]) - out_type = TensorType(f32, [64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - op = visit_node(sub_node, ctx) - - assert isinstance(op, SubOp) - assert op.lhs is lhs - assert op.rhs is rhs - assert op.res is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_matmul(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, matmul_node = _create_graph_binary_op( - "MatMul", "matmul_graph", [64, 128], [128, 64], [64, 64] - ) - - lhs = TestSSAValue(TensorType(f32, [64, 128])) - rhs = TestSSAValue(TensorType(f32, [128, 64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64, 128]) - rhs_type = TensorType(f32, [128, 64]) - out_type = TensorType(f32, [64, 64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - op = visit_node(matmul_node, ctx) - op.verify() - - assert isinstance(op, MatMulOp) - assert op.matrix_A is lhs - assert op.matrix_B is rhs - assert op.matrix_Y is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_transpose(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, transpose_node = _create_transpose_op( - graph_name="transpose_graph", dim_in=[64, 128], dim_out=[128, 64], perm=[1, 0] - ) - - in_value = TestSSAValue(TensorType(f32, [64, 128])) - ctx.value_by_name["input"] = in_value - - in_type = TensorType(f32, [64, 128]) - out_type = TensorType(f32, [128, 64]) - ctx.type_by_name["input"] = in_type - ctx.type_by_name["output"] = out_type - - op = visit_node(transpose_node, ctx) - op.verify() - - assert isinstance(op, TransposeOp) - assert op.tensor_input is in_value - assert op.tensor_output is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_graph_add(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Add operation - graph, _ = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - # visit graph - visit_graph(graph, ctx) - - # check value_by_name keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check expected generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Add(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_sub(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - graph, _ = _create_graph_binary_op("Sub", "sub_graph", [64], [64], [64]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Sub(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_matmul(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one MatMul operation - graph, _ = _create_graph_binary_op("MatMul", "matmul_graph", [64], [64], [64]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.MatMul(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_transpose(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Transpose operation - graph, _ = _create_transpose_op("transpose_graph", [64, 128], [128, 64], [1, 0]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[1]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Transpose(%1) : (tensor<64x128xf32>) -> tensor<128x64xf32>" - ) - - -def test_build_module(): - # create graph composed only of one Add operation - graph, _ = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - # create module - module = build_module(graph) - - # define expected output - expected = """ -builtin.module { - func.func @add_graph(%0 : tensor<64xf32>, %1 : tensor<64xf32>) -> tensor<64xf32> { - %2 = onnx.Add(%0, %1) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> - func.return %2 : tensor<64xf32> - } -}""" - - # remove first new line - expected = expected[1:] - - # check output - assert str(module) == expected - - -def _create_graph_binary_op( - op_name: str, - graph_name: str, - dim_in1: list[int], - dim_in2: list[int], - dim_out: list[int], -): - # define input and output names - input1_name = "input1" - input2_name = "input2" - output_name = "output" - - # define input shapes - input1_shape = dim_in1 - input2_shape = dim_in2 - output_shape = dim_out - - # define op node - op_node = helper.make_node( - op_type=op_name, - inputs=[input1_name, input2_name], - outputs=[output_name], - ) - - # create graph (composed of just one operation) - graph = helper.make_graph( - nodes=[op_node], - name=graph_name, - inputs=[ - helper.make_tensor_value_info(input1_name, TensorProto.FLOAT, input1_shape), - helper.make_tensor_value_info(input2_name, TensorProto.FLOAT, input2_shape), - ], - outputs=[ - helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape), - ], - ) - - return graph, op_node - - -def _create_transpose_op( - graph_name: str, dim_in: list[int], dim_out: list[int], perm: list[int] -): - # create input tensor - input_tensor = helper.make_tensor_value_info( - "input", onnx.TensorProto.FLOAT, dim_in - ) - - # create output tensor - output_tensor = helper.make_tensor_value_info( - "output", onnx.TensorProto.FLOAT, dim_out - ) - - # create transpose operation - transpose_node = helper.make_node("Transpose", ["input"], ["output"], perm=perm) - - # create onnx graph - graph_def = helper.make_graph( - [transpose_node], graph_name, [input_tensor], [output_tensor] - ) - - return graph_def, transpose_node diff --git a/tests/frontend/onnx/test_type.py b/tests/frontend/onnx/test_type.py deleted file mode 100644 index 8b8a96a347..0000000000 --- a/tests/frontend/onnx/test_type.py +++ /dev/null @@ -1,89 +0,0 @@ -import onnx -import pytest - -from xdsl.dialects.builtin import Float32Type, TensorType, f32, f64 - -try: - from onnx import TensorShapeProto, TypeProto - - from xdsl.frontend.onnx.type import ( - get_elem_type, - get_shape, - get_tensor_type, - get_type, - ) - from xdsl.utils.hints import isa -except ImportError as exc: - print(exc) - pytest.skip("onnx is an optional dependency", allow_module_level=True) - - -def test_get_elem_type(): - # test case 1: check if 1 corresponds to f32 - assert get_elem_type(1) == f32 - - # test case 11: check if 11 corresponds to f64 - assert get_elem_type(11) == f64 - - # test case -1: check if -1 (or other illegal values) corresponds to None - with pytest.raises(ValueError, match="Unknown elem_type: -1"): - get_elem_type(-1) - - -def test_get_type(): - tensor_type = onnx.TypeProto() - tensor_type.tensor_type.elem_type = onnx.TensorProto.FLOAT - tensor_type.tensor_type.shape.dim.extend( - [ - onnx.TensorShapeProto.Dimension(dim_value=3), - onnx.TensorShapeProto.Dimension(dim_value=4), - onnx.TensorShapeProto.Dimension(dim_value=5), - ] - ) - tt = get_type(tensor_type) - - assert isa(tt, TensorType[Float32Type]) - - assert tt.get_num_dims() == 3 - assert tt.get_shape() == (3, 4, 5) - assert tt.get_element_type().name == "f32" - - -def test_get_shape(): - assert get_shape(TensorShapeProto()) == () - assert get_shape( - TensorShapeProto(dim=(TensorShapeProto.Dimension(dim_value=1),)) - ) == (1,) - assert get_shape( - TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=1), - TensorShapeProto.Dimension(dim_value=2), - ) - ) - ) == (1, 2) - - -def test_get_tensor_type(): - assert get_tensor_type( - TypeProto.Tensor( - elem_type=1, - shape=TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=2), - TensorShapeProto.Dimension(dim_value=3), - ) - ), - ) - ) == TensorType(f32, (2, 3)) - assert get_tensor_type( - TypeProto.Tensor( - elem_type=11, - shape=TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=4), - TensorShapeProto.Dimension(dim_value=5), - ) - ), - ) - ) == TensorType(f64, (4, 5)) diff --git a/tests/interpreters/test_onnx_interpreter.py b/tests/interpreters/test_onnx_interpreter.py deleted file mode 100644 index 01c1e11013..0000000000 --- a/tests/interpreters/test_onnx_interpreter.py +++ /dev/null @@ -1,554 +0,0 @@ -import pytest - -from xdsl.dialects import onnx -from xdsl.dialects.builtin import ( - AnyIntegerAttr, - ArrayAttr, - DenseIntOrFPElementsAttr, - FloatAttr, - ModuleOp, - NoneType, - StringAttr, - TensorType, - f32, - i64, -) -from xdsl.interpreter import Interpreter -from xdsl.interpreters.builtin import BuiltinFunctions -from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr -from xdsl.utils.exceptions import InterpretationError -from xdsl.utils.test_value import TestSSAValue - -pytest.importorskip("numpy", reason="numpy is an optional dependency in xDSL") - -from xdsl.interpreters.onnx import OnnxFunctions # noqa: E402 - - -def test_onnx_add(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.AddOp( - TestSSAValue(TensorType(f32, [2, 3])), - TestSSAValue(TensorType(f32, [2, 3])), - res_type=TensorType(f32, [2, 3]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6]), [2, 3]) - b = ShapedArray(TypedPtr.new_float32([1, 4, 2, 5, 3, 6]), [2, 3]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([2, 6, 5, 9, 8, 12]), [2, 3]) - - -def test_onnx_sub(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.SubOp( - TestSSAValue(TensorType(f32, [2, 3])), - TestSSAValue(TensorType(f32, [2, 3])), - res_type=TensorType(f32, [2, 3]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6]), [2, 3]) - b = ShapedArray(TypedPtr.new_float32([1, 4, 2, 5, 3, 6]), [2, 3]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([0, -2, 1, -1, 2, 0]), [2, 3]) - - -def test_onnx_mul(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MulOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - res_type=TensorType(f32, [2, 2]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 4, 7, 1]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([2, 3, 1, 8]), [2, 2]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([2, 12, 7, 8]), [2, 2]) - - -def test_onnx_div(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.DivOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - res_type=TensorType(f32, [2, 2]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 1, 1, 1]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([5, 2, 1, 2]), [2, 2]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([0.2, 0.5, 1.0, 0.5]), [2, 2]) - - -def test_onnx_relu(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReluOp( - TestSSAValue(TensorType(f32, [2, 2])), - ) - - a = ShapedArray(TypedPtr.new_float32([-1, 0, 1, 2]), [2, 2]) - (b,) = interpreter.run_op(op, (a,)) - assert b == ShapedArray(TypedPtr.new_float32([-0.0, 0, 1, 2]), [2, 2]) - - -def test_onnx_constant(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - interpreter.register_implementations(BuiltinFunctions()) - op = onnx.ConstantOp( - ( - DenseIntOrFPElementsAttr.create_dense_int( - TensorType(i64, [4]), [5, 5, 16, 2] - ) - ), - None, - None, - None, - None, - None, - None, - output_type=TensorType(i64, [4]), - ) - - (a,) = interpreter.run_op(op, ()) - assert a == ShapedArray(TypedPtr.new_int64([5, 5, 16, 2]), [4]) - - -def test_onnx_reshape(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReshapeOp( - (TestSSAValue(TensorType(f32, [1, 10]))), - (TestSSAValue(TensorType(i64, [2]))), - AnyIntegerAttr(0, i64), - ) - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), [1, 10]) - b = ShapedArray(TypedPtr.new_float32([1, 10]), [2]) - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray( - TypedPtr.new_float32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), [1, 10] - ) - - -def test_onnx_reshape_error(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReshapeOp( - (TestSSAValue(TensorType(f32, [1, 10]))), - (TestSSAValue(TensorType(i64, [2]))), - AnyIntegerAttr(0, i64), - ) - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [1, 4]) - b = ShapedArray(TypedPtr.new_float32([2, 2]), [2]) - with pytest.raises( - InterpretationError, match="Mismatch between static shape and new shape" - ): - interpreter.run_op(op, (a, b)) - - -def test_onnx_gemm(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(1, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([2, 4, 6, 8]), [2, 2]) - c = ShapedArray(TypedPtr.new_float32([1, 1, 1, 1]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([15, 21, 31, 45]), [2, 2]) - - -def test_onnx_gemm_transpose_b(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(1, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(1, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2]), [2, 1]) - b = ShapedArray(TypedPtr.new_float32([4, 9]), [2, 1]) - c = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([5, 11, 11, 22]), [2, 2]) - - -def test_onnx_gemm_alpha(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [1, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(2, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2]), [2, 1]) - b = ShapedArray(TypedPtr.new_float32([4, 9]), [1, 2]) - c = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([9, 20, 19, 40]), [2, 2]) - - -def test_onnx_conv_no_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([54, 63, 72, 99, 108, 117, 144, 153, 162]), [1, 1, 3, 3] - ) - - -def test_onnx_conv_with_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32( - [ - 12.0, - 21.0, - 27.0, - 33.0, - 24.0, - 33.0, - 54.0, - 63.0, - 72.0, - 51.0, - 63.0, - 99.0, - 108.0, - 117.0, - 81.0, - 93.0, - 144.0, - 153.0, - 162.0, - 111.0, - 72.0, - 111.0, - 117.0, - 123.0, - 84.0, - ] - ), - [1, 1, 5, 5], - ) - - -def test_onnx_conv_with_same_lower_strides(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("SAME_LOWER"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([12.0, 27.0, 24.0, 63.0, 108.0, 81.0, 72.0, 117.0, 84.0]), - [1, 1, 3, 3], - ) - - -def test_onnx_conv_with_strides_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32( - [ - 12.0, - 27.0, - 24.0, - 63.0, - 108.0, - 81.0, - 123.0, - 198.0, - 141.0, - 112.0, - 177.0, - 124.0, - ] - ), - [1, 1, 4, 3], - ) - - -def test_onnx_conv_with_strides_no_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([54.0, 72.0, 144.0, 162.0, 234.0, 252.0]), [1, 1, 3, 2] - ) - - -def test_onnx_conv_with_strides_asy_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([21.0, 33.0, 99.0, 117.0, 189.0, 207.0, 171.0, 183.0]), - [1, 1, 4, 2], - ) - - -def test_onnx_max_pool_single_out(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MaxPoolSingleOutOp( - TestSSAValue(TensorType(f32, [1, 1, 4, 4])), - StringAttr("NOTSET"), - AnyIntegerAttr(0, i64), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - pads=ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - storage_order=AnyIntegerAttr(0, i64), - strides=ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(1, 17)), [1, 1, 4, 4]) - (b,) = interpreter.run_op(op, (a,)) - - assert b == ShapedArray( - TypedPtr.new_float32([6, 7, 8, 10, 11, 12, 14, 15, 16]), [1, 1, 3, 3] - ) - - -def test_onnx_max_pool_single_out_strides_two(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MaxPoolSingleOutOp( - TestSSAValue(TensorType(f32, [1, 1, 4, 4])), - StringAttr("NOTSET"), - AnyIntegerAttr(0, i64), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - pads=ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - storage_order=AnyIntegerAttr(0, i64), - strides=ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray( - TypedPtr.new_float32([1, 1, 2, 4, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4]), - [1, 1, 4, 4], - ) - (b,) = interpreter.run_op(op, (a,)) - - assert b == ShapedArray(TypedPtr.new_float32([6, 8, 3, 4]), [1, 1, 2, 2]) diff --git a/uv.lock b/uv.lock index 694dd59a9c..c7a8751286 100644 --- a/uv.lock +++ b/uv.lock @@ -1373,33 +1373,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/3e/1959d5219a9e6d200638d924cedda6a606392f7186a4ed56478252e70d55/numpy-2.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e", size = 12820057 }, ] -[[package]] -name = "onnx" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/29/57053ba7787788ac75efb095cfc1ae290436b6d3a26754693cd7ed1b4fac/onnx-1.17.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:38b5df0eb22012198cdcee527cc5f917f09cce1f88a69248aaca22bd78a7f023", size = 16645616 }, - { url = "https://files.pythonhosted.org/packages/75/0d/831807a18db2a5e8f7813848c59272b904a4ef3939fe4d1288cbce9ea735/onnx-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d545335cb49d4d8c47cc803d3a805deb7ad5d9094dc67657d66e568610a36d7d", size = 15908420 }, - { url = "https://files.pythonhosted.org/packages/dd/5b/c4f95dbe652d14aeba9afaceb177e9ffc48ac3c03048dd3f872f26f07e34/onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3193a3672fc60f1a18c0f4c93ac81b761bc72fd8a6c2035fa79ff5969f07713e", size = 16046244 }, - { url = "https://files.pythonhosted.org/packages/08/a9/c1f218085043dccc6311460239e253fa6957cf12ee4b0a56b82014938d0b/onnx-1.17.0-cp310-cp310-win32.whl", hash = "sha256:0141c2ce806c474b667b7e4499164227ef594584da432fd5613ec17c1855e311", size = 14423516 }, - { url = "https://files.pythonhosted.org/packages/0e/d3/d26ebf590a65686dde6b27fef32493026c5be9e42083340d947395f93405/onnx-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:dfd777d95c158437fda6b34758f0877d15b89cbe9ff45affbedc519b35345cf9", size = 14528496 }, - { url = "https://files.pythonhosted.org/packages/e5/a9/8d1b1d53aec70df53e0f57e9f9fcf47004276539e29230c3d5f1f50719ba/onnx-1.17.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:d6fc3a03fc0129b8b6ac03f03bc894431ffd77c7d79ec023d0afd667b4d35869", size = 16647991 }, - { url = "https://files.pythonhosted.org/packages/7b/e3/cc80110e5996ca61878f7b4c73c7a286cd88918ff35eacb60dc75ab11ef5/onnx-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01a4b63d4e1d8ec3e2f069e7b798b2955810aa434f7361f01bc8ca08d69cce4", size = 15908949 }, - { url = "https://files.pythonhosted.org/packages/b1/2f/91092557ed478e323a2b4471e2081fdf88d1dd52ae988ceaf7db4e4506ff/onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a183c6178be001bf398260e5ac2c927dc43e7746e8638d6c05c20e321f8c949", size = 16048190 }, - { url = "https://files.pythonhosted.org/packages/ac/59/9ea23fc22d0bb853133f363e6248e31bcbc6c1c90543a3938c00412ac02a/onnx-1.17.0-cp311-cp311-win32.whl", hash = "sha256:081ec43a8b950171767d99075b6b92553901fa429d4bc5eb3ad66b36ef5dbe3a", size = 14424299 }, - { url = "https://files.pythonhosted.org/packages/51/a5/19b0dfcb567b62e7adf1a21b08b23224f0c2d13842aee4d0abc6f07f9cf5/onnx-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:95c03e38671785036bb704c30cd2e150825f6ab4763df3a4f1d249da48525957", size = 14529142 }, - { url = "https://files.pythonhosted.org/packages/b4/dd/c416a11a28847fafb0db1bf43381979a0f522eb9107b831058fde012dd56/onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:0e906e6a83437de05f8139ea7eaf366bf287f44ae5cc44b2850a30e296421f2f", size = 16651271 }, - { url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 }, - { url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 }, - { url = "https://files.pythonhosted.org/packages/ae/20/6da11042d2ab870dfb4ce4a6b52354d7651b6b4112038b6d2229ab9904c4/onnx-1.17.0-cp312-cp312-win32.whl", hash = "sha256:317870fca3349d19325a4b7d1b5628f6de3811e9710b1e3665c68b073d0e68d7", size = 14424235 }, - { url = "https://files.pythonhosted.org/packages/35/55/c4d11bee1fdb0c4bd84b4e3562ff811a19b63266816870ae1f95567aa6e1/onnx-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:659b8232d627a5460d74fd3c96947ae83db6d03f035ac633e20cd69cfa029227", size = 14530453 }, -] - [[package]] name = "opt-einsum" version = "3.4.0" @@ -1591,20 +1564,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/b6/c5319caea262f4821995dca2107483b94a3345d4607ad797c76cb9c36bcc/propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54", size = 11818 }, ] -[[package]] -name = "protobuf" -version = "5.29.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f7/d1/e0a911544ca9993e0f17ce6d3cc0932752356c1b0a834397f28e63479344/protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620", size = 424945 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/7a/1e38f3cafa022f477ca0f57a1f49962f21ad25850c3ca0acd3b9d0091518/protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888", size = 422708 }, - { url = "https://files.pythonhosted.org/packages/61/fa/aae8e10512b83de633f2646506a6d835b151edf4b30d18d73afd01447253/protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a", size = 434508 }, - { url = "https://files.pythonhosted.org/packages/dd/04/3eaedc2ba17a088961d0e3bd396eac764450f431621b58a04ce898acd126/protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e", size = 417825 }, - { url = "https://files.pythonhosted.org/packages/4f/06/7c467744d23c3979ce250397e26d8ad8eeb2bea7b18ca12ad58313c1b8d5/protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84", size = 319573 }, - { url = "https://files.pythonhosted.org/packages/a8/45/2ebbde52ad2be18d3675b6bee50e68cd73c9e0654de77d595540b5129df8/protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f", size = 319672 }, - { url = "https://files.pythonhosted.org/packages/fd/b2/ab07b09e0f6d143dfb839693aa05765257bceaa13d03bf1a696b78323e7a/protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f", size = 172550 }, -] - [[package]] name = "psutil" version = "6.1.1" @@ -2536,10 +2495,6 @@ jax = [ { name = "jax" }, { name = "numpy" }, ] -onnx = [ - { name = "numpy" }, - { name = "onnx" }, -] riscv = [ { name = "riscemu" }, ] @@ -2557,8 +2512,6 @@ requires-dist = [ { name = "nbconvert", marker = "extra == 'dev'", specifier = ">=7.7.2,<8.0.0" }, { name = "nbval", marker = "extra == 'dev'", specifier = "<0.12" }, { name = "numpy", marker = "extra == 'jax'", specifier = "==2.2.1" }, - { name = "numpy", marker = "extra == 'onnx'", specifier = "==2.2.1" }, - { name = "onnx", marker = "extra == 'onnx'", specifier = "==1.17.0" }, { name = "ordered-set", specifier = "==4.1.0" }, { name = "pip", marker = "extra == 'dev'", specifier = "<25.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = "==4.0.1" }, diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py index 256f28802c..0c8eb9e589 100644 --- a/xdsl/dialects/__init__.py +++ b/xdsl/dialects/__init__.py @@ -178,11 +178,6 @@ def get_omp(): return OMP - def get_onnx(): - from xdsl.dialects.onnx import ONNX - - return ONNX - def get_pdl(): from xdsl.dialects.pdl import PDL @@ -363,7 +358,6 @@ def get_transform(): "mod_arith": get_mod_arith, "mpi": get_mpi, "omp": get_omp, - "onnx": get_onnx, "pdl": get_pdl, "printf": get_printf, "ptr_xdsl": get_ptr_xdsl, diff --git a/xdsl/dialects/onnx.py b/xdsl/dialects/onnx.py deleted file mode 100644 index 346975baea..0000000000 --- a/xdsl/dialects/onnx.py +++ /dev/null @@ -1,1106 +0,0 @@ -from __future__ import annotations - -import math -from abc import ABC -from typing import Annotated, ClassVar, cast - -from typing_extensions import Self - -from xdsl.dialects.builtin import ( - Any, - AnyFloat, - AnyFloatConstr, - AnyIntegerAttr, - AnyTensorType, - ArrayAttr, - DenseIntOrFPElementsAttr, - Float32Type, - FloatAttr, - IntegerAttr, - IntegerType, - MemRefType, - NoneType, - SSAValue, - StringAttr, - SymbolRefAttr, - TensorOrMemrefOf, - TensorType, -) -from xdsl.ir import ( - Attribute, - Dialect, -) -from xdsl.irdl import ( - ConstraintVar, - IRDLOperation, - VarConstraint, - attr_def, - base, - irdl_op_definition, - operand_def, - opt_attr_def, - result_def, -) -from xdsl.parser import Parser -from xdsl.printer import Printer -from xdsl.utils.exceptions import VerifyException - - -def verify_unidirectional_broadcast_shape( - lhs: TensorType[Attribute], rhs: TensorType[Attribute], res: TensorType[Attribute] -) -> None: - """ - Returns a unidirectional broadcastable shape - """ - lhs_shape = lhs.get_shape() - rhs_shape = rhs.get_shape() - expected_shape = unidirectional_broadcast_shape(list(lhs_shape), list(rhs_shape)) - if expected_shape is None: - raise VerifyException( - f"operands have incompatible shapes: {lhs_shape} and {rhs_shape}" - ) - res_type_shape = res.get_shape() - if ( - len(expected_shape) != len(res_type_shape) - or tuple(expected_shape) != res_type_shape - ): - raise VerifyException( - f"result shape {expected_shape} does not match result type {res}" - ) - - -def verify_multidirectional_broadcast_shape( - lhs: TensorType[Attribute], rhs: TensorType[Attribute], res: TensorType[Attribute] -) -> None: - """ - Returns a multidirectional broadcastable shape - """ - lhs_shape = lhs.get_shape() - rhs_shape = rhs.get_shape() - expected_shape = multidirectional_broadcast_shape(list(lhs_shape), list(rhs_shape)) - if expected_shape is None: - raise VerifyException( - f"operands have incompatible shapes: {lhs_shape} and {rhs_shape}" - ) - res_type_shape = res.get_shape() - if ( - len(expected_shape) != len(res_type_shape) - or tuple(expected_shape) != res_type_shape - ): - raise VerifyException( - f"result shape {expected_shape} does not match result type {res}" - ) - - -def unidirectional_broadcast_shape(lhs: list[int], rhs: list[int]) -> list[int] | None: - """ - In ONNX, tensor B is unidirectional broadcastable to tensor A if one of the following is true: - - 1. Tensor A and B both have exactly the same shape. - 2. Tensor A and B all have the same number of dimensions and - the length of each dimensions is either a common length or B's length is 1. - - 3.Tensor B has too few dimensions, and B can have its shapes prepended with a dimension of length 1 to satisfy - property 2. - - When unidirectional broadcasting happens, the output's shape is the same as the shape of A (i.e., - the larger shape of two input tensors) - - https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - """ - # Check if Tensor A and B both have exactly the same shape - if lhs == rhs: - return lhs - - lhs_len = len(lhs) - rhs_len = len(rhs) - prefix_len = lhs_len - rhs_len - if prefix_len < 0: - # lhs must not be shorter than rhs - return None - res_shape = lhs[:prefix_len] - for dl, dr in zip(lhs[prefix_len:], rhs): - if dl == dr or dr == 1 or dl == 1: - res_shape.append(max(dl, dr)) - else: - return None - return res_shape - - -def multidirectional_broadcast_shape( - lhs: list[int], rhs: list[int] -) -> list[int] | None: - """ - In ONNX, a set of tensors are multidirectional broadcastable to the same shape if one of the following is true: - - 1.The tensors all have exactly the same shape. - 2.The tensors all have the same number of dimensions and the length of each dimensions is either a common length or 1. - 3.The tensors that have too few dimensions can have their shapes prepended with a dimension of length 1 to satisfy property 2. - - https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - """ - - if len(lhs) > len(rhs): - return unidirectional_broadcast_shape(rhs, lhs) - else: - return unidirectional_broadcast_shape(lhs, rhs) - - -class ElementwiseBinOpBase(IRDLOperation, ABC): - """Base class for element-wise binary operations on tensors with Numpy-style broadcasting.""" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - lhs = operand_def(TensorType[T]) - rhs = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = "`(` $lhs `,` $rhs `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)" - - def __init__(self, lhs: SSAValue, rhs: SSAValue, res_type: Attribute): - super().__init__( - operands=[lhs, rhs], - result_types=[res_type], - ) - - def verify_(self) -> None: - # Check that the arguments are broadcastable (using Numpy semantics) and that the result type is correct. - assert isinstance(lhs_type := self.lhs.type, TensorType) - assert isinstance(rhs_type := self.rhs.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - verify_multidirectional_broadcast_shape(lhs_type, rhs_type, res_type) # pyright: ignore[reportUnknownArgumentType] - - -@irdl_op_definition -class AddOp(ElementwiseBinOpBase): - name = "onnx.Add" - - -@irdl_op_definition -class SubOp(ElementwiseBinOpBase): - name = "onnx.Sub" - - -@irdl_op_definition -class MulOp(ElementwiseBinOpBase): - name = "onnx.Mul" - - -@irdl_op_definition -class DivOp(ElementwiseBinOpBase): - name = "onnx.Div" - - -@irdl_op_definition -class ReluOp(IRDLOperation): - """ - Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, - y = max(0, x), is applied to the tensor elementwise. - """ - - name = "onnx.Relu" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - operand = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = ( - "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" - ) - - def __init__(self, operand: SSAValue): - super().__init__( - operands=[operand], - result_types=[operand.type], - ) - - def verify_(self) -> None: - assert isinstance(operand_type := self.operand.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - operand_type = cast(TensorType[Attribute], operand_type) - res_type = cast(TensorType[Attribute], res_type) - - if operand_type != res_type: - raise VerifyException( - "Mismatch between operand type and res type of onnx.Relu" - ) - - -@irdl_op_definition -class GemmOp(IRDLOperation): - """ - General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 - A' = transpose(A) if transA else A - B' = transpose(B) if transB else B - Compute Y = alpha * A' * B' + beta * C, - where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), - input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). - """ - - name = "onnx.Gemm" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - tensor_a = operand_def(TensorType[T]) - tensor_b = operand_def(TensorType[T]) - tensor_c = operand_def(TensorType[T]) - - alpha = opt_attr_def(FloatAttr[AnyFloat]) - beta = opt_attr_def(FloatAttr[AnyFloat]) - - trans_a = opt_attr_def(AnyIntegerAttr, attr_name="transA") - trans_b = opt_attr_def(AnyIntegerAttr, attr_name="transB") - - res_tensor = result_def(TensorType[T]) - assembly_format = ( - "`(` $tensor_a `,` $tensor_b `,`$tensor_c`)` attr-dict `:` `(` type($tensor_a) `," - "` type($tensor_b) `,`type($tensor_c)`)` `->` type($res_tensor) " - ) - - def __init__( - self, - tensor_a: SSAValue, - tensor_b: SSAValue, - tensor_c: SSAValue, - alpha: Attribute, - trans_a: Attribute, - trans_b: Attribute, - beta: Attribute, - ): - super().__init__( - attributes={ - "transA": trans_a, - "transB": trans_b, - "alpha": alpha, - "beta": beta, - }, - operands=[tensor_a, tensor_b, tensor_c], - result_types=[tensor_c.type], - ) - - def verify_(self) -> None: - assert isinstance(tensor_a_type := self.tensor_a.type, TensorType) - assert isinstance(tensor_b_type := self.tensor_b.type, TensorType) - assert isinstance(tensor_c_type := self.tensor_c.type, TensorType) - assert isinstance(res_tensor_type := self.res_tensor.type, TensorType) - - tensor_a_type = cast(TensorType[Attribute], tensor_a_type) - tensor_b_type = cast(TensorType[Attribute], tensor_b_type) - tensor_c_type = cast(TensorType[Attribute], tensor_c_type) - res_tensor_type = cast(TensorType[Attribute], res_tensor_type) - - # store dimensions of tensor A and tensor B - res_shape: list[int] = [] - tensor_a_shape = tensor_a_type.get_shape() - tensor_b_shape = tensor_b_type.get_shape() - - if tensor_a_type.get_num_dims() != 2: - raise VerifyException("tensor A should be a 2D tensor") - - if tensor_b_type.get_num_dims() != 2: - raise VerifyException("tensor B should be a 2D tensor") - - if self.trans_a is not None and self.trans_a.value.data == 1: - tensor_a_shape = tuple(reversed(tensor_a_shape)) - - if self.trans_b is not None and self.trans_b.value.data == 1: - tensor_b_shape = tuple(reversed(tensor_b_shape)) - - if self.beta is not None: - c_dims = tensor_c_type.get_num_dims() - if c_dims > 2: - raise VerifyException("tensor C should be a 1D tensor or 2D tensor") - - if tensor_a_shape[1] != tensor_b_shape[0]: - raise VerifyException( - f"operands have incompatible shapes: {tensor_a_shape} and {tensor_b_shape}" - ) - else: - res_shape.append(tensor_a_shape[0]) - res_shape.append(tensor_b_shape[1]) - - # Build tensor of tensor (A * B) computation - tensors_res = TensorType(tensor_a_type.element_type, res_shape) - verify_unidirectional_broadcast_shape( - tensors_res, tensor_c_type, res_tensor_type - ) - - -@irdl_op_definition -class ReshapeOp(IRDLOperation): - """ - Reshape the input tensor similar to numpy.reshape. - First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. - At most one dimension of the new shape can be -1. In this case, the value is - inferred from the size of the tensor and the remaining dimensions. A dimension - could also be 0, in which case the actual dimension value is unchanged (i.e. taken - from the input tensor). Shape (second input) could be an empty shape, which means converting to a scalar. - The input tensor's shape and the output tensor's shape are required to have the same number of elements. - - Attributes: - - allowzero int (default is 0): By default, when any value in the 'shape' input is equal to zero - the corresponding dimension value is copied from the input tensor dynamically. allowzero=1 indicates that if any - value in the 'shape' input is set to zero, the zero value is honoured, similar to NumPy. - - """ - - name = "onnx.Reshape" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - data = operand_def(TensorType[T]) - shape = operand_def(TensorType[IntegerType]) - reshaped = result_def(TensorType[T]) - - allow_zero = opt_attr_def(AnyIntegerAttr, attr_name="allowzero") - - assembly_format = "`(` $data `,` $shape `)` attr-dict `:` `(` type($data) `,` type($shape) `)` `->` type($reshaped)" - - def __init__(self, data: SSAValue, shape: SSAValue, allow_zero: Attribute): - super().__init__( - attributes={"allowzero": allow_zero}, - operands=[data, shape], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType) - assert isinstance(shape_type := self.shape.type, TensorType) - assert isinstance(reshaped_type := self.reshaped.type, TensorType) - - data_type = cast(TensorType[Attribute], data_type) - reshaped_type = cast(TensorType[Attribute], reshaped_type) - shape_type = cast(TensorType[Attribute], shape_type) - - if shape_type.element_type != IntegerType(64): - raise VerifyException( - "shape element type has to be a 64-bit signless integer" - ) - - data_type = data_type.get_shape() - shape_type = shape_type.get_shape() - reshaped_type = reshaped_type.get_shape() - - # Shape tensor rank can't be -1. - if shape_type[0] == -1: - raise VerifyException("Shape tensor rank must not be equal to -1") - - # There is currently only support for rank one shape tensors in onnx-mlir - # Shape tensor must have a constant shape. - if len(shape_type) != 1: - raise VerifyException("Shape tensor must have a rank one") - - # The input tensor's shape and the output tensor's shape are required to have the same number of elements. - if math.prod(data_type) != math.prod(reshaped_type): - raise VerifyException( - "Input tensor's shape and output tensor's shape must have the same number of elements" - ) - - -@irdl_op_definition -class AbsOp(IRDLOperation): - """ - Absolute takes one input data (Tensor) and produces one output data (Tensor) where absolute value, - y = abs(x), is applied to the tensor elementwise. - """ - - name = "onnx.Abs" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - operand = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = ( - "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" - ) - - def __init__(self, operand: SSAValue): - super().__init__( - operands=[operand], - result_types=[operand.type], - ) - - def verify_(self) -> None: - assert isinstance(operand_type := self.operand.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - operand_type = cast(TensorType[Attribute], operand_type) - res_type = cast(TensorType[Attribute], res_type) - if operand_type != res_type: - raise VerifyException( - "Mismatch between operand type and res type of onnx.Abs" - ) - - -@irdl_op_definition -class ConvOp(IRDLOperation): - """ - The convolution operator consumes an input tensor and a filter, and computes the output. - - Attributes: - - - auto_pad string (default is NOTSET): auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or - VALID. Where default value is NOTSET, which means explicit padding is used. SAME_UPPER or SAME_LOWER mean pad the - input so that output_shape[i] = ceil(input_shape[i] / strides[i]) for each axis i. The padding is split between - the two sides equally or almost equally (depending on whether it is even or odd). In case the padding is an odd - number, the extra padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER. - - -dilations list of ints: dilation value along each spatial axis of the filter. If not present, the dilation - defaults is 1 along each spatial axis. - - -group int (default is '1'): number of groups input channels and output channels are divided into. - - -kernel_shape list of ints: The shape of the convolution kernel. If not present, should be inferred from input W. - - -pads list of ints: Padding for the beginning and ending along each spatial axis, it can take any value greater - than or equal to 0. The value represent the number of pixels added to the beginning and end part of the - corresponding axis. `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], where xi_begin - the number of pixels added at the beginning of axis `i` and xi_end, the number of pixels added at the end of axis - `i`. This attribute cannot be used simultaneously with auto_pad attribute. If not present, the padding defaults - to 0 along start and end of each spatial axis. - - -strides list of ints: Stride along each spatial axis. If not present, the stride defaults is 1 along each spatial axis. - - """ - - name = "onnx.Conv" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - data = operand_def(TensorType[T]) - weight = operand_def(TensorType[T]) - bias = operand_def(base(TensorType[T]) | base(NoneType)) - res = result_def(TensorType[T]) - - auto_pad = attr_def(StringAttr) - dilations = attr_def(ArrayAttr[AnyIntegerAttr]) - group = attr_def(AnyIntegerAttr) - kernel_shape = attr_def(ArrayAttr[AnyIntegerAttr]) - pads = attr_def(ArrayAttr[AnyIntegerAttr]) - strides = attr_def(ArrayAttr[AnyIntegerAttr]) - - assembly_format = ( - "`(` $data `,` $weight `,`$bias`)` attr-dict `:` `(` type($data) `," - "` type($weight) `,`type($bias)`)` `->` type($res) " - ) - - def __init__( - self, - data: SSAValue, - weight: SSAValue, - bias: SSAValue, - auto_pad: Attribute, - dilations: Attribute, - group: Attribute, - kernel_shape: Attribute, - pads: Attribute, - strides: Attribute, - ): - super().__init__( - attributes={ - "auto_pad": auto_pad, - "dilations": dilations, - "group": group, - "kernel_shape": kernel_shape, - "pads": pads, - "strides": strides, - }, - operands=[data, weight, bias], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType) - assert isinstance(weight_type := self.weight.type, TensorType) - assert isinstance(bias_type := self.bias.type, TensorType | NoneType) - assert isinstance(res_type := self.res.type, TensorType) - - weight_type = cast(TensorType[Attribute], weight_type) - data_type = cast(TensorType[Attribute], data_type) - res_type = cast(TensorType[Attribute], res_type) - - # case that bias is a tensor type - if isinstance(bias_type, TensorType): - bias_type = bias_type.get_shape() - if len(bias_type) != 1: - raise VerifyException("bias must be 1D") - - weight_type = weight_type.get_shape() - # kernel_shape - kernel_shape_data: list[int] = [] - for value in self.kernel_shape: - val = value.value.data - kernel_shape_data.append(val) - if list(weight_type[-2:]) != kernel_shape_data: - raise VerifyException( - "kernel shape rank and weight tensor rank are not the same" - ) - - # dilations - for value in self.dilations: - val = value.value.data - if val <= 0: - raise VerifyException("dilation value must be non zero positive") - if len(self.dilations) != len(self.kernel_shape): - raise VerifyException( - "dilations rank and kernel shape rank are not the same" - ) - - # group - if self.group.value.data < 1: - raise VerifyException("group value must be nonnegative") - - # strides - for value in self.strides: - val = value.value.data - if val <= 0: - raise VerifyException("stride value must be non zero positive") - if len(self.strides) != len(self.kernel_shape): - raise VerifyException( - "strides rank and kernel shape rank are not the same " - ) - - # pads - for value in self.pads: - val = value.value.data - if val < 0: - raise VerifyException("pads value must be nonnegative") - if len(self.pads) != 2 * len(self.kernel_shape): - raise VerifyException("pads rank is not twice the kernel shape rank") - - auto_pad_strings = ["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] - if self.auto_pad.data not in auto_pad_strings: - raise VerifyException( - f"Invalid auto_pad string. Must be one of {auto_pad_strings}" - ) - - -@irdl_op_definition -class ConstantOp(IRDLOperation): - """ - Produce a constant tensor. - - Exactly one of the provided attributes, either value, sparse_value, or value_* must be specified. - - Attributes: - - sparse_value: sparse_tensor - The value for the elements of the output tensor in sparse format. (currently unsupported) - - value : tensor - The value for the elements of the output tensor. - - value_float: float - The value for the sole element for the scalar, float32, output tensor. - - value_floats: list of floats - The values for the elements for the 1D, float32, output tensor. - - value_int : int - The value for the sole element for the scalar, int64, output tensor. - - value_ints : list of ints - The values for the elements for the 1D, int64, output tensor. - - value_string : string - The value for the sole element for the scalar, UTF-8 string, output tensor. - - value_strings: list of strings - The values for the elements for the 1D, UTF-8 string, output tensor. - """ - - name = "onnx.Constant" - output = result_def(AnyTensorType) - - value = opt_attr_def(DenseIntOrFPElementsAttr) - value_float = opt_attr_def(FloatAttr[Float32Type]) - value_floats = opt_attr_def(ArrayAttr[FloatAttr[Float32Type]]) - value_int = opt_attr_def(IntegerAttr[IntegerType]) - value_ints = opt_attr_def(ArrayAttr[IntegerAttr[IntegerType]]) - value_string = opt_attr_def(StringAttr) - value_strings = opt_attr_def(ArrayAttr[StringAttr]) - - def __init__( - self, - value: Attribute | None, - value_float: Attribute | None, - value_floats: Attribute | None, - value_int: Attribute | None, - value_ints: Attribute | None, - value_string: Attribute | None, - value_strings: Attribute | None, - output_type: Attribute | None, - ): - super().__init__( - attributes={ - "value": value, - "value_float": value_float, - "value_floats": value_floats, - "value_int": value_int, - "value_ints": value_ints, - "value_string": value_string, - "value_strings": value_strings, - }, - operands=[], - result_types=[output_type], - ) - - def verify_(self) -> None: - if self.value is not None and not isinstance(self.value.type, TensorType): - raise VerifyException("value attribute type must be of type TensorType") - - if self.value_int is not None and self.value_int.type.width.data != 64: - raise VerifyException( - "value_int element type has to be a 64-bit signless integer" - ) - - if self.value_ints is not None: - for value in self.value_ints: - width = value.type.width.data - if width != 64: - raise VerifyException( - "value_ints elements type has to be a 64-bit signless integer" - ) - - attrs = [ - self.value, - self.value_float, - self.value_floats, - self.value_int, - self.value_ints, - self.value_string, - self.value_strings, - ] - used_attrs = sum(1 for attr in attrs if attr is not None) - if used_attrs != 1: - raise VerifyException( - f"Only one value attribute must be provided, but {used_attrs} were specified" - ) - - def print(self, printer: Printer): - if self.value is not None: - printer.print(" ") - printer.print(self.value) - - @classmethod - def parse(cls, parser: Parser) -> Self: - v = parser.parse_attribute() - if not isinstance(v, DenseIntOrFPElementsAttr): - raise NotImplementedError() - constant = cls(v, None, None, None, None, None, None, v.type) - return constant - - -@irdl_op_definition -class MaxPoolSingleOutOp(IRDLOperation): - """ - ONNX MaxPool operation with a single output. - - Attributes: - - - auto_pad string (default is 'NOTSET'): auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or - VALID. Where default value is NOTSET, which means explicit padding is used. SAME_UPPER or SAME_LOWER mean pad the - input so that output_shape[i] = ceil(input_shape[i] / strides[i]) for each axis i. The padding is split between - the two sides equally or almost equally (depending on whether it is even or odd). In case the padding is an odd - number, the extra padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER. - - - ceil_mode int (default is '1'): Whether to use ceil or floor (default) to compute the output shape. - - - dilations list of ints: Dilation value along each spatial axis of filter. - - - kernel_shape list of ints: The size of the kernel along each axis. - - - pads list of ints: Padding for the beginning and ending along each spatial axis, it can take any value greater - than or equal to 0. The value represent the number of pixels added to the beginning and end part of the - corresponding axis. `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], where xi_begin - the number of pixels added at the beginning of axis `i` and xi_end, the number of pixels added at the end of axis - `i`. This attribute cannot be used simultaneously with auto_pad attribute. If not present, the padding defaults - to 0 along start and end of each spatial axis. - - - storage_order int (default is '0') : The storage order of the tensor. 0 is row major, and 1 is column major. - This attribute is used only to convert an n-tuple index value into a single integer value for producing the - second output. - - - strides list of ints: Stride along each spatial axis. If not present, the stride defaults to 1 along each spatial axis - - """ - - name = "onnx.MaxPoolSingleOut" - - T: ClassVar = VarConstraint("T", AnyFloatConstr | base(IntegerType)) - data = operand_def(TensorOrMemrefOf(T)) - output = result_def(TensorOrMemrefOf(T)) - - auto_pad = attr_def(StringAttr) - ceil_mode = attr_def(AnyIntegerAttr) - dilations = attr_def(ArrayAttr[AnyIntegerAttr]) - kernel_shape = attr_def(ArrayAttr[AnyIntegerAttr]) - pads = attr_def(ArrayAttr[AnyIntegerAttr]) - storage_order = attr_def(AnyIntegerAttr) - strides = attr_def(ArrayAttr[AnyIntegerAttr]) - - assembly_format = ( - "`(` $data`)` attr-dict `:` `(` type($data) `)` `->` type($output)" - ) - - def __init__( - self, - data: SSAValue, - auto_pad: Attribute, - ceil_mode: Attribute, - dilations: Attribute, - kernel_shape: Attribute, - pads: Attribute, - storage_order: Attribute, - strides: Attribute, - ): - super().__init__( - attributes={ - "auto_pad": auto_pad, - "ceil_mode": ceil_mode, - "dilations": dilations, - "kernel_shape": kernel_shape, - "pads": pads, - "storage_order": storage_order, - "strides": strides, - }, - operands=[data], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType | MemRefType) - assert isinstance(output_type := self.output.type, TensorType | MemRefType) - - data_type = cast(TensorType[Attribute], data_type) - output_type = cast(TensorType[Attribute], output_type) - - # auto pad - auto_pad_strings = ["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] - if self.auto_pad.data not in auto_pad_strings: - raise VerifyException( - f"Invalid auto_pad string. Must be one of {auto_pad_strings}" - ) - - # ceil mode - if self.ceil_mode.value.data < 0 or self.ceil_mode.value.data > 1: - raise VerifyException("ceil value must be either zero or one") - - # kernel shape - if (input_dims := len(data_type.get_shape()) - 2) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"input data and kernel shape rank mismatch: ({input_dims}) vs ({kernel_dims})" - ) - - # dilations - for value in self.dilations: - val = value.value.data - if val <= 0: - raise VerifyException("dilation value must be non zero positive") - - if (dilations_dims := len(self.dilations)) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"dilations rank ({dilations_dims}) and kernel shape rank ({kernel_dims}) are not the " - f"same " - ) - - # storage order - # Not supported for storage order in column major mode in onnx-mlir (therefore row major mode only considered) - if self.storage_order.value.data != 0: - raise VerifyException("column major storage order not implemented yet") - - # strides - for value in self.strides: - val = value.value.data - if val <= 0: - raise VerifyException("stride value must be non zero positive") - - if (strides_dims := len(self.strides)) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"strides rank ({strides_dims}) and kernel shape rank ({kernel_dims}) are not the " - f"same " - ) - - # pads - for value in self.pads: - val = value.value.data - if val < 0: - raise VerifyException("pads value must be nonnegative") - - if (pads_dims := len(self.pads)) != 2 * len(self.kernel_shape): - raise VerifyException( - f"pads rank ({pads_dims}) is not twice the kernel shape rank ({len(self.kernel_shape)})" - ) - - -@irdl_op_definition -class EntryPointOp(IRDLOperation): - """ - Indicate ONNX entry point - The "onnx.EntryPoint" function indicates the main entry point of ONNX model. - """ - - name = "onnx.EntryPoint" - func = attr_def(SymbolRefAttr) - - def __init__(self, func: Attribute): - super().__init__( - attributes={ - "func": func, - }, - ) - - -@irdl_op_definition -class MatMulOp(IRDLOperation): - """ - The operation MatMul performs matrix multiplication between two input matrices, A and B, and returns the result as matrix Y. - Matrix multiplication is a fundamental operation in linear algebra, where each element of the resulting matrix Y is computed by taking the - dot product of the corresponding row of matrix A and column of matrix B. - """ - - name = "onnx.MatMul" - - # describe annotated type - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - - # input matrices - matrix_A = operand_def(TensorType[T]) - matrix_B = operand_def(TensorType[T]) - - # output matrices - matrix_Y = result_def(TensorType[T]) - - assembly_format = ( - "`(` $matrix_A `,` $matrix_B `)` attr-dict `:` `(` type($matrix_A) `," - "` type($matrix_B) `)` `->` type($matrix_Y) " - ) - - def __init__( - self, - matrix_A: SSAValue, - matrix_B: SSAValue, - matrix_Y_type: Attribute, - ): - super().__init__( - operands=[matrix_A, matrix_B], - result_types=[matrix_Y_type], - ) - - def verify_(self) -> None: - # store dimensions of tensor A and tensor B - res_shape: list[int] = [] - matrix_A_type = cast(TensorType[Any], self.matrix_A.type) - matrix_B_type = cast(TensorType[Any], self.matrix_B.type) - matrix_Y_type = cast(TensorType[Any], self.matrix_Y.type) - - # check shape compatibility - matrix_A_shape = matrix_A_type.get_shape() - matrix_B_shape = matrix_B_type.get_shape() - - if matrix_A_type.get_num_dims() != 2: - raise VerifyException("input matrix A should be a 2D tensor") - - if matrix_B_type.get_num_dims() != 2: - raise VerifyException("input matrix B should be a 2D tensor") - - if matrix_A_shape[1] != matrix_B_shape[0]: - raise VerifyException( - f"operands have incompatible shapes: {matrix_A_shape} and {matrix_B_shape}" - ) - else: - res_shape.append(matrix_A_shape[0]) - res_shape.append(matrix_B_shape[1]) - - matrix_Y_type_shape = list(matrix_Y_type.get_shape()) - if ( - len(res_shape) != len(matrix_Y_type_shape) - or res_shape != matrix_Y_type_shape - ): - raise VerifyException( - f"result shape {res_shape} does not match result type {matrix_Y_type_shape}" - ) - - -@irdl_op_definition -class TransposeOp(IRDLOperation): - """ - The transpose_tensor function takes a tensor as input and returns its transpose. - Transposing a tensor means flipping its dimensions, so that rows become columns and vice versa. - """ - - name = "onnx.Transpose" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - tensor_input = operand_def(TensorType[T]) - - perm = opt_attr_def(ArrayAttr[AnyIntegerAttr], attr_name="perm") - - tensor_output = result_def(TensorType[T]) - - assembly_format = ( - "`(` $tensor_input `)` attr-dict `:` `(` type($tensor_input) " - "`)` `->` type($tensor_output) " - ) - - def __init__(self, tensor_input: SSAValue, perm: Attribute): - super().__init__( - attributes={"perm": perm}, - operands=[tensor_input], - result_types=[tensor_input.type], - ) - - def verify_(self) -> None: - assert isinstance(tensor_input_type := self.tensor_input.type, TensorType) - assert isinstance(tensor_output_type := self.tensor_output.type, TensorType) - - tensor_input_shape = tensor_input_type.get_shape() - tensor_output_shape = tensor_output_type.get_shape() - - # numbers in perm cannot be repeated - if self.perm is not None: - for _, int_attr in enumerate(self.perm.data): - attr_value = int_attr.value.data - count = self.perm.data.count(int_attr) - if count != 1: - raise VerifyException( - f"permutation can not contain more than one occurrence of the same dimension: dimension #{attr_value} appears {count} times." - ) - - # numbers in perm must be between 0 and len(tensor_input_shape)-1 - perm_size = len(self.perm.data) - for int_index, int_attr in enumerate(self.perm.data): - int_index = int_index + 0 - int_attr_val = int_attr.value.data - if int_attr_val < 0 or int_attr_val >= perm_size: - raise VerifyException( - f"permutation can only contain values between 0 and {perm_size}-1: dimension #{int_index} value is {int_attr_val}" - ) - - # len(tensor_input_shape) must be equal to len(perm) - perm_size = len(self.perm.data) - input_size = len(tensor_input_shape) - if perm_size != input_size: - raise VerifyException( - f"permutation and inputs dimensions must have the same size: #dimensions input is {input_size}, #dimension perimutation is {perm_size}" - ) - - # check output shape - for index_attr, int_attr in enumerate(self.perm.data): - int_attr_val = int_attr.value.data - if tensor_output_shape[index_attr] != tensor_input_shape[int_attr_val]: - raise VerifyException( - f"incorrect output shape: output dimension #{index_attr} should be equal to {tensor_input_shape[int_attr_val]}" - ) - - -@irdl_op_definition -class SqueezeOp(IRDLOperation): - """ - Squeeze the input tensor along the specified axes. - - Squeezing a tensor removes dimensions of size 1, effectively reducing the rank of the tensor and collapsing those dimensions. - This operation is particularly useful for removing unnecessary singleton dimensions, which may arise from broadcasting or previous operations. - - Args: - input_tensor: The input tensor to be squeezed. This tensor should be a multi-dimensional array-like object. - axes: A list of axes along which to squeeze the tensor. If provided, only the specified axes will be squeezed. If not provided, all dimensions of size 1 will be squeezed. - - Returns: - output_tensor: The squeezed tensor. - """ - - name = "onnx.Squeeze" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - input_tensor = operand_def(TensorType[T]) - axes = opt_attr_def(base(AnyIntegerAttr), attr_name="axes") - - output_tensor = result_def(TensorType[T]) - - assembly_format = "`(` $input_tensor `)` attr-dict `:` `(` type($input_tensor) `)` `->` type($output_tensor) " - - def __init__( - self, - input_tensor: SSAValue, - axes: Attribute, - ): - super().__init__( - attributes={ - "axes": axes, - }, - operands=[input_tensor], - result_types=[input_tensor.type], - ) - - def verify_(self) -> None: - assert isinstance(input_tensor_type := self.input_tensor.type, TensorType), ( - "onnx elementwise operation operands and result must be of type TensorType" - ) - - input_tensor_shape = input_tensor_type.get_shape() - - if self.axes is not None: - axes_value = self.axes.value.data - - # axes out of bounds: the axes value must between 0 and len(input_tensor.shape)-1 - if axes_value < 0 or axes_value >= len(input_tensor_shape): - max_axes_value = len(input_tensor_shape) - 1 - raise VerifyException( - f"axes to squeeze must be between 0 and {max_axes_value}, axes: {axes_value}" - ) - - -@irdl_op_definition -class SigmoidOp(IRDLOperation): - """ - Applies the sigmoid function element-wise to all elements of the input tensor. - The sigmoid function, denoted by sigma(x), is a common mathematical function used in machine learning and neural networks. It is defined as: - sigma(x) = 1 / (1 + e^-x) - where e is the base of the natural logarithm. The sigmoid function maps any real-valued number to the range of [0, 1]. - The sigmoid function is used as an activation function. - - Args: - - input_tensor (TensorType): The input tensor to which the sigmoid function will be applied. - - Returns: - - output_tensor (TensorType): The output tensor after applying the sigmoid function element-wise to the input tensor. - """ - - name = "onnx.Sigmoid" - - T = Annotated[AnyFloat, ConstraintVar("T")] - input_tensor = operand_def(TensorType[T]) - output_tensor = result_def(TensorType[T]) - - assembly_format = "`(` $input_tensor`)` attr-dict `:` `(` type($input_tensor) `)` `->` type($output_tensor) " - - def __init__( - self, - input_tensor: SSAValue, - ): - super().__init__( - operands=[input_tensor], - result_types=[input_tensor.type], - ) - - def verify_(self) -> None: - assert isinstance(input_tensor_type := self.input_tensor.type, TensorType) - assert isinstance(output_tensor_type := self.output_tensor.type, TensorType) - - input_tensor_shape = input_tensor_type.get_shape() - output_tensor_shape = output_tensor_type.get_shape() - - # check if input tensor and output tensor have the same shape - if input_tensor_shape != output_tensor_shape: - raise VerifyException( - f"tensor input shape {input_tensor_shape} is not equal to tensor output shape {output_tensor_shape}" - ) - - -ONNX = Dialect( - "onnx", - [ - AbsOp, - AddOp, - ConstantOp, - ConvOp, - DivOp, - EntryPointOp, - GemmOp, - MatMulOp, - MaxPoolSingleOutOp, - MulOp, - ReluOp, - ReshapeOp, - SubOp, - TransposeOp, - SqueezeOp, - SigmoidOp, - ], -) diff --git a/xdsl/frontend/onnx/__init__.py b/xdsl/frontend/onnx/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xdsl/frontend/onnx/ir_builder.py b/xdsl/frontend/onnx/ir_builder.py deleted file mode 100644 index 41c867f488..0000000000 --- a/xdsl/frontend/onnx/ir_builder.py +++ /dev/null @@ -1,93 +0,0 @@ -from onnx import GraphProto, NodeProto, ValueInfoProto - -from xdsl.dialects import func, onnx -from xdsl.dialects.builtin import ModuleOp -from xdsl.frontend.onnx.type import get_type -from xdsl.ir import Attribute, SSAValue -from xdsl.irdl import IRDLOperation - -OP_BY_OP_TYPE: dict[str, type[IRDLOperation]] = { - "Add": onnx.AddOp, - "Sub": onnx.SubOp, - "MatMul": onnx.MatMulOp, - "Transpose": onnx.TransposeOp, -} -"""Associate the name of the operations with the respective operation in ONNX dialect.""" - - -class OnnxXdslMapping: - """The representation of the onnx context.""" - - type_by_name: dict[str, Attribute] - value_by_name: dict[str, SSAValue] - - def __init__(self): - self.type_by_name = {} - self.value_by_name = {} - - -def visit_node(node: NodeProto, ctx: OnnxXdslMapping) -> IRDLOperation: - """Update the onnx context with the current node of the onnx graph.""" - if node.op_type not in OP_BY_OP_TYPE: - raise ValueError(f"Unknown ONNX op name {node.op_type}") - - op_class = OP_BY_OP_TYPE[node.op_type] - - operands = tuple(ctx.value_by_name[name] for name in node.input) - result_types = tuple(ctx.type_by_name[name] for name in node.output) - - op = op_class.build(operands=operands, result_types=result_types) - results = op.results - - for output_name, result in zip(node.output, results, strict=True): - ctx.value_by_name[output_name] = result - - return op - - -def visit_value_info(i: ValueInfoProto, ctx: OnnxXdslMapping) -> Attribute: - """Given the onnx ValueInforProto, it returns the corresponding Attribute stored in the context.""" - name = i.name - - if name in ctx.type_by_name: - return ctx.type_by_name[name] - - t = get_type(i.type) - ctx.type_by_name[name] = t - return t - - -def build_module(graph: GraphProto) -> ModuleOp: - """Create the ModuleOp based on the onnx graph provided.""" - - ctx = OnnxXdslMapping() - fn = visit_graph(graph, ctx) - module = ModuleOp([fn]) - - return module - - -def visit_graph(g: GraphProto, ctx: OnnxXdslMapping) -> IRDLOperation: - """ - Visit the onnx graph to update the onnx context. - """ - - name = g.name - - input_types = tuple(visit_value_info(input, ctx) for input in g.input) - output_types = tuple(visit_value_info(output, ctx) for output in g.output) - - fn = func.FuncOp(name, (input_types, output_types)) - - for input, arg in zip(g.input, fn.body.block.args, strict=True): - ctx.value_by_name[input.name] = arg - - for node in g.node: - results = visit_node(node, ctx) - fn.body.block.add_op(results) - - returned_values = tuple(ctx.value_by_name[output.name] for output in g.output) - retfn = func.ReturnOp(*returned_values) - fn.body.block.add_op(retfn) - - return fn diff --git a/xdsl/frontend/onnx/type.py b/xdsl/frontend/onnx/type.py deleted file mode 100644 index b09135bdce..0000000000 --- a/xdsl/frontend/onnx/type.py +++ /dev/null @@ -1,46 +0,0 @@ -from onnx import TensorShapeProto, TypeProto - -from xdsl.dialects.builtin import TensorType, f32, f64 -from xdsl.ir import Attribute - -ELEM_TYPE = { - 1: f32, - 11: f64, -} -""" -Dictionary containing information about the type. -""" - - -def get_elem_type(code: int) -> Attribute: - """ - It takes an ONNX tensor element type code as input and returns the corresponding Attribute. - """ - if code in ELEM_TYPE: - return ELEM_TYPE[code] - else: - raise ValueError(f"Unknown elem_type: {code}") - - -def get_type(type: TypeProto) -> Attribute: - """ - It takes the type in ONNX as input and returns the corresponding Attribute. - """ - tt = get_tensor_type(type.tensor_type) - return tt - - -def get_shape(shape: TensorShapeProto) -> tuple[int, ...]: - """ - It returns the shape of a tensor in ONNX. - """ - return tuple(dim.dim_value for dim in shape.dim) - - -def get_tensor_type(tensor: TypeProto.Tensor) -> TensorType[Attribute]: - """ - Function that returns the type of the tensor in ONNX. - """ - elem_type = get_elem_type(tensor.elem_type) - shape = get_shape(tensor.shape) - return TensorType(elem_type, shape) diff --git a/xdsl/interpreters/__init__.py b/xdsl/interpreters/__init__.py index 6cb1e68d80..1c597b8ad6 100644 --- a/xdsl/interpreters/__init__.py +++ b/xdsl/interpreters/__init__.py @@ -25,12 +25,7 @@ ) -def register_implementations( - interpreter: Interpreter, - ctx: MLContext, - *, - include_onnx: bool = True, -): +def register_implementations(interpreter: Interpreter, ctx: MLContext): interpreter.register_implementations(affine.AffineFunctions()) interpreter.register_implementations(arith.ArithFunctions()) interpreter.register_implementations(builtin.BuiltinFunctions()) @@ -52,7 +47,3 @@ def register_implementations( interpreter.register_implementations(scf.ScfFunctions()) interpreter.register_implementations(snitch_stream.SnitchStreamFunctions()) interpreter.register_implementations(tensor.TensorFunctions()) - if include_onnx: - from xdsl.interpreters import onnx - - interpreter.register_implementations(onnx.OnnxFunctions()) diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py deleted file mode 100644 index 2d26cc0b38..0000000000 --- a/xdsl/interpreters/onnx.py +++ /dev/null @@ -1,440 +0,0 @@ -from typing import Any, cast, overload - -import numpy as np -import numpy.typing as npt - -from xdsl.dialects import onnx -from xdsl.dialects.builtin import ( - Float32Type, - Float64Type, - IntAttr, - IntegerType, - PackableType, - TensorType, -) -from xdsl.interpreter import ( - Interpreter, - InterpreterFunctions, - ReturnedValues, - impl, - register_impls, -) -from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils import ptr -from xdsl.utils.exceptions import InterpretationError - - -def to_dtype( - xtype: PackableType[int] | PackableType[float], -) -> type[np.int32] | type[np.int64] | type[np.float32] | type[np.float64]: - match xtype: - case IntegerType(width=IntAttr(data=32)): - return np.int32 - case IntegerType(width=IntAttr(data=64)): - return np.int64 - case Float32Type(): - return np.float32 - case Float64Type(): - return np.float64 - case _: - raise NotImplementedError() - - -def from_dtype( - dtype: np.dtype[np.float32 | np.float64 | np.int32 | np.int64], -) -> PackableType[float] | PackableType[int]: - if dtype == np.float32: - return ptr.float32 - elif dtype == np.float64: - return ptr.float64 - elif dtype == np.float32: - return ptr.int32 - elif dtype == np.float64: - return ptr.int64 - else: - raise NotImplementedError() - - -@overload -def to_ndarray( - shaped_array: ShapedArray[float], -) -> npt.NDArray[np.float32 | np.float64]: ... - - -@overload -def to_ndarray( - shaped_array: ShapedArray[int], -) -> npt.NDArray[np.int32 | np.int64]: ... - - -def to_ndarray( - shaped_array: ShapedArray[int] | ShapedArray[float], -) -> npt.NDArray[np.float32 | np.float64 | np.int32 | np.int64]: - dtype = to_dtype(shaped_array.data_ptr.xtype) - flat = np.frombuffer(shaped_array.data_ptr.raw.memory, dtype) - shaped = flat.reshape(shaped_array.shape) - return shaped - - -def from_ndarray( - ndarray: npt.NDArray[np.float32 | np.float64 | np.int32 | np.int64], -) -> ShapedArray[float] | ShapedArray[int]: - xtype = from_dtype(np.dtype(ndarray.dtype)) - # TypedPtr's generic parameter is invariant, so ambiguous here - typed_ptr: ptr.TypedPtr[float] | ptr.TypedPtr[int] = ptr.TypedPtr( - ptr.RawPtr(bytearray(ndarray.data)), - xtype=xtype, # pyright: ignore[reportArgumentType] - ) - return ShapedArray( - typed_ptr, - list(ndarray.shape), - ) - - -@register_impls -class OnnxFunctions(InterpreterFunctions): - @impl(onnx.AddOp) - def run_add( - self, interpreter: Interpreter, op: onnx.AddOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) + to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.SubOp) - def run_sub( - self, interpreter: Interpreter, op: onnx.SubOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) - to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.MulOp) - def run_mul( - self, interpreter: Interpreter, op: onnx.MulOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) * to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.DivOp) - def run_div( - self, interpreter: Interpreter, op: onnx.DivOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) / to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.ReluOp) - def run_relu( - self, interpreter: Interpreter, op: onnx.ReluOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - operand = args[0] - assert isinstance(operand, ShapedArray) - operand = cast(ShapedArray[float], operand) - operand_data = to_ndarray(operand) - result = operand_data * (operand_data > 0) - return (from_ndarray(result),) - - @impl(onnx.ConstantOp) - def run_constant( - self, interpreter: Interpreter, op: onnx.ConstantOp, args: tuple[Any, ...] - ): - if op.value is None: - raise NotImplementedError("Only dense constant values implemented") - shape = op.value.get_shape() - data = op.value.get_values() - data_ptr = ptr.TypedPtr[Any].new( - data, - xtype=xtype_for_el_type( - op.value.get_element_type(), interpreter.index_bitwidth - ), - ) - return (ShapedArray(data_ptr, list(shape)),) - - @impl(onnx.ReshapeOp) - def run_reshape( - self, interpreter: Interpreter, op: onnx.ReshapeOp, args: tuple[Any, ...] - ): - if op.allow_zero is not None and op.allow_zero.value.data == 1: - raise NotImplementedError( - "allow_zero not yet supported in onnx.reshape interpreter" - ) - input, new_shape = args - assert isinstance(input, ShapedArray) - assert isinstance(new_shape, ShapedArray) - input = cast(ShapedArray[float], input) - new_shape = cast(ShapedArray[int], new_shape) - result_type = op.reshaped.type - assert isinstance(result_type, TensorType) - static_shape = list(result_type.get_shape()) - assert static_shape is not None - if static_shape != new_shape.data: - raise InterpretationError("Mismatch between static shape and new shape") - return (input.with_shape(new_shape.data),) - - @impl(onnx.GemmOp) - def run_gemm( - self, interpreter: Interpreter, op: onnx.GemmOp, args: tuple[Any, ...] - ): - a, b, c = args[0], args[1], args[2] - - alpha = op.alpha.value.data if op.alpha is not None else 1.0 - beta = op.beta.value.data if op.beta is not None else 1.0 - - assert isinstance(a, ShapedArray) - assert isinstance(b, ShapedArray) - assert isinstance(c, ShapedArray) - - a = cast(ShapedArray[float], a) - b = cast(ShapedArray[float], b) - c = cast(ShapedArray[float], c) - - nd_a = to_ndarray(a) - nd_b = to_ndarray(b) - nd_c = to_ndarray(c) - - if op.trans_a is not None and op.trans_a.value.data == 1: - nd_a = np.transpose(nd_a) - - if op.trans_b is not None and op.trans_b.value.data == 1: - nd_b = np.transpose(nd_b) - - result = alpha * nd_a @ nd_b + beta * nd_c - - return (from_ndarray(result),) - - @impl(onnx.ConvOp) - def run_conv( - self, interpreter: Interpreter, op: onnx.ConvOp, args: tuple[Any, ...] - ): - # initialise the attributes used - auto_pad = op.auto_pad.data - strides: list[int] = [value.value.data for value in op.strides] - matrix, kernel, bias = args[0], args[1], args[2] - pads: list[int] = [value.value.data for value in op.pads] - - matrix = cast(ShapedArray[float], matrix) - kernel = cast(ShapedArray[float], kernel) - bias = cast(ShapedArray[float], bias) - - matrix = to_ndarray(matrix) - kernel = to_ndarray(kernel) - - if auto_pad != "NOTSET": - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_height = int(np.ceil(matrix.shape[2] / strides[0])) - out_width = int(np.ceil(matrix.shape[3] / strides[1])) - - pad_along_height = max( - (out_height - 1) * strides[0] + kernel.shape[2] - matrix.shape[2], 0 - ) - pad_along_width = max( - (out_width - 1) * strides[1] + kernel.shape[3] - matrix.shape[3], 0 - ) - - if auto_pad == "SAME_UPPER": - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - else: - pad_bottom = pad_along_height // 2 - pad_top = pad_along_height - pad_bottom - pad_right = pad_along_width // 2 - pad_left = pad_along_width - pad_right - - pads = [pad_top, pad_bottom, pad_left, pad_right] - - elif auto_pad == "VALID": - pads = [0, 0, 0, 0] # set padding to all zeros - - if pads: - # case of asymmetric padding - pad_values = [ - (pads[i], pads[i + len(pads) // 2]) for i in range(len(pads) // 2) - ] - - # pad input matrix - padded_matrix = np.pad( - matrix, - ( - (0, 0), - (0, 0), - (pad_values[0][0], pad_values[0][1]), - (pad_values[1][0], pad_values[1][1]), - ), - mode="constant", - ) - - # padded shape case - m_height, m_width = padded_matrix.shape[2:] - - else: - m_height, m_width = matrix.shape[2:] - - padded_matrix = matrix - - # based on strides calculate the output shape - out_height = int((m_height - kernel.shape[2]) // strides[0] + 1) - out_width = int((m_width - kernel.shape[3]) // strides[1] + 1) - - output = np.zeros( - (matrix.shape[0], matrix.shape[1], out_height, out_width), - dtype=matrix.dtype, - ) - - # do convolution - for k in range(matrix.shape[0]): - for l in range(matrix.shape[1]): - for i in range(0, m_height - kernel.shape[2] + 1, strides[0]): - for j in range(0, m_width - kernel.shape[3] + 1, strides[1]): - output[k, l, i // strides[0], j // strides[1]] = np.sum( - padded_matrix[ - k, l, i : i + kernel.shape[2], j : j + kernel.shape[3] - ] - * kernel[k, l] - ) - - output += to_ndarray(bias) - - # the number of channels is not always fixed to one - result_type = op.res.type - assert isinstance(result_type, TensorType) - static_shape = list(result_type.get_shape()) - - result = np.array(output) - assert tuple(result.shape) == ( - 1, - static_shape[1], - output.shape[2], - output.shape[3], - ) - return (from_ndarray(result),) - - @impl(onnx.MaxPoolSingleOutOp) - def run_max_pool_single_out( - self, - interpreter: Interpreter, - op: onnx.MaxPoolSingleOutOp, - args: tuple[Any, ...], - ): - kernel_shape = tuple(value.value.data for value in op.kernel_shape) - - if len(kernel_shape) != 2: - raise NotImplementedError("Only 2d max pooling supported") - ky, kx = kernel_shape - - strides = tuple(value.value.data for value in op.strides) - - if len(strides) != 2: - raise NotImplementedError("Only 2d max pooling supported") - - # initialise the attributes used - auto_pad = op.auto_pad.data - (matrix,) = args - pads: list[int] = [value.value.data for value in op.pads] - - matrix = cast(ShapedArray[float], matrix) - - matrix = to_ndarray(matrix) - - if auto_pad != "NOTSET": - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_height = int(np.ceil(matrix.shape[2] / strides[0])) - out_width = int(np.ceil(matrix.shape[3] / strides[1])) - - pad_along_height = max( - (out_height - 1) * strides[0] + ky - matrix.shape[2], 0 - ) - pad_along_width = max( - (out_width - 1) * strides[1] + kx - matrix.shape[3], 0 - ) - - if auto_pad == "SAME_UPPER": - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - else: - pad_bottom = pad_along_height // 2 - pad_top = pad_along_height - pad_bottom - pad_right = pad_along_width // 2 - pad_left = pad_along_width - pad_right - - pads = [pad_top, pad_bottom, pad_left, pad_right] - - elif auto_pad == "VALID": - pads = [0, 0, 0, 0] # set padding to all zeros - - if pads: - # case of asymmetric padding - pad_values = [ - (pads[i], pads[i + len(pads) // 2]) for i in range(len(pads) // 2) - ] - - # pad input matrix - padded_matrix = np.pad( - matrix, - ( - (0, 0), - (0, 0), - (pad_values[0][0], pad_values[0][1]), - (pad_values[1][0], pad_values[1][1]), - ), - mode="constant", - ) - - # padded shape case - m_height, m_width = padded_matrix.shape[2:] - - else: - m_height, m_width = matrix.shape[2:] - - padded_matrix = matrix - - # based on strides calculate the output shape - out_height = int((m_height - ky) // strides[0] + 1) - out_width = int((m_width - kx) // strides[1] + 1) - - result = np.zeros( - (matrix.shape[0], matrix.shape[1], out_height, out_width), - dtype=matrix.dtype, - ) - - # do maxpool computation - for k in range(matrix.shape[0]): - for l in range(matrix.shape[1]): - for i in range(0, m_height - ky + 1, strides[0]): - for j in range(0, m_width - kx + 1, strides[1]): - result[k, l, i // strides[0], j // strides[1]] = np.nanmax( - padded_matrix[k, l, i : i + ky, j : j + kx] - ) - - # Numpy has two types of ndarray: ndarray and NDArray, weirdly they don't seem - # to be compatible, despite one being a typealias for the other... - output: Any = result - return (from_ndarray(output),) - - @impl(onnx.EntryPointOp) - def run_entry_point( - self, interpreter: Interpreter, op: onnx.EntryPointOp, args: tuple[Any, ...] - ): - return ReturnedValues(args), () diff --git a/xdsl/tools/xdsl_run.py b/xdsl/tools/xdsl_run.py index 96c66b302d..71dc56f302 100644 --- a/xdsl/tools/xdsl_run.py +++ b/xdsl/tools/xdsl_run.py @@ -35,12 +35,6 @@ def __init__( self.ctx.allow_unregistered = self.args.allow_unregistered_dialect def register_all_arguments(self, arg_parser: argparse.ArgumentParser): - arg_parser.add_argument( - "--onnx", - default=False, - action="store_true", - help="Enable the onnx-compilation interpreter.", - ) arg_parser.add_argument( "--verbose", default=False, @@ -71,11 +65,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): return super().register_all_arguments(arg_parser) def register_implementations(self, interpreter: Interpreter): - register_implementations( - interpreter, - self.ctx, - include_onnx=self.args.onnx, - ) + register_implementations(interpreter, self.ctx) def run(self): input, file_extension = self.get_input_stream() diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index 81a12e06e1..aabe89eda9 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -98,11 +98,6 @@ def get_convert_ml_program_to_memref(): return convert_ml_program_to_memref.ConvertMlProgramToMemrefPass - def get_convert_onnx_to_linalg(): - from xdsl.transforms import convert_onnx_to_linalg - - return convert_onnx_to_linalg.ConvertOnnxToLinalgPass - def get_convert_print_format_to_riscv_debug(): from xdsl.backend.riscv.lowering import convert_print_format_to_riscv_debug @@ -514,7 +509,6 @@ def get_varith_fuse_repeated_operands(): "convert-memref-to-ptr": get_convert_memref_to_ptr, "convert-memref-to-riscv": get_convert_memref_to_riscv, "convert-ml-program-to-memref": get_convert_ml_program_to_memref, - "convert-onnx-to-linalg": get_convert_onnx_to_linalg, "convert-print-format-to-riscv-debug": get_convert_print_format_to_riscv_debug, "convert-ptr-to-riscv": get_convert_ptr_to_riscv, "convert-qref-to-qssa": get_convert_qref_to_qssa, diff --git a/xdsl/transforms/constant_fold_interp.py b/xdsl/transforms/constant_fold_interp.py index acb4c621f4..d550d5fc09 100644 --- a/xdsl/transforms/constant_fold_interp.py +++ b/xdsl/transforms/constant_fold_interp.py @@ -85,7 +85,6 @@ class ConstantFoldInterpPass(ModulePass): def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: interpreter = Interpreter(op) - # Do not call onnx interpreter function for this pass - register_implementations(interpreter, ctx, include_onnx=False) + register_implementations(interpreter, ctx) pattern = ConstantFoldInterpPattern(interpreter) PatternRewriteWalker(pattern).rewrite_module(op) diff --git a/xdsl/transforms/convert_onnx_to_linalg.py b/xdsl/transforms/convert_onnx_to_linalg.py deleted file mode 100644 index 59845e5f54..0000000000 --- a/xdsl/transforms/convert_onnx_to_linalg.py +++ /dev/null @@ -1,415 +0,0 @@ -from dataclasses import dataclass -from typing import cast - -from xdsl.builder import ImplicitBuilder -from xdsl.context import MLContext -from xdsl.dialects import arith, linalg, ml_program, onnx, tensor -from xdsl.dialects.builtin import ( - AffineMapAttr, - AnyFloat, - DenseArrayBase, - DenseIntOrFPElementsAttr, - FloatAttr, - ModuleOp, - NoneType, - StringAttr, - SymbolRefAttr, - TensorType, - f32, - f64, - i64, -) -from xdsl.ir import Attribute, Block, Operation, Region -from xdsl.ir.affine import AffineMap -from xdsl.passes import ModulePass -from xdsl.pattern_rewriter import ( - GreedyRewritePatternApplier, - PatternRewriter, - PatternRewriteWalker, - RewritePattern, - op_type_rewrite_pattern, -) -from xdsl.traits import SymbolTable - - -def get_root_op(op: Operation | None) -> Operation | None: - """ - Recursively finds and returns the root operation associated with the given operation. - """ - return op if op is None or op.parent_op() is None else get_root_op(op.parent_op()) - - -@dataclass -class AddOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, add: onnx.AddOp, rewriter: PatternRewriter, /): - lhs_type = add.lhs.type - rhs_type = add.rhs.type - if isinstance(lhs_type, TensorType) and isinstance(rhs_type, TensorType): - lhs_shape = lhs_type.get_shape() - rhs_shape = rhs_type.get_shape() - - if -1 in lhs_shape or -1 in rhs_shape: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), add.res.type), - linalg.AddOp((add.lhs, add.rhs), (empty.tensor,), res=(add.res.type,)), - ) - ) - - -@dataclass -class SubOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, sub: onnx.SubOp, rewriter: PatternRewriter, /): - lhs_type = sub.lhs.type - rhs_type = sub.rhs.type - if isinstance(lhs_type, TensorType) and isinstance(rhs_type, TensorType): - lhs_shape = lhs_type.get_shape() - rhs_shape = rhs_type.get_shape() - - if -1 in lhs_shape or -1 in rhs_shape: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), sub.res.type), - linalg.SubOp((sub.lhs, sub.rhs), (empty.tensor,), res=(sub.res.type,)), - ) - ) - - -@dataclass -class ReluOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, relu: onnx.ReluOp, rewriter: PatternRewriter, /): - operand = relu.operand.type - assert isinstance(operand, TensorType) - operand = cast(TensorType[Attribute], operand) - operand_rank = len(operand.get_shape()) - body = Region(Block(arg_types=(operand.element_type, operand.element_type))) - affine_map = AffineMapAttr(AffineMap.identity(operand_rank)) - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), relu.res.type), - zero := arith.ConstantOp( - FloatAttr(0.0, cast(AnyFloat, operand.element_type)) - ), - linalg.GenericOp( - (relu.operand,), - (empty.tensor,), - body, - (affine_map, affine_map), - (linalg.IteratorTypeAttr.parallel(),) * operand_rank, - (relu.res.type,), - ), - ) - ) - with ImplicitBuilder(body) as (a, _): - max_op = arith.MaximumfOp(a, zero.result) - linalg.YieldOp(max_op.result) - - -@dataclass -class ConstantOpLowering(RewritePattern): - constant_count: int = 0 - - def make_unique_name(self): - self.constant_count += 1 - return f"onnx_constant_{self.constant_count}" - - @op_type_rewrite_pattern - def match_and_rewrite( - self, constant: onnx.ConstantOp, rewriter: PatternRewriter, / - ): - attr_value = list(constant.attributes.values())[1] - constant_name = self.make_unique_name() - global_op = ml_program.GlobalOp( - StringAttr(constant_name), - constant.output.type, - None, - attr_value, - StringAttr("private"), - ) - root_op = get_root_op(constant) - if root_op is not None and root_op.has_trait(SymbolTable): - SymbolTable.insert_or_update(root_op, global_op) - rewriter.replace_matched_op( - ( - ml_program.GlobalLoadConstantOp( - SymbolRefAttr(global_op.sym_name), - global_op.type, - ), - ) - ) - - -@dataclass -class ReshapeOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, reshape: onnx.ReshapeOp, rewriter: PatternRewriter, /): - # Dynamic shapes not currently supported - source_type = reshape.data.type - shape_type = reshape.shape.type - if isinstance(source_type, TensorType) and isinstance(shape_type, TensorType): - source_shape = source_type.get_shape() - shape_shape = shape_type.get_shape() - - if -1 in source_shape or -1 in shape_shape: - raise NotImplementedError() - - # Lowering with `allowzero = 1` attribute not supported" - if reshape.allow_zero is not None and reshape.allow_zero.value.data == 1: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - tensor.ReshapeOp( - reshape.data, - reshape.shape, - reshape.reshaped.type, - ), - ) - ) - - -@dataclass -class GemmOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, gemm: onnx.GemmOp, rewriter: PatternRewriter, /): - assert isinstance(tensor_a_type := gemm.tensor_a.type, TensorType) - assert isinstance(tensor_b_type := gemm.tensor_b.type, TensorType) - assert isinstance(tensor_c_type := gemm.tensor_c.type, TensorType) - - tensor_a_type = cast(TensorType[Attribute], tensor_a_type) - tensor_b_type = cast(TensorType[Attribute], tensor_b_type) - tensor_c_type = cast(TensorType[Attribute], tensor_c_type) - - tensor_a_shape = tensor_a_type.get_shape() - tensor_b_shape = tensor_b_type.get_shape() - tensor_c_shape = tensor_c_type.get_shape() - - # Dynamic shapes not currently supported - if any( - -1 in shape for shape in [tensor_a_shape, tensor_b_shape, tensor_c_shape] - ): - raise NotImplementedError() - - perm: list[int] = [1, 0] - permutation = DenseArrayBase.create_dense_int(i64, perm) - - # if transA is set, trans_a is changed to this op - trans_a_res = None - if gemm.trans_a is not None and gemm.trans_a.value.data == 1: - shape_type = tensor_a_type.element_type - # onnx.gemm supports only 2D tensors, hence reversing is acceptable - shape = tuple(reversed(tensor_a_shape)) - empty_shape = TensorType(shape_type, shape) - empty = tensor.EmptyOp((), empty_shape) - trans_a = linalg.TransposeOp( - gemm.tensor_a, empty.tensor, permutation, empty.tensor.type - ) - # save the result - trans_a_res = trans_a.result[0] - rewriter.insert_op_before_matched_op([empty, trans_a]) - - # if transB is set, trans_b is changed to this op - trans_b_res = None - if gemm.trans_b is not None and gemm.trans_b.value.data == 1: - shape_type = tensor_b_type.element_type - # onnx.gemm supports only 2D tensors, hence reversing is acceptable - shape = tuple(reversed(tensor_b_shape)) - empty_shape = TensorType(shape_type, shape) - empty = tensor.EmptyOp((), empty_shape) - trans_b = linalg.TransposeOp( - gemm.tensor_b, empty.tensor, permutation, empty.tensor.type - ) - # save the result - trans_b_res = trans_b.result[0] - rewriter.insert_op_before_matched_op([empty, trans_b]) - - # if trans_a occurs, else remain - if trans_a_res is not None: - trans_a = trans_a_res - else: - trans_a = gemm.tensor_a - - # if trans_b occurs, else remain - if trans_b_res is not None: - trans_b = trans_b_res - else: - trans_b = gemm.tensor_b - - # alpha * A - alpha_res = None - if gemm.alpha is not None and gemm.alpha.value.data != 1: - constant = arith.ConstantOp( - FloatAttr(gemm.alpha.value.data, gemm.alpha.type) - ) - alpha_mul_result = linalg.MulOp( - (constant.result, trans_a), - (trans_a,), - (trans_a.type,), - ) - alpha_res = alpha_mul_result.res[0] - rewriter.insert_op_before_matched_op([constant, alpha_mul_result]) - - # if alpha * a does not occur remain on previous trans_a else switch - if alpha_res is not None: - trans_a = alpha_res - - # beta * C - beta_mul_result = gemm.tensor_c - beta_res = None - if gemm.beta is not None and gemm.beta.value.data != 1: - constant = arith.ConstantOp(FloatAttr(gemm.beta.value.data, gemm.beta.type)) - beta_mul_result = linalg.MulOp( - ( - constant.result, - beta_mul_result, - ), - (gemm.tensor_c,), - (gemm.tensor_c.type,), - ) - beta_res = beta_mul_result.res[0] - rewriter.insert_op_before_matched_op([constant, beta_mul_result]) - - # this is beta * c result else its just c - if beta_res is not None: - beta_mul_result = beta_res - else: - beta_mul_result = gemm.tensor_c - - # (A * B) + beta * C - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp( - (), - gemm.res_tensor.type, - ), - # A * B - mat_mul_res := linalg.MatmulOp( - (trans_a, trans_b), - (empty.tensor,), - res=(gemm.res_tensor.type,), - ), - # (A * B) + beta * C - linalg.AddOp( - (mat_mul_res.results[0], beta_mul_result), - (mat_mul_res.results[0],), - res=(mat_mul_res.results[0].type,), - ), - ) - ) - - -@dataclass -class MaxPoolSingleOutOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite( - self, max_pool_single_out: onnx.MaxPoolSingleOutOp, rewriter: PatternRewriter, / - ): - kernel: list[int] = [ - value.value.data for value in max_pool_single_out.kernel_shape.data - ] - dilations: list[int] = [ - value.value.data for value in max_pool_single_out.dilations.data - ] - strides: list[int] = [ - value.value.data for value in max_pool_single_out.strides.data - ] - kernel_shape = TensorType(f32, kernel) - - # Lowering with `storage_order = 1` attribute not supported" - if ( - max_pool_single_out.storage_order.value.data != 0 - and max_pool_single_out.storage_order - ): - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), kernel_shape), - init := tensor.EmptyOp((), max_pool_single_out.output.type), - # Since we're unable to represent +/- infinity, - # we currently use the maximum value by sys - cst := arith.ConstantOp(FloatAttr(-1e308, f64)), - fill := linalg.FillOp( - (cst.result,), - (init.tensor,), - (max_pool_single_out.output.type,), - ), - linalg.PoolingNchwMaxOp( - DenseIntOrFPElementsAttr.tensor_from_list(dilations, i64, [2]), - DenseIntOrFPElementsAttr.tensor_from_list(strides, i64, [2]), - ( - max_pool_single_out.data, - empty.tensor, - ), - (fill.results[0],), - (max_pool_single_out.output.type,), - ), - ) - ) - - -@dataclass -class ConvOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, conv: onnx.ConvOp, rewriter: PatternRewriter, /): - dilations = tuple(value.value.data for value in conv.dilations.data) - strides = tuple(value.value.data for value in conv.strides.data) - - if conv.group.value.data != 1: - raise NotImplementedError("Only 1 group supported") - - if not all(dilation == 1 for dilation in dilations): - raise NotImplementedError("Only 1 dilation supported") - - empty = tensor.EmptyOp((), conv.res.type) - conv_op = linalg.Conv2DNchwFchwOp( - DenseIntOrFPElementsAttr.tensor_from_list(dilations, i64, [2]), - DenseIntOrFPElementsAttr.tensor_from_list(strides, i64, [2]), - ( - conv.data, - conv.weight, - ), - (empty.tensor,), - (conv.res.type,), - ) - conv_ops = ( - empty, - conv_op, - ) - if not isinstance(conv.bias.type, NoneType): - add_bias = linalg.AddOp( - (conv.bias,), - (conv_op.results[0],), - res=(conv.res.type,), - ) - conv_ops += (add_bias,) - rewriter.replace_matched_op(conv_ops) - - -@dataclass(frozen=True) -class ConvertOnnxToLinalgPass(ModulePass): - name = "convert-onnx-to-linalg" - - def apply(self, ctx: MLContext, op: ModuleOp) -> None: - PatternRewriteWalker( - GreedyRewritePatternApplier( - [ - AddOpLowering(), - SubOpLowering(), - ReluOpLowering(), - ConstantOpLowering(), - ReshapeOpLowering(), - GemmOpLowering(), - MaxPoolSingleOutOpLowering(), - ConvOpLowering(), - ] - ), - apply_recursively=False, - ).rewrite_module(op)