From e30ec3b89d5a6ec2c33dccff5d0060d162e6de0d Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 13 Jan 2025 15:16:21 +0000 Subject: [PATCH] dialects: (onnx) remove ONNX-related code (#3738) ONNX, which I was hoping would be a convenient way to define NNs in Python to then compile using xDSL, has proven to be a massive maintenance burden, and generally not very good. They are slow to support new Python versions, poor typing support, and the import mechanism is not very stable, so I couldn't even reliably import networks. Our code for it is fairly buggy, and the bug fixes weren't all merged before Kayode finished his masters project. Given Quidditch and IREE's support for importing ONNX, I don't see the point of having our implementation in xDSL, hence this PR. --- .github/workflows/ci-mlir.yml | 7 - Makefile | 22 +- README.md | 2 +- docs/marimo/__marimo__/linalg_snitch.ipynb | 4 +- docs/marimo/linalg_snitch.py | 4 +- docs/marimo/onnx/README.md | 4 - docs/marimo/onnx/__marimo__/onnx_demo.ipynb | 421 ------- docs/marimo/onnx/onnx_demo.py | 271 ---- pyproject.toml | 5 - tests/dialects/onnx/test_broadcast.py | 50 - .../filecheck/dialects/onnx/onnx_invalid.mlir | 571 --------- tests/filecheck/dialects/onnx/onnx_ops.mlir | 91 -- .../transforms/convert_onnx_to_linalg.mlir | 117 -- tests/frontend/onnx/test_build_onnx_ir.py | 377 ------ tests/frontend/onnx/test_type.py | 89 -- tests/interpreters/test_onnx_interpreter.py | 554 --------- uv.lock | 47 - xdsl/dialects/__init__.py | 6 - xdsl/dialects/onnx.py | 1106 ----------------- xdsl/frontend/onnx/__init__.py | 0 xdsl/frontend/onnx/ir_builder.py | 93 -- xdsl/frontend/onnx/type.py | 46 - xdsl/interpreters/__init__.py | 11 +- xdsl/interpreters/onnx.py | 440 ------- xdsl/tools/xdsl_run.py | 12 +- xdsl/transforms/__init__.py | 6 - xdsl/transforms/constant_fold_interp.py | 3 +- xdsl/transforms/convert_onnx_to_linalg.py | 415 ------- 28 files changed, 9 insertions(+), 4765 deletions(-) delete mode 100644 docs/marimo/onnx/README.md delete mode 100644 docs/marimo/onnx/__marimo__/onnx_demo.ipynb delete mode 100644 docs/marimo/onnx/onnx_demo.py delete mode 100644 tests/dialects/onnx/test_broadcast.py delete mode 100644 tests/filecheck/dialects/onnx/onnx_invalid.mlir delete mode 100644 tests/filecheck/dialects/onnx/onnx_ops.mlir delete mode 100644 tests/filecheck/transforms/convert_onnx_to_linalg.mlir delete mode 100644 tests/frontend/onnx/test_build_onnx_ir.py delete mode 100644 tests/frontend/onnx/test_type.py delete mode 100644 tests/interpreters/test_onnx_interpreter.py delete mode 100644 xdsl/dialects/onnx.py delete mode 100644 xdsl/frontend/onnx/__init__.py delete mode 100644 xdsl/frontend/onnx/ir_builder.py delete mode 100644 xdsl/frontend/onnx/type.py delete mode 100644 xdsl/interpreters/onnx.py delete mode 100644 xdsl/transforms/convert_onnx_to_linalg.py diff --git a/.github/workflows/ci-mlir.yml b/.github/workflows/ci-mlir.yml index ff7d69e298..6d4b1bb0d7 100644 --- a/.github/workflows/ci-mlir.yml +++ b/.github/workflows/ci-mlir.yml @@ -109,13 +109,6 @@ jobs: export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ uv run pytest --nbval docs/mlir_interoperation.ipynb --maxfail 1 -vv - - name: Test ONNX-dependent marimo notebooks - run: | - cd xdsl - # Add mlir-opt to the path - export PATH=$PATH:${GITHUB_WORKSPACE}/llvm-project/build/bin/ - uv run make tests-marimo-onnx - - name: Combine coverage data run: | cd xdsl diff --git a/Makefile b/Makefile index 80089b02ee..026a193278 100644 --- a/Makefile +++ b/Makefile @@ -98,30 +98,10 @@ tests-marimo: uv-installed done @echo "All marimo tests passed successfully." -.PHONY: tests-marimo-onnx -tests-marimo-onnx: uv-installed - @if uv run python -c "import onnx" > /dev/null 2>&1; then \ - echo "onnx is installed, running tests."; \ - if ! command -v mlir-opt > /dev/null 2>&1; then \ - echo "MLIR is not installed, skipping tests."; \ - exit 0; \ - fi; \ - for file in docs/marimo/onnx/*.py; do \ - echo "Running $$file"; \ - error_message=$$(uv run python3 "$$file" 2>&1) || { \ - echo "Error running $$file"; \ - echo "$$error_message"; \ - exit 1; \ - }; \ - done; \ - echo "All marimo onnx tests passed successfully."; \ - else \ - echo "onnx is not installed, skipping tests."; \ - fi # run all tests .PHONY: tests-functional -tests-functional: pytest tests-toy filecheck pytest-nb tests-marimo tests-marimo-onnx +tests-functional: pytest tests-toy filecheck pytest-nb tests-marimo @echo All functional tests done. # run all tests diff --git a/README.md b/README.md index 7ccf4e4e54..9886b3a809 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ In order to keep the set of dependencies ot a minimum, these extra dependencies specified explicitly. To install these, use: ``` bash -pip install xdsl[gui,jax,riscv,onnx] +pip install xdsl[gui,jax,riscv] ``` To install the testing/development dependencies, use: diff --git a/docs/marimo/__marimo__/linalg_snitch.ipynb b/docs/marimo/__marimo__/linalg_snitch.ipynb index 255bf55707..b3490c91e6 100644 --- a/docs/marimo/__marimo__/linalg_snitch.ipynb +++ b/docs/marimo/__marimo__/linalg_snitch.ipynb @@ -591,7 +591,7 @@ "riscv_op_counter = OpCounter()\n", "riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,))\n", "\n", - "register_implementations(riscv_interpreter, riscv_ctx, include_onnx=False)\n", + "register_implementations(riscv_interpreter, riscv_ctx)\n", "\n", "riscv_interpreter.call_op(\"matmul\", (a_shaped.data_ptr.raw, b_shaped.data_ptr.raw, riscv_c_shaped.data_ptr.raw))\n", "\n", @@ -633,7 +633,7 @@ "\n", "snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape)\n", "\n", - "register_implementations(snitch_interpreter, snitch_stream_ctx, include_onnx=False)\n", + "register_implementations(snitch_interpreter, snitch_stream_ctx)\n", "\n", "snitch_interpreter.call_op(\n", " \"matmul\",\n", diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index f1120277da..18d58e2f15 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -513,7 +513,7 @@ def _(TypedPtr, a_shape, b_shape, c_shape, mo, riscv_ctx, riscv_module): riscv_op_counter = OpCounter() riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,)) - register_implementations(riscv_interpreter, riscv_ctx, include_onnx=False) + register_implementations(riscv_interpreter, riscv_ctx) riscv_interpreter.call_op("matmul", (a_shaped.data_ptr.raw, b_shaped.data_ptr.raw, riscv_c_shaped.data_ptr.raw)) @@ -566,7 +566,7 @@ def _( snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape) - register_implementations(snitch_interpreter, snitch_stream_ctx, include_onnx=False) + register_implementations(snitch_interpreter, snitch_stream_ctx) snitch_interpreter.call_op( "matmul", diff --git a/docs/marimo/onnx/README.md b/docs/marimo/onnx/README.md deleted file mode 100644 index 648b1d0dc4..0000000000 --- a/docs/marimo/onnx/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Marimo notebooks that depend on ONNX and mlir-opt - -For these notebooks to run as intended, the `onnx` package needs to be installed and `mlir-opt` needs to be in the path. -Please see the [MLIR Interoperation](../../mlir_interoperation.md) document for more information. diff --git a/docs/marimo/onnx/__marimo__/onnx_demo.ipynb b/docs/marimo/onnx/__marimo__/onnx_demo.ipynb deleted file mode 100644 index fc4a9e1416..0000000000 --- a/docs/marimo/onnx/__marimo__/onnx_demo.ipynb +++ /dev/null @@ -1,421 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "Hbol", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "source": [ - "# ONNX to Snitch\n", - "\n", - "This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "MJUe", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "For example, here is a slider, which can take on values from 1 to 4.\n", - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "rank = mo.ui.slider(1, 4, value=2, label=\"Rank\")\n", - "\n", - "mo.md(\n", - " f\"\"\"\n", - " For example, here is a slider, which can take on values from 1 to 4.\n", - "\n", - " {rank}\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "vblA", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "We use the slider to determine the shape of our inputs and outputs:\n", - "
A: 2x3xf64\n",
-       "B: 2x3xf64\n",
-       "C: 2x3xf64\n",
-       "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "shape = tuple(range(2, 2 + rank.value))\n", - "\n", - "mo.md(\n", - " f\"\"\"\n", - " We use the slider to determine the shape of our inputs and outputs:\n", - "\n", - " ```\n", - " A: {'x'.join(str(dim) for dim in shape)}xf64\n", - " B: {'x'.join(str(dim) for dim in shape)}xf64\n", - " C: {'x'.join(str(dim) for dim in shape)}xf64\n", - " ```\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bkHC", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "

The ONNX model

\n", - "We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape (2, 3)
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "mo.md(\n", - " f\"\"\"\n", - " ### The ONNX model\n", - "\n", - " We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape}\n", - " \"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "lEQa", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "
Traceback (most recent call last):\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell\n",
-      "    exec(cell.body, glbls)\n",
-      "  File "/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_lEQa_.py", line 1, in <module>\n",
-      "    import onnx\n",
-      "ModuleNotFoundError: No module named 'onnx'\n",
-      "
\n", - "
" - ] - } - ], - "source": [ - "import onnx\n", - "from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "PKri", - "metadata": {}, - "outputs": [], - "source": [ - "# Create one input (ValueInfoProto)\n", - "X1 = helper.make_tensor_value_info(\"X1\", TensorProto.DOUBLE, shape)\n", - "X2 = helper.make_tensor_value_info(\"X2\", TensorProto.DOUBLE, shape)\n", - "\n", - "# Create one output (ValueInfoProto)\n", - "Y = helper.make_tensor_value_info(\"Y\", TensorProto.DOUBLE, shape)\n", - "\n", - "# Create a node (NodeProto) - This is based on Pad-11\n", - "node_def = helper.make_node(\n", - " \"Sub\", # node name\n", - " [\"X1\", \"X2\"], # inputs\n", - " [\"Y\"], # outputs\n", - ")\n", - "\n", - "# Create the graph (GraphProto)\n", - "graph_def = helper.make_graph(\n", - " [node_def],\n", - " \"main_graph\",\n", - " [X1, X2],\n", - " [Y],\n", - ")\n", - "\n", - "# Set opset version to 18\n", - "opset_import = [helper.make_operatorsetid(\"\", 18)]\n", - "\n", - "# Create the model (ModelProto) without using helper.make_model\n", - "model_def = helper.make_model(\n", - " graph_def, producer_name=\"onnx-example\", opset_imports=opset_import\n", - ")\n", - "\n", - "onnx.checker.check_model(model_def)" - ] - }, - { - "cell_type": "markdown", - "id": "Xref", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "source": [ - "ONNX uses a serialized binary format for neural networks, but can also print a string format, which can be useful for debugging.\n", - "Here is the textual format of our model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "SFPL", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "mo.accordion(\n", - " {\n", - " \"ONNX Graph\": mo.plain_text(f\"{model_def}\"),\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "BYtC", - "metadata": {}, - "outputs": [], - "source": [ - "mo.md(f\"\"\"\n", - "### Converting to `linalg`\n", - "\n", - "Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it:\n", - "\n", - "{xmo.module_html(init_module)}\n", - "\"\"\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "RGSE", - "metadata": {}, - "outputs": [], - "source": [ - "init_module = build_module(model_def.graph)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Kclp", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "ctx = MLContext()\n", - "\n", - "for dialect_name, dialect_factory in get_all_dialects().items():\n", - " ctx.register_dialect(dialect_name, dialect_factory)" - ] - }, - { - "cell_type": "markdown", - "id": "emfo", - "metadata": {}, - "source": [ - "xDSL seamlessly interoperates with MLIR, we the `mlir-opt` tool to compile the input to a form that we want to process:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "Hstk", - "metadata": {}, - "outputs": [], - "source": [ - "bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html(\n", - " ctx,\n", - " init_module,\n", - " (\n", - " (\n", - " mo.md(\n", - " \"\"\"\\\n", - "We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style:\n", - "\"\"\"\n", - " ),\n", - " ConvertOnnxToLinalgPass()\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"\n", - "We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation:\n", - "\"\"\"\n", - " ),\n", - " MLIROptPass(\n", - " generic=False,\n", - " arguments=[\"--linalg-generalize-named-ops\"]\n", - " )\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"We prepare the result tensors for bufferization:\"\"\"\n", - " ),\n", - " EmptyTensorToAllocTensorPass()\n", - " ),\n", - " (\n", - " mo.md(\n", - " \"\"\"We then use MLIR to bufferize our function:\"\"\"\n", - " ),\n", - " MLIROptPass(\n", - " arguments=[\n", - " \"--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map\",\n", - " ]\n", - " )\n", - " )\n", - " )\n", - ")\n", - "\n", - "linalg_html" - ] - }, - { - "cell_type": "markdown", - "id": "nWHF", - "metadata": {}, - "source": [ - "From here we can use a number of backends to generate executable code, like LLVM, or RISC-V assembly directly.\n", - "Please see other notebooks for details" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "iLit", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "
Traceback (most recent call last):\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py", line 141, in execute_cell\n",
-      "    exec(cell.body, glbls)\n",
-      "  File "/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_iLit_.py", line 2, in <module>\n",
-      "    from xdsl.frontend.onnx.ir_builder import build_module\n",
-      "  File "/Users/sasha/Developer/xdslproject/xdsl/xdsl/frontend/onnx/ir_builder.py", line 1, in <module>\n",
-      "    from onnx import GraphProto, NodeProto, ValueInfoProto\n",
-      "ModuleNotFoundError: No module named 'onnx'\n",
-      "
\n", - "
" - ] - } - ], - "source": [ - "from xdsl.context import MLContext\n", - "from xdsl.frontend.onnx.ir_builder import build_module\n", - "from xdsl.ir import Attribute, SSAValue\n", - "from xdsl.passes import PipelinePass\n", - "from xdsl.tools.command_line_tool import get_all_dialects\n", - "from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass\n", - "from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass\n", - "from xdsl.transforms.mlir_opt import MLIROptPass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ZHCJ", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "import marimo as mo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ROlb", - "metadata": { - "marimo": { - "config": { - "hide_code": true - } - } - }, - "outputs": [], - "source": [ - "import xdsl.utils.marimo as xmo" - ] - } - ], - "metadata": {}, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/marimo/onnx/onnx_demo.py b/docs/marimo/onnx/onnx_demo.py deleted file mode 100644 index c9b3ac6a77..0000000000 --- a/docs/marimo/onnx/onnx_demo.py +++ /dev/null @@ -1,271 +0,0 @@ -import marimo - -__generated_with = "0.10.9" -app = marimo.App(auto_download=["ipynb"]) - - -@app.cell(hide_code=True) -def _(mo): - mo.md( - """ - # ONNX to Snitch - - This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state. - """ - ) - return - - -@app.cell(hide_code=True) -def _(mo): - rank = mo.ui.slider(1, 4, value=2, label="Rank") - - mo.md( - f""" - For example, here is a slider, which can take on values from 1 to 4. - - {rank} - """ - ) - return (rank,) - - -@app.cell(hide_code=True) -def _(mo, rank): - shape = tuple(range(2, 2 + rank.value)) - - mo.md( - f""" - We use the slider to determine the shape of our inputs and outputs: - - ``` - A: {'x'.join(str(dim) for dim in shape)}xf64 - B: {'x'.join(str(dim) for dim in shape)}xf64 - C: {'x'.join(str(dim) for dim in shape)}xf64 - ``` - """ - ) - return (shape,) - - -@app.cell(hide_code=True) -def _(mo, shape): - mo.md( - f""" - ### The ONNX model - - We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape} - """ - ) - return - - -@app.cell(hide_code=True) -def _(): - import onnx - from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper - return ( - AttributeProto, - GraphProto, - TensorProto, - ValueInfoProto, - helper, - onnx, - ) - - -@app.cell -def _(TensorProto, helper, onnx, shape): - # Create one input (ValueInfoProto) - X1 = helper.make_tensor_value_info("X1", TensorProto.DOUBLE, shape) - X2 = helper.make_tensor_value_info("X2", TensorProto.DOUBLE, shape) - - # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, shape) - - # Create a node (NodeProto) - This is based on Pad-11 - node_def = helper.make_node( - "Sub", # node name - ["X1", "X2"], # inputs - ["Y"], # outputs - ) - - # Create the graph (GraphProto) - graph_def = helper.make_graph( - [node_def], - "main_graph", - [X1, X2], - [Y], - ) - - # Set opset version to 18 - opset_import = [helper.make_operatorsetid("", 18)] - - # Create the model (ModelProto) without using helper.make_model - model_def = helper.make_model( - graph_def, producer_name="onnx-example", opset_imports=opset_import - ) - - onnx.checker.check_model(model_def) - return X1, X2, Y, graph_def, model_def, node_def, opset_import - - -@app.cell(hide_code=True) -def _(mo): - mo.md( - """ - ONNX uses a serialized binary format for neural networks, but can also print a string format, which can be useful for debugging. - Here is the textual format of our model: - """ - ) - return - - -@app.cell(hide_code=True) -def _(mo, model_def): - mo.accordion( - { - "ONNX Graph": mo.plain_text(f"{model_def}"), - } - ) - return - - -@app.cell -def _(init_module, mo, xmo): - mo.md(f""" - ### Converting to `linalg` - - Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it: - - {xmo.module_html(init_module)} - """ - ) - return - - -@app.cell -def _(build_module, model_def): - init_module = build_module(model_def.graph) - return (init_module,) - - -@app.cell(hide_code=True) -def _(MLContext, get_all_dialects): - ctx = MLContext() - - for dialect_name, dialect_factory in get_all_dialects().items(): - ctx.register_dialect(dialect_name, dialect_factory) - return ctx, dialect_factory, dialect_name - - -@app.cell -def _(mo): - mo.md("""xDSL seamlessly interoperates with MLIR, we the `mlir-opt` tool to compile the input to a form that we want to process:""") - return - - -@app.cell -def _( - ConvertOnnxToLinalgPass, - EmptyTensorToAllocTensorPass, - MLIROptPass, - ctx, - init_module, - mo, - xmo, -): - bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html( - ctx, - init_module, - ( - ( - mo.md( - """\ - We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style: - """ - ), - ConvertOnnxToLinalgPass() - ), - ( - mo.md( - """ - We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation: - """ - ), - MLIROptPass( - generic=False, - arguments=["--linalg-generalize-named-ops"] - ) - ), - ( - mo.md( - """We prepare the result tensors for bufferization:""" - ), - EmptyTensorToAllocTensorPass() - ), - ( - mo.md( - """We then use MLIR to bufferize our function:""" - ), - MLIROptPass( - arguments=[ - "--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map", - ] - ) - ) - ) - ) - - linalg_html - return bufferized_ctx, bufferized_module, linalg_html - - -@app.cell -def _(mo): - mo.md( - """ - From here we can use a number of backends to generate executable code, like LLVM, or RISC-V assembly directly. - Please see other notebooks for details - """ - ) - return - - -@app.cell -def _(): - from xdsl.context import MLContext - from xdsl.frontend.onnx.ir_builder import build_module - from xdsl.ir import Attribute, SSAValue - from xdsl.passes import PipelinePass - from xdsl.tools.command_line_tool import get_all_dialects - from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass - from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass - from xdsl.transforms.mlir_opt import MLIROptPass - return ( - Attribute, - ConvertOnnxToLinalgPass, - EmptyTensorToAllocTensorPass, - MLContext, - MLIROptPass, - PipelinePass, - SSAValue, - build_module, - get_all_dialects, - ) - - -@app.cell(hide_code=True) -def _(): - import marimo as mo - return (mo,) - - -@app.cell(hide_code=True) -def _(): - import xdsl.utils.marimo as xmo - return (xmo,) - - -if __name__ == "__main__": - app.run() diff --git a/pyproject.toml b/pyproject.toml index 7b2deed320..0384b532ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dev = [ ] gui = ["textual==1.0.0", "pyclip==0.7"] jax = ["jax==0.4.38", "numpy==2.2.1"] -onnx = ["onnx==1.17.0", "numpy==2.2.1"] riscv = ["riscemu==2.2.7"] [project.urls] @@ -72,12 +71,8 @@ extraPaths = ["tests"] "include" = ["docs", "xdsl", "tests"] "exclude" = [ "tests/dialects/test_memref.py", - "tests/frontend/onnx/test_build_onnx_ir.py", - "tests/frontend/onnx/test_type.py", "tests/test_frontend_op_resolver.py", "tests/test_frontend_python_code_check.py", - "xdsl/frontend/onnx/ir_builder.py", - "xdsl/frontend/onnx/type.py", ] "ignore" = [ "docs/marimo", diff --git a/tests/dialects/onnx/test_broadcast.py b/tests/dialects/onnx/test_broadcast.py deleted file mode 100644 index c6e49b0c2a..0000000000 --- a/tests/dialects/onnx/test_broadcast.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from xdsl.dialects.onnx import ( - multidirectional_broadcast_shape, - unidirectional_broadcast_shape, -) - -m_test_cases: list[tuple[tuple[list[int], list[int]], list[int] | None]] = [ - (([2, 3, 4, 5], []), [2, 3, 4, 5]), - (([2, 3, 4, 5], [5]), [2, 3, 4, 5]), - (([4, 5], [2, 3, 4, 5]), [2, 3, 4, 5]), - (([1, 4, 5], [2, 3, 1, 5]), [2, 3, 4, 5]), - (([3, 4, 5], [2, 1, 1, 1]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [6, 1]), None), - (([1, 3, 4, 1], [5]), None), - (([4, 5, 6], [2, 3, 4, 5]), None), - (([1, 4, 5], [2, 3, 1, 5]), None), - (([3, 4, 5], [2, 1, 1, 1]), None), -] - -u_test_cases: list[tuple[tuple[list[int], list[int]], list[int] | None]] = [ - (([2, 3, 4, 5], []), [2, 3, 4, 5]), - (([2, 3, 4, 5], [5]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [2, 1, 1, 5]), [2, 3, 4, 5]), - (([2, 3, 4, 5], [1, 3, 1, 5]), [2, 3, 4, 5]), - (([], [2, 3, 4, 5]), None), - (([5], [2, 3, 4, 5]), None), - (([2, 1, 1, 5], [2, 3, 4, 5]), None), - (([2, 3, 5], [1, 3, 1, 5]), None), -] - - -# Multidirectional Broadcasting Tests -@pytest.mark.parametrize("input_shapes, expected_result", m_test_cases) -def multi_test_broadcast_shape( - input_shapes: tuple[list[int], list[int]], expected_result: list[int] | None -): - lhs, rhs = input_shapes - result = multidirectional_broadcast_shape(lhs, rhs) - assert result == expected_result - - -# Unidirectional Broadcasting Tests -@pytest.mark.parametrize("input_shapes, expected_result", u_test_cases) -def uni_test_broadcast_shape( - input_shapes: tuple[list[int], list[int]], expected_result: list[int] | None -): - lhs, rhs = input_shapes - result = unidirectional_broadcast_shape(lhs, rhs) - assert result == expected_result diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir deleted file mode 100644 index b15ba25761..0000000000 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ /dev/null @@ -1,571 +0,0 @@ -// RUN: xdsl-opt --verify-diagnostics --split-input-file %s | filecheck %s - -// Non-broadcastable operands are not allowed. - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<3x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (3, 2) - %res_add = "onnx.Add"(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<2x4xf32>, tensor<3x2xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<1x4xf32>) - - // CHECK: Operation does not verify: result shape [2, 4] does not match result type tensor<1x4xf32> - %res_sub = "onnx.Sub"(%t0, %t1) {onnx_node_name = "/Sub"} : (tensor<2x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (f32, tensor<2x4xf32>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_mul = "onnx.Mul"(%t0, %t1) {onnx_node_name = "/Mul"} : (f32, tensor<2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<2x4xi32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_div = "onnx.Div"(%t0, %t1) {onnx_node_name = "/Div"} : (tensor<2x4xf32>, tensor<2x4xi32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<2x4xf32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: Operation does not verify: Mismatch between operand type and res type of onnx.Relu - %res_relu = "onnx.Relu"(%t0) {onnx_node_name = "/Relu"} : (tensor<2x4xf32>) -> tensor<3x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<2x4xf32>, tensor<3x2xf32>, tensor<3x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (3, 2) - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (tensor<2x4xf32>, tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (f32, tensor<2x4xf32>,tensor<2x4xf32>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_gemm= "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (f32, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"(): () -> (tensor<5x2xf32>, tensor<2x1xf32>, tensor<5x4xf32>) - - //CHECK: Operation does not verify: result shape [5, 4] does not match result type tensor<5x2xf32> - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) : (tensor<5x2xf32>, tensor<2x1xf32>, tensor<5x4xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<2x4xf32>, tensor<2x4xi32>, tensor<2x4xf32>) - - // CHECK: operand at position 1 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"} : (tensor<2x4xf32>, tensor<2x4xi32>, tensor<2x4xf32>) -> tensor<2x4xf32> - - } - - // ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) - // CHECK: Operation does not verify: tensor A should be a 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"}: (tensor<5x3x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2x3xf32>, tensor<5x2xf32>) - // CHECK: Operation does not verify: tensor B should be a 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2x3xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1, %t2 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2x7xf32>) - // CHECK: Operation does not verify: tensor C should be a 1D tensor or 2D tensor - %res_gemm = "onnx.Gemm"(%t0, %t1, %t2) {onnx_node_name = "/Gemm", beta = 47.0 : f32}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2x7xf32>) -> tensor<5x3xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (f32, tensor<2x4xi64>) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"} : (f32, tensor<2x4xi64>) -> tensor<2x4xi64> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<4x3x2xf32>, tensor<1xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<4x3x2xf32>, tensor<1xi64>) -> tensor<4x3x2xi32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: Operation does not verify: Input tensor's shape and output tensor's shape must have the same number of elements - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi64>) -> tensor<6x9xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi32>) - - // CHECK: Operation does not verify: shape element type has to be a 64-bit signless integer - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi32>) -> tensor<6x9xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<6x9x5xf32>, tensor<3xi64>) - - // CHECK: result at position 0 does not verify: - // CHECK: vector<6x9x5xf32> should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (tensor<6x9x5xf32>, tensor<3xi64>) -> vector<6x9x5xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (vector<6x9x5xf32>, tensor<3xi64>) - - // CHECK: operand at position 0 does not verify: - // CHECK: vector<6x9x5xf32> should be of base attribute tensor - %res_reshape = "onnx.Reshape"(%t0, %t1) {"onnx_node_name" = "/Reshape"} : (vector<6x9x5xf32>, tensor<3xi64>) -> tensor<6x9x5xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<5x5xf32>, tensor<2x2xi64>) - - //CHECK: Operation does not verify: Shape tensor must have a rank one - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"}: (tensor<5x5xf32>, tensor<2x2xi64>) -> tensor<5x5xf32> - -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<3x3xf32>, tensor) - - //CHECK: Operation does not verify: Shape tensor rank must not be equal to -1 - %res_reshape = "onnx.Reshape"(%t0, %t1) {onnx_node_name = "/Reshape"}: (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x3xf32>) - - //CHECK: Operation does not verify: Mismatch between operand type and res type of onnx.Abs - %res_abs = "onnx.Abs"(%t0) {onnx_node_name = "/Abs"}: (tensor<3x3xf32>) -> tensor<2x3xf32> - -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (f32, tensor<1x1x3x3xf32>, none) - - // CHECK: operand at position 0 does not verify: - // CHECK: f32 should be of base attribute tensor - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {onnx_node_name = "/Conv"} : (f32, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv"} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xi32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: bias must be 1D - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [4 : i64, 4 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, tensor<4x2xf32>) -> tensor<1x1x3x3xf32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: kernel shape rank and weight tensor rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [4 : i64, 4 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: dilation value must be non zero positive - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [-2 : i64, -2: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: dilations rank and kernel shape rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64, 3: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: group value must be nonnegative - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: stride value must be non zero positive - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [-2 : i64, -2: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: strides rank and kernel shape rank are not the same - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: pads value must be nonnegative - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [-1 : i64, -1: i64, -1: i64, -1: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: pads rank is not twice the kernel shape rank - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "NOTSET", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [1 : i64, 1: i64, 1: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -} - -// ----- - -builtin.module { - %t0,%t1,%t2 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) - - // CHECK: Operation does not verify: Invalid auto_pad string. Must be one of ['NOTSET', 'SAME_UPPER', 'SAME_LOWER', 'VALID'] - %res_conv = "onnx.Conv"(%t0, %t1, %t2) {"onnx_node_name" = "/Conv", "auto_pad" = "INVALID", "dilations" = [1 : i64, 1: i64], "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "pads" = [0 : i64, 0: i64, 0: i64, 0: i64], "strides" = [1: i64, 1: i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - - } - - - // ----- - -builtin.module { - - // CHECK: f32 should be of base attribute tensor - %res_constant = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> f32 - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value attribute type must be of type TensorType - %res_constant = "onnx.Constant"() {value = dense<[3.0]> : vector<1xf32>} : () -> tensor<1xf32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value_int element type has to be a 64-bit signless integer - %res_constant = "onnx.Constant"() {value_int = 4 : i32} : () -> tensor<1xf32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: value_ints elements type has to be a 64-bit signless integer - %res_constant = "onnx.Constant"() {value_ints = [1: i64, 2: i32, 3: i64]} : () -> tensor<3xi32> - - } - -// ----- - -builtin.module { - - // CHECK: Operation does not verify: Only one value attribute must be provided, but 2 were specified - %res_constant = "onnx.Constant"() {value_ints = [1: i64, 1: i64], value_int = 3: i64} : () -> tensor<3xi64> - - } - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (f32) - - // CHECK: operand at position 0 does not verify: - // CHECK: Expected tensor or memref type, got f32 - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {onnx_node_name = "/MaxPoolSingleOut"} : (f32) -> tensor<5x5x32x32xf32> -} - -// ----- - -builtin.module { - %t0= "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: result at position 0 does not verify: - // CHECK: attribute f32 expected from variable 'T', but got i32 - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut"} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xi32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: Invalid auto_pad string. Must be one of ['NOTSET', 'SAME_UPPER', 'SAME_LOWER', 'VALID'] - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "INVALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: ceil value must be either zero or one - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 2 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: input data and kernel shape rank mismatch: (2) vs (1) - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: dilation value must be non zero positive - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [-1 : i64, -1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: dilations rank (3) and kernel shape rank (2) are not the same - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64, 1: i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: column major storage order not implemented yet - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 1 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: stride value must be non zero positive - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [-1 : i64, -1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: strides rank (3) and kernel shape rank (2) are not the same - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64, 1: i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: pads value must be nonnegative - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [-2 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0 = "test.op"(): () -> (tensor<5x5x32x32xf32>) - - // CHECK: Operation does not verify: pads rank (5) is not twice the kernel shape rank (2) - %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64, 0: i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> - -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4x3xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: input matrix A should be a 2D tensor - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4x3xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2x3xf32>) - - // CHECK: Operation does not verify: input matrix B should be a 2D tensor - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2x3xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<5x2xf32>) - - // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (5, 2) - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<5x2xf32>) -> tensor<2x2xf32> -} - -// ----- - -builtin.module { - %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2xf32>) - - // CHECK: Operation does not verify: result shape [2, 2] does not match result type [2, 3] - %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: permutation can not contain more than one occurrence of the same dimension: dimension #1 appears 2 times. - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 1 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: permutation can only contain values between 0 and 2-1: dimension #1 value is 2 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 2 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<1x3x4xf32>) - // CHECK: Operation does not verify: permutation and inputs dimensions must have the same size: #dimensions input is 3, #dimension perimutation is 2 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<1x3x4xf32>) -> tensor<3x1x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - // CHECK: Operation does not verify: incorrect output shape: output dimension #0 should be equal to 4 - %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<3x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<1x2x4xf32>) - - // CHECK: Operation does not verify: axes to squeeze must be between 0 and 2, axes: 3 - %res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 3 : i64} : (tensor<1x2x4xf32>) -> tensor<2x4xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<3x4xf32>) - - // CHECK: Operation does not verify: tensor input shape (3, 4) is not equal to tensor output shape (7, 3) - %res_sigmoid = "onnx.Sigmoid"(%t0) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<7x3xf32> -} diff --git a/tests/filecheck/dialects/onnx/onnx_ops.mlir b/tests/filecheck/dialects/onnx/onnx_ops.mlir deleted file mode 100644 index ed12de520a..0000000000 --- a/tests/filecheck/dialects/onnx/onnx_ops.mlir +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: XDSL_ROUNDTRIP - -%t0, %t1 = "test.op"(): () -> (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -%t2, %t3 = "test.op"(): () -> (tensor<3x2xf32>, tensor<1x2xf32>) -%t4, %t5 = "test.op"(): () -> (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -%t6, %t7 = "test.op"(): () -> (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -%t8 = "test.op"(): () -> (tensor<3x4xf32>) -%t9, %t10, %t11 = "test.op"(): () -> (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -%t12, %t13, %t14 = "test.op"(): () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -%t15,%t16 = "test.op"(): () -> (tensor<48x256x64xf32>, tensor<3xi64>) -%t17,%t18 = "test.op"(): () -> (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -%t19 = "test.op"(): () -> (tensor<10x10xf32>) -%t20,%t21,%t22 = "test.op"(): () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -%t23,%t24,%t25 = "test.op"(): () -> (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -%t26 = "test.op"(): () -> (tensor<5x5x32x32xf32>) -%t27, %t28, %t29 = "test.op"(): () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) - -%res_add = "onnx.Add"(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -> tensor<1x2x6xf32> -// CHECK: %res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<1x2x6xf32>, tensor<1x2x6xf32>) -> tensor<1x2x6xf32> - -%res_sub = "onnx.Sub"(%t2, %t3) {onnx_node_name = "/Sub"}: (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<3x2xf32> -// CHECK: %res_sub = onnx.Sub(%t2, %t3) {onnx_node_name = "/Sub"} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<3x2xf32> - -%res_mul = "onnx.Mul"(%t4, %t5) {onnx_node_name = "/Mul"}: (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -> tensor<3x4x2xf32> -// CHECK: %res_mul = onnx.Mul(%t4, %t5) {onnx_node_name = "/Mul"} : (tensor<3x1x2xf32>, tensor<3x4x1xf32>) -> tensor<3x4x2xf32> - -%res_div = "onnx.Div"(%t6, %t7) {onnx_node_name = "/Div"}: (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -> tensor<2x5x6x3xf32> -// CHECK: %res_div = onnx.Div(%t6, %t7) {onnx_node_name = "/Div"} : (tensor<1x5x1x3xf32>, tensor<2x1x6x3xf32>) -> tensor<2x5x6x3xf32> - -%res_relu = "onnx.Relu"(%t8) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %res_relu = onnx.Relu(%t8) {onnx_node_name = "/Relu"} : (tensor<3x4xf32>) -> tensor<3x4xf32> - -%res_gemm = "onnx.Gemm"(%t9, %t10, %t11) {onnx_node_name = "/Gemm"}: (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %res_gemm = onnx.Gemm(%t9, %t10, %t11) {onnx_node_name = "/Gemm"} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - -%res_gemm_1 = "onnx.Gemm"(%t12, %t13, %t14) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> -// CHECK: %res_gemm_1 = onnx.Gemm(%t12, %t13, %t14) {onnx_node_name = "/Gemm"} : (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> - -%res_reshape = "onnx.Reshape"(%t15, %t16) {onnx_node_name = "/Reshape", "allowzero" = 1 : i64}: (tensor<48x256x64xf32>, tensor<3xi64>) -> tensor<48x256x64xf32> -//CHECK: res_reshape = onnx.Reshape(%t15, %t16) {onnx_node_name = "/Reshape", allowzero = 1 : i64} : (tensor<48x256x64xf32>, tensor<3xi64>) -> tensor<48x256x64xf32> - -%res_reshape_1 = "onnx.Reshape"(%t17, %t18) {onnx_node_name = "/Reshape"}: (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -> tensor<1x120xf32> -//CHECK: %res_reshape_1 = onnx.Reshape(%t17, %t18) {onnx_node_name = "/Reshape"} : (tensor<1x2x3x4x5xf32>, tensor<5xi64>) -> tensor<1x120xf32> - -%res_abs = "onnx.Abs"(%t19) {onnx_node_name = "/Abs"}: (tensor<10x10xf32>) -> tensor<10x10xf32> -// CHECK: %res_abs = onnx.Abs(%t19) {onnx_node_name = "/Abs"} : (tensor<10x10xf32>) -> tensor<10x10xf32> - -%res_conv = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [1 : i64, 1 : i64, 1: i64, 1 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x5x5xf32> -//CHECK: %res_conv = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [1 : i64, 1 : i64], pads = [1 : i64, 1 : i64, 1 : i64, 1 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x5x5xf32> - -%res_conv_1 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -//CHECK: %res_conv_1 = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [1 : i64, 1 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -%res_conv_2 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "SAME_LOWER", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> -//CHECK: %res_conv_2 = onnx.Conv(%t20, %t21, %t22) {onnx_node_name = "/Conv", auto_pad = "SAME_LOWER", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -%res_conv_3 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [1 : i64, 1 : i64, 1 : i64, 1 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x3xf32> -//CHECK: %res_conv_3 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [1 : i64, 1 : i64, 1 : i64, 1 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x3xf32> - -%res_conv_4 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x2xf32> -//CHECK: %res_conv_4 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x2xf32> - -%res_conv_5 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [2 : i64, 2 : i64], "pads" = [1 : i64, 0 : i64, 1 : i64, 0 : i64]}: (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x2xf32> -//CHECK: %res_conv_5 = onnx.Conv(%t23, %t24, %t25) {onnx_node_name = "/Conv", auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], strides = [2 : i64, 2 : i64], pads = [1 : i64, 0 : i64, 1 : i64, 0 : i64]} : (tensor<1x1x7x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x4x2xf32> - -%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "VALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]}: (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> -//CHECK: %res_max_pool_single_out = onnx.MaxPoolSingleOut(%t26) {onnx_node_name = "/MaxPoolSingleOut", auto_pad = "VALID", ceil_mode = 0 : i64, kernel_shape = [3 : i64, 3 : i64], dilations = [1 : i64, 1 : i64], pads = [0 : i64, 0 : i64, 0 : i64, 0 : i64], storage_order = 0 : i64, strides = [1 : i64, 1 : i64]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> - -"onnx.EntryPoint"() {onnx_node_name = "/EntryPoint", "func" = @main_graph} : () -> () -//CHECK: "onnx.EntryPoint"() {onnx_node_name = "/EntryPoint", func = @main_graph} : () -> () - -%res_constant = onnx.Constant dense<1> : tensor<1xi64> -//CHECK: %res_constant = onnx.Constant dense<1> : tensor<1xi64> - -%res_constant_1 = onnx.Constant dense<[5, 5, 16, 2]> : tensor<4xi64> -//CHECK: %res_constant_1 = onnx.Constant dense<[5, 5, 16, 2]> : tensor<4xi64> - -%res_gemm_2 = "onnx.Gemm"(%t27, %t28, %t29) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> -// CHECK: %res_gemm_2 = onnx.Gemm(%t27, %t28, %t29) {onnx_node_name = "/Gemm", alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 1 : si64} : (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> - -%res_matmul = "onnx.MatMul"(%t9, %t10) {onnx_node_name = "/MatMul"}: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %res_matmul = onnx.MatMul(%t9, %t10) {onnx_node_name = "/MatMul"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - -%res_transpose = "onnx.Transpose"(%t8) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<4x3xf32> -// CHECK: %res_transpose = onnx.Transpose(%t8) {onnx_node_name = "/Transpose", perm = [1 : i64, 0 : i64]} : (tensor<3x4xf32>) -> tensor<4x3xf32> - -%res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 0}: (tensor<1x2x6xf32>) -> tensor<2x6xf32> -// CHECK: %res_squeeze = onnx.Squeeze(%t0) {onnx_node_name = "/Squeeze", axes = 0 : i64} : (tensor<1x2x6xf32>) -> tensor<2x6xf32> - -%res_sigmoid = "onnx.Sigmoid"(%t8) {onnx_node_name = "/Sigmoid"}: (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %res_sigmoid = onnx.Sigmoid(%t8) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<3x4xf32> diff --git a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir b/tests/filecheck/transforms/convert_onnx_to_linalg.mlir deleted file mode 100644 index 76ff2fef9a..0000000000 --- a/tests/filecheck/transforms/convert_onnx_to_linalg.mlir +++ /dev/null @@ -1,117 +0,0 @@ -// RUN: xdsl-opt -p convert-onnx-to-linalg %s | filecheck %s - -// CHECK: builtin.module { - -%t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) -%res_add = onnx.Add(%t0, %t1) {onnx_node_name = "/Add"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> -%res_sub = onnx.Sub(%t0, %t1) {onnx_node_name = "/Sub"} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> - -// CHECK-NEXT: %t0, %t1 = "test.op"() : () -> (tensor<3x2xf32>, tensor<3x2xf32>) -// CHECK-NEXT: %res_add = tensor.empty() : tensor<3x2xf32> -// CHECK-NEXT: %res_add_1 = linalg.add ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_add : tensor<3x2xf32>) -> tensor<3x2xf32> -// CHECK-NEXT: %res_sub = tensor.empty() : tensor<3x2xf32> -// CHECK-NEXT: %res_sub_1 = linalg.sub ins(%t0, %t1 : tensor<3x2xf32>, tensor<3x2xf32>) outs(%res_sub : tensor<3x2xf32>) -> tensor<3x2xf32> - - -%t2 = "test.op"() : () -> (tensor<3x4xf32>) -%res_relu = "onnx.Relu"(%t2) {onnx_node_name = "/Relu"}: (tensor<3x4xf32>) -> tensor<3x4xf32> - -// CHECK-NEXT: %t2 = "test.op"() : () -> tensor<3x4xf32> -// CHECK-NEXT: %res_relu = tensor.empty() : tensor<3x4xf32> -// CHECK-NEXT: %res_relu_1 = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %res_relu_2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t2 : tensor<3x4xf32>) outs(%res_relu : tensor<3x4xf32>) { -// CHECK-NEXT: ^0(%0 : f32, %1 : f32): -// CHECK-NEXT: %2 = arith.maximumf %0, %res_relu_1 : f32 -// CHECK-NEXT: linalg.yield %2 : f32 -// CHECK-NEXT: } -> tensor<3x4xf32> - -%t27 = "test.op"() : () -> (tensor<3x4xf64>) -%res_relu_3 = "onnx.Relu"(%t27) {onnx_node_name = "/Relu"}: (tensor<3x4xf64>) -> tensor<3x4xf64> - -// CHECK-NEXT: %t27 = "test.op"() : () -> tensor<3x4xf64> -// CHECK-NEXT: %res_relu_3 = tensor.empty() : tensor<3x4xf64> -// CHECK-NEXT: %res_relu_4 = arith.constant 0.000000e+00 : f64 -// CHECK-NEXT: %res_relu_5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%t27 : tensor<3x4xf64>) outs(%res_relu_3 : tensor<3x4xf64>) { -// CHECK-NEXT: ^1(%3 : f64, %4 : f64): -// CHECK-NEXT: %5 = arith.maximumf %3, %res_relu_4 : f64 -// CHECK-NEXT: linalg.yield %5 : f64 -// CHECK-NEXT: } -> tensor<3x4xf64> - - -%t3,%t4 = "test.op"(): () -> (tensor<20x2xf32>, tensor<2xi64>) -%res_reshape = "onnx.Reshape"(%t3, %t4) {onnx_node_name = "/Reshape"}: (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32> - -// CHECK-NEXT: %t3, %t4 = "test.op"() : () -> (tensor<20x2xf32>, tensor<2xi64>) -// CHECK-NEXT: %res_reshape = tensor.reshape %t3(%t4) : (tensor<20x2xf32>, tensor<2xi64>) -> tensor<1x40xf32> - -%t5, %t6, %t7 = "test.op"(): () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -%res_gemm= "onnx.Gemm"(%t5, %t6, %t7) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> - -// CHECK-NEXT: %t5, %t6, %t7 = "test.op"() : () -> (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -// CHECK-NEXT: %6 = tensor.empty() : tensor<320x50xf32> -// CHECK-NEXT: %7 = linalg.transpose ins(%t6:tensor<50x320xf32>) outs(%6:tensor<320x50xf32>) permutation = [1, 0] -// CHECK-NEXT: %res_gemm = tensor.empty() : tensor<1x50xf32> -// CHECK-NEXT: %res_gemm_1 = linalg.matmul ins(%t5, %7 : tensor<1x320xf32>, tensor<320x50xf32>) outs(%res_gemm : tensor<1x50xf32>) -> tensor<1x50xf32> -// CHECK-NEXT: %res_gemm_2 = linalg.add ins(%res_gemm_1, %t7 : tensor<1x50xf32>, tensor<50xf32>) outs(%res_gemm_1 : tensor<1x50xf32>) -> tensor<1x50xf32> - - - -%t8, %t9, %t10 = "test.op"(): () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -%res_gemm_1 = "onnx.Gemm"(%t8, %t9, %t10) {onnx_node_name = "/Gemm"}: (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -> tensor<5x2xf32> - - -// CHECK-NEXT: %t8, %t9, %t10 = "test.op"() : () -> (tensor<5x3xf32>, tensor<3x2xf32>, tensor<5x2xf32>) -// CHECK-NEXT: %res_gemm_3 = tensor.empty() : tensor<5x2xf32> -// CHECK-NEXT: %res_gemm_4 = linalg.matmul ins(%t8, %t9 : tensor<5x3xf32>, tensor<3x2xf32>) outs(%res_gemm_3 : tensor<5x2xf32>) -> tensor<5x2xf32> -// CHECK-NEXT: %res_gemm_5 = linalg.add ins(%res_gemm_4, %t10 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%res_gemm_4 : tensor<5x2xf32>) -> tensor<5x2xf32> - - -%t11, %t12, %t13 = "test.op"(): () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -%res_gemm_2 = "onnx.Gemm"(%t11, %t12, %t13) {onnx_node_name = "/Gemm", "alpha" = 0.500000e+00 : f32, "beta" = 0.500000e+00 : f32, "transA" = 1 : si64}: (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> - - -// CHECK-NEXT: %t11, %t12, %t13 = "test.op"() : () -> (tensor<10x5xf32>, tensor<10x3xf32>, tensor<5x3xf32>) -// CHECK-NEXT: %8 = tensor.empty() : tensor<5x10xf32> -// CHECK-NEXT: %9 = linalg.transpose ins(%t11:tensor<10x5xf32>) outs(%8:tensor<5x10xf32>) permutation = [1, 0] -// CHECK-NEXT: %10 = arith.constant 5.000000e-01 : f32 -// CHECK-NEXT: %11 = linalg.mul ins(%10, %9 : f32, tensor<5x10xf32>) outs(%9 : tensor<5x10xf32>) -> tensor<5x10xf32> -// CHECK-NEXT: %12 = arith.constant 5.000000e-01 : f32 -// CHECK-NEXT: %13 = linalg.mul ins(%12, %t13 : f32, tensor<5x3xf32>) outs(%t13 : tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_6 = tensor.empty() : tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_7 = linalg.matmul ins(%11, %t12 : tensor<5x10xf32>, tensor<10x3xf32>) outs(%res_gemm_6 : tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK-NEXT: %res_gemm_8 = linalg.add ins(%res_gemm_7, %13 : tensor<5x3xf32>, tensor<5x3xf32>) outs(%res_gemm_7 : tensor<5x3xf32>) -> tensor<5x3xf32> - -%t26 = "test.op"(): () -> (tensor<1x16x14x14xf32>) -%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "NOTSET", "ceil_mode" = 0 : si64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : si64, strides = [3 : i64, 3 : i64]} : (tensor<1x16x14x14xf32>) -> tensor<1x16x4x4xf32> - -// CHECK-NEXT: %t26 = "test.op"() : () -> tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_max_pool_single_out = tensor.empty() : tensor<3x3xf32> -// CHECK-NEXT: %res_max_pool_single_out_1 = tensor.empty() : tensor<1x16x4x4xf32> -// CHECK-NEXT: %res_max_pool_single_out_2 = arith.constant -1.000000e+308 : f64 -// CHECK-NEXT: %res_max_pool_single_out_3 = linalg.fill ins(%res_max_pool_single_out_2 : f64) outs(%res_max_pool_single_out_1 : tensor<1x16x4x4xf32>) -> tensor<1x16x4x4xf32> -// CHECK-NEXT: %res_max_pool_single_out_4 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>} ins(%t26, %res_max_pool_single_out : tensor<1x16x14x14xf32>, tensor<3x3xf32>) outs(%res_max_pool_single_out_3 : tensor<1x16x4x4xf32>) -> tensor<1x16x4x4xf32> - -%t20, %t21, %t22 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -%res_conv_2 = "onnx.Conv"(%t20, %t21, %t22) {onnx_node_name = "/Conv", "auto_pad" = "NOTSET", "group" = 1 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]}: (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -> tensor<1x1x3x3xf32> - -// CHECK-NEXT: %t20, %t21, %t22 = "test.op"() : () -> (tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>, none) -// CHECK-NEXT: %res_conv = tensor.empty() : tensor<1x1x3x3xf32> -// CHECK-NEXT: %res_conv_1 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%t20, %t21 : tensor<1x1x5x5xf32>, tensor<1x1x3x3xf32>) outs(%res_conv : tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> - -%t23, %t24, %t25 = "test.op"() : () -> (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -%res_conv_3 = "onnx.Conv"(%t23, %t24, %t25) {onnx_node_name = "/Conv", "auto_pad" = "SAME_UPPER", "group" = 1 : i64, "kernel_shape" = [5 : i64, 5 : i64], "dilations" = [1 : i64, 1 : i64], "strides" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0: i64, 0 : i64]} : (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -> tensor<1x16x14x14xf32> - -// CHECK-NEXT: %t23, %t24, %t25 = "test.op"() : () -> (tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>, tensor<16xf32>) -// CHECK-NEXT: %res_conv_2 = tensor.empty() : tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_conv_3 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%t23, %t24 : tensor<1x8x14x14xf32>, tensor<16x8x5x5xf32>) outs(%res_conv_2 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> -// CHECK-NEXT: %res_conv_4 = linalg.add ins(%t25 : tensor<16xf32>) outs(%res_conv_3 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> - -%res_constant = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<1> : tensor<1xi64>}: () -> tensor<1xi64> -%res_constant_2 = "onnx.Constant"() {onnx_node_name = "/Constant", "value" = dense<2.0> : tensor<1x5xf32>} : () -> tensor<1x5xf32> - -// CHECK-NEXT: %res_constant = ml_program.global_load_const @onnx_constant_1 : tensor<1xi64> -// CHECK-NEXT: %res_constant_1 = ml_program.global_load_const @onnx_constant_2 : tensor<1x5xf32> -// CHECK-NEXT: ml_program.global private @onnx_constant_1(dense<1> : tensor<1xi64>) : tensor<1xi64> -// CHECK-NEXT: ml_program.global private @onnx_constant_2(dense<2.000000e+00> : tensor<1x5xf32>) : tensor<1x5xf32> - -// CHECK-NEXT: } diff --git a/tests/frontend/onnx/test_build_onnx_ir.py b/tests/frontend/onnx/test_build_onnx_ir.py deleted file mode 100644 index 7a0be0340b..0000000000 --- a/tests/frontend/onnx/test_build_onnx_ir.py +++ /dev/null @@ -1,377 +0,0 @@ -import onnx -import pytest - -from xdsl.dialects.builtin import TensorType, f32 -from xdsl.dialects.onnx import AddOp, MatMulOp, SubOp, TransposeOp -from xdsl.ir import Attribute -from xdsl.utils.test_value import TestSSAValue - -try: - from onnx import TensorProto, ValueInfoProto, helper - - from xdsl.frontend.onnx.ir_builder import ( - OnnxXdslMapping, - build_module, - visit_graph, - visit_node, - visit_value_info, - ) -except ImportError as exc: - print(exc) - pytest.skip("onnx is an optional dependency", allow_module_level=True) - - -def test_visit_value_info(): - # initialize context - ctx = OnnxXdslMapping() - - # create ValueInfoProto input tensor - input_value_info = ValueInfoProto() - input_value_info.name = "input_tensor" - - # define the type of the input tensor - input_tensor_type = input_value_info.type.tensor_type - input_tensor_type.elem_type = TensorProto.FLOAT - input_tensor_type.shape.dim.extend( - [ - onnx.TensorShapeProto.Dimension(dim_value=1), - onnx.TensorShapeProto.Dimension(dim_value=3), - onnx.TensorShapeProto.Dimension(dim_value=224), - onnx.TensorShapeProto.Dimension(dim_value=224), - ] - ) - - # run visit_value_info with empty context - t1 = visit_value_info(input_value_info, ctx) - - # check type info - assert isinstance(t1, Attribute) - assert str(t1) == "tensor<1x3x224x224xf32>" - - # check keys in context - keys = list(ctx.type_by_name.keys()) - assert keys == ["input_tensor"] - - # run visit_value_info again - t2 = visit_value_info(input_value_info, ctx) - - # check type info - assert isinstance(t2, Attribute) - assert str(t2) == "tensor<1x3x224x224xf32>" - - # check keys in context - keys = list(ctx.type_by_name.keys()) - assert keys == ["input_tensor"] - - # check it is returned the same reference - assert t1 is t2 - - -def test_visit_node_unknown_op_name(): - """ - Test for unknown expected onnx op - """ - node_attributes = { - "name": "dummy_name", - "op_type": "dummy_op", - } - - node = helper.make_node( - **node_attributes, inputs=["input1", "input2"], outputs=["output"] - ) - ctx = OnnxXdslMapping() - with pytest.raises(ValueError, match="Unknown ONNX op name dummy_op"): - visit_node(node=node, ctx=ctx) - - -def test_visit_node_add(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Add operation - _, add_node = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - lhs = TestSSAValue(TensorType(f32, [64])) - rhs = TestSSAValue(TensorType(f32, [64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64]) - rhs_type = TensorType(f32, [64]) - out_type = TensorType(f32, [64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - # visit node - op = visit_node(add_node, ctx) - - assert isinstance(op, AddOp) - assert op.lhs is lhs - assert op.rhs is rhs - assert op.res is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_sub(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, sub_node = _create_graph_binary_op("Sub", "sub_graph", [64], [64], [64]) - - lhs = TestSSAValue(TensorType(f32, [64])) - rhs = TestSSAValue(TensorType(f32, [64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64]) - rhs_type = TensorType(f32, [64]) - out_type = TensorType(f32, [64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - op = visit_node(sub_node, ctx) - - assert isinstance(op, SubOp) - assert op.lhs is lhs - assert op.rhs is rhs - assert op.res is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_matmul(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, matmul_node = _create_graph_binary_op( - "MatMul", "matmul_graph", [64, 128], [128, 64], [64, 64] - ) - - lhs = TestSSAValue(TensorType(f32, [64, 128])) - rhs = TestSSAValue(TensorType(f32, [128, 64])) - ctx.value_by_name["input1"] = lhs - ctx.value_by_name["input2"] = rhs - - lhs_type = TensorType(f32, [64, 128]) - rhs_type = TensorType(f32, [128, 64]) - out_type = TensorType(f32, [64, 64]) - ctx.type_by_name["input1"] = lhs_type - ctx.type_by_name["input2"] = rhs_type - ctx.type_by_name["output"] = out_type - - op = visit_node(matmul_node, ctx) - op.verify() - - assert isinstance(op, MatMulOp) - assert op.matrix_A is lhs - assert op.matrix_B is rhs - assert op.matrix_Y is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_node_transpose(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - _, transpose_node = _create_transpose_op( - graph_name="transpose_graph", dim_in=[64, 128], dim_out=[128, 64], perm=[1, 0] - ) - - in_value = TestSSAValue(TensorType(f32, [64, 128])) - ctx.value_by_name["input"] = in_value - - in_type = TensorType(f32, [64, 128]) - out_type = TensorType(f32, [128, 64]) - ctx.type_by_name["input"] = in_type - ctx.type_by_name["output"] = out_type - - op = visit_node(transpose_node, ctx) - op.verify() - - assert isinstance(op, TransposeOp) - assert op.tensor_input is in_value - assert op.tensor_output is ctx.value_by_name["output"] - assert not op.attributes - assert not op.regions - - -def test_visit_graph_add(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Add operation - graph, _ = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - # visit graph - visit_graph(graph, ctx) - - # check value_by_name keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check expected generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Add(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_sub(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Sub operation - graph, _ = _create_graph_binary_op("Sub", "sub_graph", [64], [64], [64]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Sub(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_matmul(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one MatMul operation - graph, _ = _create_graph_binary_op("MatMul", "matmul_graph", [64], [64], [64]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input1", "input2", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[2]].owner - assert ( - str(gen_ir) - == "%0 = onnx.MatMul(%1, %2) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32>" - ) - - -def test_visit_graph_transpose(): - # initialize context - ctx = OnnxXdslMapping() - - # create graph composed only of one Transpose operation - graph, _ = _create_transpose_op("transpose_graph", [64, 128], [128, 64], [1, 0]) - - # run visit graph - visit_graph(graph, ctx) - - # check value_by_names keys - keys = list(ctx.value_by_name.keys()) - assert keys == ["input", "output"] - - # check generated ir - gen_ir = ctx.value_by_name[keys[1]].owner - assert ( - str(gen_ir) - == "%0 = onnx.Transpose(%1) : (tensor<64x128xf32>) -> tensor<128x64xf32>" - ) - - -def test_build_module(): - # create graph composed only of one Add operation - graph, _ = _create_graph_binary_op("Add", "add_graph", [64], [64], [64]) - - # create module - module = build_module(graph) - - # define expected output - expected = """ -builtin.module { - func.func @add_graph(%0 : tensor<64xf32>, %1 : tensor<64xf32>) -> tensor<64xf32> { - %2 = onnx.Add(%0, %1) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> - func.return %2 : tensor<64xf32> - } -}""" - - # remove first new line - expected = expected[1:] - - # check output - assert str(module) == expected - - -def _create_graph_binary_op( - op_name: str, - graph_name: str, - dim_in1: list[int], - dim_in2: list[int], - dim_out: list[int], -): - # define input and output names - input1_name = "input1" - input2_name = "input2" - output_name = "output" - - # define input shapes - input1_shape = dim_in1 - input2_shape = dim_in2 - output_shape = dim_out - - # define op node - op_node = helper.make_node( - op_type=op_name, - inputs=[input1_name, input2_name], - outputs=[output_name], - ) - - # create graph (composed of just one operation) - graph = helper.make_graph( - nodes=[op_node], - name=graph_name, - inputs=[ - helper.make_tensor_value_info(input1_name, TensorProto.FLOAT, input1_shape), - helper.make_tensor_value_info(input2_name, TensorProto.FLOAT, input2_shape), - ], - outputs=[ - helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape), - ], - ) - - return graph, op_node - - -def _create_transpose_op( - graph_name: str, dim_in: list[int], dim_out: list[int], perm: list[int] -): - # create input tensor - input_tensor = helper.make_tensor_value_info( - "input", onnx.TensorProto.FLOAT, dim_in - ) - - # create output tensor - output_tensor = helper.make_tensor_value_info( - "output", onnx.TensorProto.FLOAT, dim_out - ) - - # create transpose operation - transpose_node = helper.make_node("Transpose", ["input"], ["output"], perm=perm) - - # create onnx graph - graph_def = helper.make_graph( - [transpose_node], graph_name, [input_tensor], [output_tensor] - ) - - return graph_def, transpose_node diff --git a/tests/frontend/onnx/test_type.py b/tests/frontend/onnx/test_type.py deleted file mode 100644 index 8b8a96a347..0000000000 --- a/tests/frontend/onnx/test_type.py +++ /dev/null @@ -1,89 +0,0 @@ -import onnx -import pytest - -from xdsl.dialects.builtin import Float32Type, TensorType, f32, f64 - -try: - from onnx import TensorShapeProto, TypeProto - - from xdsl.frontend.onnx.type import ( - get_elem_type, - get_shape, - get_tensor_type, - get_type, - ) - from xdsl.utils.hints import isa -except ImportError as exc: - print(exc) - pytest.skip("onnx is an optional dependency", allow_module_level=True) - - -def test_get_elem_type(): - # test case 1: check if 1 corresponds to f32 - assert get_elem_type(1) == f32 - - # test case 11: check if 11 corresponds to f64 - assert get_elem_type(11) == f64 - - # test case -1: check if -1 (or other illegal values) corresponds to None - with pytest.raises(ValueError, match="Unknown elem_type: -1"): - get_elem_type(-1) - - -def test_get_type(): - tensor_type = onnx.TypeProto() - tensor_type.tensor_type.elem_type = onnx.TensorProto.FLOAT - tensor_type.tensor_type.shape.dim.extend( - [ - onnx.TensorShapeProto.Dimension(dim_value=3), - onnx.TensorShapeProto.Dimension(dim_value=4), - onnx.TensorShapeProto.Dimension(dim_value=5), - ] - ) - tt = get_type(tensor_type) - - assert isa(tt, TensorType[Float32Type]) - - assert tt.get_num_dims() == 3 - assert tt.get_shape() == (3, 4, 5) - assert tt.get_element_type().name == "f32" - - -def test_get_shape(): - assert get_shape(TensorShapeProto()) == () - assert get_shape( - TensorShapeProto(dim=(TensorShapeProto.Dimension(dim_value=1),)) - ) == (1,) - assert get_shape( - TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=1), - TensorShapeProto.Dimension(dim_value=2), - ) - ) - ) == (1, 2) - - -def test_get_tensor_type(): - assert get_tensor_type( - TypeProto.Tensor( - elem_type=1, - shape=TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=2), - TensorShapeProto.Dimension(dim_value=3), - ) - ), - ) - ) == TensorType(f32, (2, 3)) - assert get_tensor_type( - TypeProto.Tensor( - elem_type=11, - shape=TensorShapeProto( - dim=( - TensorShapeProto.Dimension(dim_value=4), - TensorShapeProto.Dimension(dim_value=5), - ) - ), - ) - ) == TensorType(f64, (4, 5)) diff --git a/tests/interpreters/test_onnx_interpreter.py b/tests/interpreters/test_onnx_interpreter.py deleted file mode 100644 index 01c1e11013..0000000000 --- a/tests/interpreters/test_onnx_interpreter.py +++ /dev/null @@ -1,554 +0,0 @@ -import pytest - -from xdsl.dialects import onnx -from xdsl.dialects.builtin import ( - AnyIntegerAttr, - ArrayAttr, - DenseIntOrFPElementsAttr, - FloatAttr, - ModuleOp, - NoneType, - StringAttr, - TensorType, - f32, - i64, -) -from xdsl.interpreter import Interpreter -from xdsl.interpreters.builtin import BuiltinFunctions -from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr -from xdsl.utils.exceptions import InterpretationError -from xdsl.utils.test_value import TestSSAValue - -pytest.importorskip("numpy", reason="numpy is an optional dependency in xDSL") - -from xdsl.interpreters.onnx import OnnxFunctions # noqa: E402 - - -def test_onnx_add(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.AddOp( - TestSSAValue(TensorType(f32, [2, 3])), - TestSSAValue(TensorType(f32, [2, 3])), - res_type=TensorType(f32, [2, 3]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6]), [2, 3]) - b = ShapedArray(TypedPtr.new_float32([1, 4, 2, 5, 3, 6]), [2, 3]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([2, 6, 5, 9, 8, 12]), [2, 3]) - - -def test_onnx_sub(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.SubOp( - TestSSAValue(TensorType(f32, [2, 3])), - TestSSAValue(TensorType(f32, [2, 3])), - res_type=TensorType(f32, [2, 3]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6]), [2, 3]) - b = ShapedArray(TypedPtr.new_float32([1, 4, 2, 5, 3, 6]), [2, 3]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([0, -2, 1, -1, 2, 0]), [2, 3]) - - -def test_onnx_mul(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MulOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - res_type=TensorType(f32, [2, 2]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 4, 7, 1]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([2, 3, 1, 8]), [2, 2]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([2, 12, 7, 8]), [2, 2]) - - -def test_onnx_div(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.DivOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - res_type=TensorType(f32, [2, 2]), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 1, 1, 1]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([5, 2, 1, 2]), [2, 2]) - - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray(TypedPtr.new_float32([0.2, 0.5, 1.0, 0.5]), [2, 2]) - - -def test_onnx_relu(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReluOp( - TestSSAValue(TensorType(f32, [2, 2])), - ) - - a = ShapedArray(TypedPtr.new_float32([-1, 0, 1, 2]), [2, 2]) - (b,) = interpreter.run_op(op, (a,)) - assert b == ShapedArray(TypedPtr.new_float32([-0.0, 0, 1, 2]), [2, 2]) - - -def test_onnx_constant(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - interpreter.register_implementations(BuiltinFunctions()) - op = onnx.ConstantOp( - ( - DenseIntOrFPElementsAttr.create_dense_int( - TensorType(i64, [4]), [5, 5, 16, 2] - ) - ), - None, - None, - None, - None, - None, - None, - output_type=TensorType(i64, [4]), - ) - - (a,) = interpreter.run_op(op, ()) - assert a == ShapedArray(TypedPtr.new_int64([5, 5, 16, 2]), [4]) - - -def test_onnx_reshape(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReshapeOp( - (TestSSAValue(TensorType(f32, [1, 10]))), - (TestSSAValue(TensorType(i64, [2]))), - AnyIntegerAttr(0, i64), - ) - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), [1, 10]) - b = ShapedArray(TypedPtr.new_float32([1, 10]), [2]) - (c,) = interpreter.run_op(op, (a, b)) - assert c == ShapedArray( - TypedPtr.new_float32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), [1, 10] - ) - - -def test_onnx_reshape_error(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ReshapeOp( - (TestSSAValue(TensorType(f32, [1, 10]))), - (TestSSAValue(TensorType(i64, [2]))), - AnyIntegerAttr(0, i64), - ) - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [1, 4]) - b = ShapedArray(TypedPtr.new_float32([2, 2]), [2]) - with pytest.raises( - InterpretationError, match="Mismatch between static shape and new shape" - ): - interpreter.run_op(op, (a, b)) - - -def test_onnx_gemm(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(1, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - b = ShapedArray(TypedPtr.new_float32([2, 4, 6, 8]), [2, 2]) - c = ShapedArray(TypedPtr.new_float32([1, 1, 1, 1]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([15, 21, 31, 45]), [2, 2]) - - -def test_onnx_gemm_transpose_b(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(1, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(1, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2]), [2, 1]) - b = ShapedArray(TypedPtr.new_float32([4, 9]), [2, 1]) - c = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([5, 11, 11, 22]), [2, 2]) - - -def test_onnx_gemm_alpha(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.GemmOp( - TestSSAValue(TensorType(f32, [2, 1])), - TestSSAValue(TensorType(f32, [1, 2])), - TestSSAValue(TensorType(f32, [2, 2])), - FloatAttr(2, f32), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - FloatAttr(1, f32), - ) - - a = ShapedArray(TypedPtr.new_float32([1, 2]), [2, 1]) - b = ShapedArray(TypedPtr.new_float32([4, 9]), [1, 2]) - c = ShapedArray(TypedPtr.new_float32([1, 2, 3, 4]), [2, 2]) - (d,) = interpreter.run_op(op, (a, b, c)) - assert d == ShapedArray(TypedPtr.new_float32([9, 20, 19, 40]), [2, 2]) - - -def test_onnx_conv_no_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([54, 63, 72, 99, 108, 117, 144, 153, 162]), [1, 1, 3, 3] - ) - - -def test_onnx_conv_with_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32( - [ - 12.0, - 21.0, - 27.0, - 33.0, - 24.0, - 33.0, - 54.0, - 63.0, - 72.0, - 51.0, - 63.0, - 99.0, - 108.0, - 117.0, - 81.0, - 93.0, - 144.0, - 153.0, - 162.0, - 111.0, - 72.0, - 111.0, - 117.0, - 123.0, - 84.0, - ] - ), - [1, 1, 5, 5], - ) - - -def test_onnx_conv_with_same_lower_strides(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("SAME_LOWER"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(25)), [1, 1, 5, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([12.0, 27.0, 24.0, 63.0, 108.0, 81.0, 72.0, 117.0, 84.0]), - [1, 1, 3, 3], - ) - - -def test_onnx_conv_with_strides_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(1, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32( - [ - 12.0, - 27.0, - 24.0, - 63.0, - 108.0, - 81.0, - 123.0, - 198.0, - 141.0, - 112.0, - 177.0, - 124.0, - ] - ), - [1, 1, 4, 3], - ) - - -def test_onnx_conv_with_strides_no_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([54.0, 72.0, 144.0, 162.0, 234.0, 252.0]), [1, 1, 3, 2] - ) - - -def test_onnx_conv_with_strides_asy_padding(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.ConvOp( - TestSSAValue(TensorType(f32, [1, 1, 5, 5])), - TestSSAValue(TensorType(f32, [1, 1, 3, 3])), - TestSSAValue(NoneType()), - StringAttr("NOTSET"), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - AnyIntegerAttr(1, i64), - ArrayAttr([AnyIntegerAttr(3, i64), AnyIntegerAttr(3, i64)]), - ArrayAttr( - [ - AnyIntegerAttr(1, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(1, i64), - AnyIntegerAttr(0, i64), - ] - ), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(35)), [1, 1, 7, 5]) - b = ShapedArray( - TypedPtr.new_float32( - [ - 1, - ] - * 9 - ), - [1, 1, 3, 3], - ) - c = ShapedArray(TypedPtr.new_float32([0]), [1]) - (d,) = interpreter.run_op(op, (a, b, c)) - - assert d == ShapedArray( - TypedPtr.new_float32([21.0, 33.0, 99.0, 117.0, 189.0, 207.0, 171.0, 183.0]), - [1, 1, 4, 2], - ) - - -def test_onnx_max_pool_single_out(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MaxPoolSingleOutOp( - TestSSAValue(TensorType(f32, [1, 1, 4, 4])), - StringAttr("NOTSET"), - AnyIntegerAttr(0, i64), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - pads=ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - storage_order=AnyIntegerAttr(0, i64), - strides=ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ) - a = ShapedArray(TypedPtr.new_float32(range(1, 17)), [1, 1, 4, 4]) - (b,) = interpreter.run_op(op, (a,)) - - assert b == ShapedArray( - TypedPtr.new_float32([6, 7, 8, 10, 11, 12, 14, 15, 16]), [1, 1, 3, 3] - ) - - -def test_onnx_max_pool_single_out_strides_two(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(OnnxFunctions()) - op = onnx.MaxPoolSingleOutOp( - TestSSAValue(TensorType(f32, [1, 1, 4, 4])), - StringAttr("NOTSET"), - AnyIntegerAttr(0, i64), - ArrayAttr([AnyIntegerAttr(1, i64), AnyIntegerAttr(1, i64)]), - ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - pads=ArrayAttr( - [ - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - AnyIntegerAttr(0, i64), - ] - ), - storage_order=AnyIntegerAttr(0, i64), - strides=ArrayAttr([AnyIntegerAttr(2, i64), AnyIntegerAttr(2, i64)]), - ) - a = ShapedArray( - TypedPtr.new_float32([1, 1, 2, 4, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4]), - [1, 1, 4, 4], - ) - (b,) = interpreter.run_op(op, (a,)) - - assert b == ShapedArray(TypedPtr.new_float32([6, 8, 3, 4]), [1, 1, 2, 2]) diff --git a/uv.lock b/uv.lock index 694dd59a9c..c7a8751286 100644 --- a/uv.lock +++ b/uv.lock @@ -1373,33 +1373,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/3e/1959d5219a9e6d200638d924cedda6a606392f7186a4ed56478252e70d55/numpy-2.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e", size = 12820057 }, ] -[[package]] -name = "onnx" -version = "1.17.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/29/57053ba7787788ac75efb095cfc1ae290436b6d3a26754693cd7ed1b4fac/onnx-1.17.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:38b5df0eb22012198cdcee527cc5f917f09cce1f88a69248aaca22bd78a7f023", size = 16645616 }, - { url = "https://files.pythonhosted.org/packages/75/0d/831807a18db2a5e8f7813848c59272b904a4ef3939fe4d1288cbce9ea735/onnx-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d545335cb49d4d8c47cc803d3a805deb7ad5d9094dc67657d66e568610a36d7d", size = 15908420 }, - { url = "https://files.pythonhosted.org/packages/dd/5b/c4f95dbe652d14aeba9afaceb177e9ffc48ac3c03048dd3f872f26f07e34/onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3193a3672fc60f1a18c0f4c93ac81b761bc72fd8a6c2035fa79ff5969f07713e", size = 16046244 }, - { url = "https://files.pythonhosted.org/packages/08/a9/c1f218085043dccc6311460239e253fa6957cf12ee4b0a56b82014938d0b/onnx-1.17.0-cp310-cp310-win32.whl", hash = "sha256:0141c2ce806c474b667b7e4499164227ef594584da432fd5613ec17c1855e311", size = 14423516 }, - { url = "https://files.pythonhosted.org/packages/0e/d3/d26ebf590a65686dde6b27fef32493026c5be9e42083340d947395f93405/onnx-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:dfd777d95c158437fda6b34758f0877d15b89cbe9ff45affbedc519b35345cf9", size = 14528496 }, - { url = "https://files.pythonhosted.org/packages/e5/a9/8d1b1d53aec70df53e0f57e9f9fcf47004276539e29230c3d5f1f50719ba/onnx-1.17.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:d6fc3a03fc0129b8b6ac03f03bc894431ffd77c7d79ec023d0afd667b4d35869", size = 16647991 }, - { url = "https://files.pythonhosted.org/packages/7b/e3/cc80110e5996ca61878f7b4c73c7a286cd88918ff35eacb60dc75ab11ef5/onnx-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01a4b63d4e1d8ec3e2f069e7b798b2955810aa434f7361f01bc8ca08d69cce4", size = 15908949 }, - { url = "https://files.pythonhosted.org/packages/b1/2f/91092557ed478e323a2b4471e2081fdf88d1dd52ae988ceaf7db4e4506ff/onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a183c6178be001bf398260e5ac2c927dc43e7746e8638d6c05c20e321f8c949", size = 16048190 }, - { url = "https://files.pythonhosted.org/packages/ac/59/9ea23fc22d0bb853133f363e6248e31bcbc6c1c90543a3938c00412ac02a/onnx-1.17.0-cp311-cp311-win32.whl", hash = "sha256:081ec43a8b950171767d99075b6b92553901fa429d4bc5eb3ad66b36ef5dbe3a", size = 14424299 }, - { url = "https://files.pythonhosted.org/packages/51/a5/19b0dfcb567b62e7adf1a21b08b23224f0c2d13842aee4d0abc6f07f9cf5/onnx-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:95c03e38671785036bb704c30cd2e150825f6ab4763df3a4f1d249da48525957", size = 14529142 }, - { url = "https://files.pythonhosted.org/packages/b4/dd/c416a11a28847fafb0db1bf43381979a0f522eb9107b831058fde012dd56/onnx-1.17.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:0e906e6a83437de05f8139ea7eaf366bf287f44ae5cc44b2850a30e296421f2f", size = 16651271 }, - { url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 }, - { url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 }, - { url = "https://files.pythonhosted.org/packages/ae/20/6da11042d2ab870dfb4ce4a6b52354d7651b6b4112038b6d2229ab9904c4/onnx-1.17.0-cp312-cp312-win32.whl", hash = "sha256:317870fca3349d19325a4b7d1b5628f6de3811e9710b1e3665c68b073d0e68d7", size = 14424235 }, - { url = "https://files.pythonhosted.org/packages/35/55/c4d11bee1fdb0c4bd84b4e3562ff811a19b63266816870ae1f95567aa6e1/onnx-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:659b8232d627a5460d74fd3c96947ae83db6d03f035ac633e20cd69cfa029227", size = 14530453 }, -] - [[package]] name = "opt-einsum" version = "3.4.0" @@ -1591,20 +1564,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/b6/c5319caea262f4821995dca2107483b94a3345d4607ad797c76cb9c36bcc/propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54", size = 11818 }, ] -[[package]] -name = "protobuf" -version = "5.29.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f7/d1/e0a911544ca9993e0f17ce6d3cc0932752356c1b0a834397f28e63479344/protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620", size = 424945 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/7a/1e38f3cafa022f477ca0f57a1f49962f21ad25850c3ca0acd3b9d0091518/protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888", size = 422708 }, - { url = "https://files.pythonhosted.org/packages/61/fa/aae8e10512b83de633f2646506a6d835b151edf4b30d18d73afd01447253/protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a", size = 434508 }, - { url = "https://files.pythonhosted.org/packages/dd/04/3eaedc2ba17a088961d0e3bd396eac764450f431621b58a04ce898acd126/protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e", size = 417825 }, - { url = "https://files.pythonhosted.org/packages/4f/06/7c467744d23c3979ce250397e26d8ad8eeb2bea7b18ca12ad58313c1b8d5/protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84", size = 319573 }, - { url = "https://files.pythonhosted.org/packages/a8/45/2ebbde52ad2be18d3675b6bee50e68cd73c9e0654de77d595540b5129df8/protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f", size = 319672 }, - { url = "https://files.pythonhosted.org/packages/fd/b2/ab07b09e0f6d143dfb839693aa05765257bceaa13d03bf1a696b78323e7a/protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f", size = 172550 }, -] - [[package]] name = "psutil" version = "6.1.1" @@ -2536,10 +2495,6 @@ jax = [ { name = "jax" }, { name = "numpy" }, ] -onnx = [ - { name = "numpy" }, - { name = "onnx" }, -] riscv = [ { name = "riscemu" }, ] @@ -2557,8 +2512,6 @@ requires-dist = [ { name = "nbconvert", marker = "extra == 'dev'", specifier = ">=7.7.2,<8.0.0" }, { name = "nbval", marker = "extra == 'dev'", specifier = "<0.12" }, { name = "numpy", marker = "extra == 'jax'", specifier = "==2.2.1" }, - { name = "numpy", marker = "extra == 'onnx'", specifier = "==2.2.1" }, - { name = "onnx", marker = "extra == 'onnx'", specifier = "==1.17.0" }, { name = "ordered-set", specifier = "==4.1.0" }, { name = "pip", marker = "extra == 'dev'", specifier = "<25.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = "==4.0.1" }, diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py index 256f28802c..0c8eb9e589 100644 --- a/xdsl/dialects/__init__.py +++ b/xdsl/dialects/__init__.py @@ -178,11 +178,6 @@ def get_omp(): return OMP - def get_onnx(): - from xdsl.dialects.onnx import ONNX - - return ONNX - def get_pdl(): from xdsl.dialects.pdl import PDL @@ -363,7 +358,6 @@ def get_transform(): "mod_arith": get_mod_arith, "mpi": get_mpi, "omp": get_omp, - "onnx": get_onnx, "pdl": get_pdl, "printf": get_printf, "ptr_xdsl": get_ptr_xdsl, diff --git a/xdsl/dialects/onnx.py b/xdsl/dialects/onnx.py deleted file mode 100644 index 346975baea..0000000000 --- a/xdsl/dialects/onnx.py +++ /dev/null @@ -1,1106 +0,0 @@ -from __future__ import annotations - -import math -from abc import ABC -from typing import Annotated, ClassVar, cast - -from typing_extensions import Self - -from xdsl.dialects.builtin import ( - Any, - AnyFloat, - AnyFloatConstr, - AnyIntegerAttr, - AnyTensorType, - ArrayAttr, - DenseIntOrFPElementsAttr, - Float32Type, - FloatAttr, - IntegerAttr, - IntegerType, - MemRefType, - NoneType, - SSAValue, - StringAttr, - SymbolRefAttr, - TensorOrMemrefOf, - TensorType, -) -from xdsl.ir import ( - Attribute, - Dialect, -) -from xdsl.irdl import ( - ConstraintVar, - IRDLOperation, - VarConstraint, - attr_def, - base, - irdl_op_definition, - operand_def, - opt_attr_def, - result_def, -) -from xdsl.parser import Parser -from xdsl.printer import Printer -from xdsl.utils.exceptions import VerifyException - - -def verify_unidirectional_broadcast_shape( - lhs: TensorType[Attribute], rhs: TensorType[Attribute], res: TensorType[Attribute] -) -> None: - """ - Returns a unidirectional broadcastable shape - """ - lhs_shape = lhs.get_shape() - rhs_shape = rhs.get_shape() - expected_shape = unidirectional_broadcast_shape(list(lhs_shape), list(rhs_shape)) - if expected_shape is None: - raise VerifyException( - f"operands have incompatible shapes: {lhs_shape} and {rhs_shape}" - ) - res_type_shape = res.get_shape() - if ( - len(expected_shape) != len(res_type_shape) - or tuple(expected_shape) != res_type_shape - ): - raise VerifyException( - f"result shape {expected_shape} does not match result type {res}" - ) - - -def verify_multidirectional_broadcast_shape( - lhs: TensorType[Attribute], rhs: TensorType[Attribute], res: TensorType[Attribute] -) -> None: - """ - Returns a multidirectional broadcastable shape - """ - lhs_shape = lhs.get_shape() - rhs_shape = rhs.get_shape() - expected_shape = multidirectional_broadcast_shape(list(lhs_shape), list(rhs_shape)) - if expected_shape is None: - raise VerifyException( - f"operands have incompatible shapes: {lhs_shape} and {rhs_shape}" - ) - res_type_shape = res.get_shape() - if ( - len(expected_shape) != len(res_type_shape) - or tuple(expected_shape) != res_type_shape - ): - raise VerifyException( - f"result shape {expected_shape} does not match result type {res}" - ) - - -def unidirectional_broadcast_shape(lhs: list[int], rhs: list[int]) -> list[int] | None: - """ - In ONNX, tensor B is unidirectional broadcastable to tensor A if one of the following is true: - - 1. Tensor A and B both have exactly the same shape. - 2. Tensor A and B all have the same number of dimensions and - the length of each dimensions is either a common length or B's length is 1. - - 3.Tensor B has too few dimensions, and B can have its shapes prepended with a dimension of length 1 to satisfy - property 2. - - When unidirectional broadcasting happens, the output's shape is the same as the shape of A (i.e., - the larger shape of two input tensors) - - https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - """ - # Check if Tensor A and B both have exactly the same shape - if lhs == rhs: - return lhs - - lhs_len = len(lhs) - rhs_len = len(rhs) - prefix_len = lhs_len - rhs_len - if prefix_len < 0: - # lhs must not be shorter than rhs - return None - res_shape = lhs[:prefix_len] - for dl, dr in zip(lhs[prefix_len:], rhs): - if dl == dr or dr == 1 or dl == 1: - res_shape.append(max(dl, dr)) - else: - return None - return res_shape - - -def multidirectional_broadcast_shape( - lhs: list[int], rhs: list[int] -) -> list[int] | None: - """ - In ONNX, a set of tensors are multidirectional broadcastable to the same shape if one of the following is true: - - 1.The tensors all have exactly the same shape. - 2.The tensors all have the same number of dimensions and the length of each dimensions is either a common length or 1. - 3.The tensors that have too few dimensions can have their shapes prepended with a dimension of length 1 to satisfy property 2. - - https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md - """ - - if len(lhs) > len(rhs): - return unidirectional_broadcast_shape(rhs, lhs) - else: - return unidirectional_broadcast_shape(lhs, rhs) - - -class ElementwiseBinOpBase(IRDLOperation, ABC): - """Base class for element-wise binary operations on tensors with Numpy-style broadcasting.""" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - lhs = operand_def(TensorType[T]) - rhs = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = "`(` $lhs `,` $rhs `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)" - - def __init__(self, lhs: SSAValue, rhs: SSAValue, res_type: Attribute): - super().__init__( - operands=[lhs, rhs], - result_types=[res_type], - ) - - def verify_(self) -> None: - # Check that the arguments are broadcastable (using Numpy semantics) and that the result type is correct. - assert isinstance(lhs_type := self.lhs.type, TensorType) - assert isinstance(rhs_type := self.rhs.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - verify_multidirectional_broadcast_shape(lhs_type, rhs_type, res_type) # pyright: ignore[reportUnknownArgumentType] - - -@irdl_op_definition -class AddOp(ElementwiseBinOpBase): - name = "onnx.Add" - - -@irdl_op_definition -class SubOp(ElementwiseBinOpBase): - name = "onnx.Sub" - - -@irdl_op_definition -class MulOp(ElementwiseBinOpBase): - name = "onnx.Mul" - - -@irdl_op_definition -class DivOp(ElementwiseBinOpBase): - name = "onnx.Div" - - -@irdl_op_definition -class ReluOp(IRDLOperation): - """ - Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, - y = max(0, x), is applied to the tensor elementwise. - """ - - name = "onnx.Relu" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - operand = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = ( - "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" - ) - - def __init__(self, operand: SSAValue): - super().__init__( - operands=[operand], - result_types=[operand.type], - ) - - def verify_(self) -> None: - assert isinstance(operand_type := self.operand.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - operand_type = cast(TensorType[Attribute], operand_type) - res_type = cast(TensorType[Attribute], res_type) - - if operand_type != res_type: - raise VerifyException( - "Mismatch between operand type and res type of onnx.Relu" - ) - - -@irdl_op_definition -class GemmOp(IRDLOperation): - """ - General Matrix multiplication: https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 - A' = transpose(A) if transA else A - B' = transpose(B) if transB else B - Compute Y = alpha * A' * B' + beta * C, - where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), - input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N). - """ - - name = "onnx.Gemm" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - tensor_a = operand_def(TensorType[T]) - tensor_b = operand_def(TensorType[T]) - tensor_c = operand_def(TensorType[T]) - - alpha = opt_attr_def(FloatAttr[AnyFloat]) - beta = opt_attr_def(FloatAttr[AnyFloat]) - - trans_a = opt_attr_def(AnyIntegerAttr, attr_name="transA") - trans_b = opt_attr_def(AnyIntegerAttr, attr_name="transB") - - res_tensor = result_def(TensorType[T]) - assembly_format = ( - "`(` $tensor_a `,` $tensor_b `,`$tensor_c`)` attr-dict `:` `(` type($tensor_a) `," - "` type($tensor_b) `,`type($tensor_c)`)` `->` type($res_tensor) " - ) - - def __init__( - self, - tensor_a: SSAValue, - tensor_b: SSAValue, - tensor_c: SSAValue, - alpha: Attribute, - trans_a: Attribute, - trans_b: Attribute, - beta: Attribute, - ): - super().__init__( - attributes={ - "transA": trans_a, - "transB": trans_b, - "alpha": alpha, - "beta": beta, - }, - operands=[tensor_a, tensor_b, tensor_c], - result_types=[tensor_c.type], - ) - - def verify_(self) -> None: - assert isinstance(tensor_a_type := self.tensor_a.type, TensorType) - assert isinstance(tensor_b_type := self.tensor_b.type, TensorType) - assert isinstance(tensor_c_type := self.tensor_c.type, TensorType) - assert isinstance(res_tensor_type := self.res_tensor.type, TensorType) - - tensor_a_type = cast(TensorType[Attribute], tensor_a_type) - tensor_b_type = cast(TensorType[Attribute], tensor_b_type) - tensor_c_type = cast(TensorType[Attribute], tensor_c_type) - res_tensor_type = cast(TensorType[Attribute], res_tensor_type) - - # store dimensions of tensor A and tensor B - res_shape: list[int] = [] - tensor_a_shape = tensor_a_type.get_shape() - tensor_b_shape = tensor_b_type.get_shape() - - if tensor_a_type.get_num_dims() != 2: - raise VerifyException("tensor A should be a 2D tensor") - - if tensor_b_type.get_num_dims() != 2: - raise VerifyException("tensor B should be a 2D tensor") - - if self.trans_a is not None and self.trans_a.value.data == 1: - tensor_a_shape = tuple(reversed(tensor_a_shape)) - - if self.trans_b is not None and self.trans_b.value.data == 1: - tensor_b_shape = tuple(reversed(tensor_b_shape)) - - if self.beta is not None: - c_dims = tensor_c_type.get_num_dims() - if c_dims > 2: - raise VerifyException("tensor C should be a 1D tensor or 2D tensor") - - if tensor_a_shape[1] != tensor_b_shape[0]: - raise VerifyException( - f"operands have incompatible shapes: {tensor_a_shape} and {tensor_b_shape}" - ) - else: - res_shape.append(tensor_a_shape[0]) - res_shape.append(tensor_b_shape[1]) - - # Build tensor of tensor (A * B) computation - tensors_res = TensorType(tensor_a_type.element_type, res_shape) - verify_unidirectional_broadcast_shape( - tensors_res, tensor_c_type, res_tensor_type - ) - - -@irdl_op_definition -class ReshapeOp(IRDLOperation): - """ - Reshape the input tensor similar to numpy.reshape. - First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. - At most one dimension of the new shape can be -1. In this case, the value is - inferred from the size of the tensor and the remaining dimensions. A dimension - could also be 0, in which case the actual dimension value is unchanged (i.e. taken - from the input tensor). Shape (second input) could be an empty shape, which means converting to a scalar. - The input tensor's shape and the output tensor's shape are required to have the same number of elements. - - Attributes: - - allowzero int (default is 0): By default, when any value in the 'shape' input is equal to zero - the corresponding dimension value is copied from the input tensor dynamically. allowzero=1 indicates that if any - value in the 'shape' input is set to zero, the zero value is honoured, similar to NumPy. - - """ - - name = "onnx.Reshape" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - data = operand_def(TensorType[T]) - shape = operand_def(TensorType[IntegerType]) - reshaped = result_def(TensorType[T]) - - allow_zero = opt_attr_def(AnyIntegerAttr, attr_name="allowzero") - - assembly_format = "`(` $data `,` $shape `)` attr-dict `:` `(` type($data) `,` type($shape) `)` `->` type($reshaped)" - - def __init__(self, data: SSAValue, shape: SSAValue, allow_zero: Attribute): - super().__init__( - attributes={"allowzero": allow_zero}, - operands=[data, shape], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType) - assert isinstance(shape_type := self.shape.type, TensorType) - assert isinstance(reshaped_type := self.reshaped.type, TensorType) - - data_type = cast(TensorType[Attribute], data_type) - reshaped_type = cast(TensorType[Attribute], reshaped_type) - shape_type = cast(TensorType[Attribute], shape_type) - - if shape_type.element_type != IntegerType(64): - raise VerifyException( - "shape element type has to be a 64-bit signless integer" - ) - - data_type = data_type.get_shape() - shape_type = shape_type.get_shape() - reshaped_type = reshaped_type.get_shape() - - # Shape tensor rank can't be -1. - if shape_type[0] == -1: - raise VerifyException("Shape tensor rank must not be equal to -1") - - # There is currently only support for rank one shape tensors in onnx-mlir - # Shape tensor must have a constant shape. - if len(shape_type) != 1: - raise VerifyException("Shape tensor must have a rank one") - - # The input tensor's shape and the output tensor's shape are required to have the same number of elements. - if math.prod(data_type) != math.prod(reshaped_type): - raise VerifyException( - "Input tensor's shape and output tensor's shape must have the same number of elements" - ) - - -@irdl_op_definition -class AbsOp(IRDLOperation): - """ - Absolute takes one input data (Tensor) and produces one output data (Tensor) where absolute value, - y = abs(x), is applied to the tensor elementwise. - """ - - name = "onnx.Abs" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - operand = operand_def(TensorType[T]) - res = result_def(TensorType[T]) - assembly_format = ( - "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" - ) - - def __init__(self, operand: SSAValue): - super().__init__( - operands=[operand], - result_types=[operand.type], - ) - - def verify_(self) -> None: - assert isinstance(operand_type := self.operand.type, TensorType) - assert isinstance(res_type := self.res.type, TensorType) - operand_type = cast(TensorType[Attribute], operand_type) - res_type = cast(TensorType[Attribute], res_type) - if operand_type != res_type: - raise VerifyException( - "Mismatch between operand type and res type of onnx.Abs" - ) - - -@irdl_op_definition -class ConvOp(IRDLOperation): - """ - The convolution operator consumes an input tensor and a filter, and computes the output. - - Attributes: - - - auto_pad string (default is NOTSET): auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or - VALID. Where default value is NOTSET, which means explicit padding is used. SAME_UPPER or SAME_LOWER mean pad the - input so that output_shape[i] = ceil(input_shape[i] / strides[i]) for each axis i. The padding is split between - the two sides equally or almost equally (depending on whether it is even or odd). In case the padding is an odd - number, the extra padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER. - - -dilations list of ints: dilation value along each spatial axis of the filter. If not present, the dilation - defaults is 1 along each spatial axis. - - -group int (default is '1'): number of groups input channels and output channels are divided into. - - -kernel_shape list of ints: The shape of the convolution kernel. If not present, should be inferred from input W. - - -pads list of ints: Padding for the beginning and ending along each spatial axis, it can take any value greater - than or equal to 0. The value represent the number of pixels added to the beginning and end part of the - corresponding axis. `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], where xi_begin - the number of pixels added at the beginning of axis `i` and xi_end, the number of pixels added at the end of axis - `i`. This attribute cannot be used simultaneously with auto_pad attribute. If not present, the padding defaults - to 0 along start and end of each spatial axis. - - -strides list of ints: Stride along each spatial axis. If not present, the stride defaults is 1 along each spatial axis. - - """ - - name = "onnx.Conv" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - data = operand_def(TensorType[T]) - weight = operand_def(TensorType[T]) - bias = operand_def(base(TensorType[T]) | base(NoneType)) - res = result_def(TensorType[T]) - - auto_pad = attr_def(StringAttr) - dilations = attr_def(ArrayAttr[AnyIntegerAttr]) - group = attr_def(AnyIntegerAttr) - kernel_shape = attr_def(ArrayAttr[AnyIntegerAttr]) - pads = attr_def(ArrayAttr[AnyIntegerAttr]) - strides = attr_def(ArrayAttr[AnyIntegerAttr]) - - assembly_format = ( - "`(` $data `,` $weight `,`$bias`)` attr-dict `:` `(` type($data) `," - "` type($weight) `,`type($bias)`)` `->` type($res) " - ) - - def __init__( - self, - data: SSAValue, - weight: SSAValue, - bias: SSAValue, - auto_pad: Attribute, - dilations: Attribute, - group: Attribute, - kernel_shape: Attribute, - pads: Attribute, - strides: Attribute, - ): - super().__init__( - attributes={ - "auto_pad": auto_pad, - "dilations": dilations, - "group": group, - "kernel_shape": kernel_shape, - "pads": pads, - "strides": strides, - }, - operands=[data, weight, bias], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType) - assert isinstance(weight_type := self.weight.type, TensorType) - assert isinstance(bias_type := self.bias.type, TensorType | NoneType) - assert isinstance(res_type := self.res.type, TensorType) - - weight_type = cast(TensorType[Attribute], weight_type) - data_type = cast(TensorType[Attribute], data_type) - res_type = cast(TensorType[Attribute], res_type) - - # case that bias is a tensor type - if isinstance(bias_type, TensorType): - bias_type = bias_type.get_shape() - if len(bias_type) != 1: - raise VerifyException("bias must be 1D") - - weight_type = weight_type.get_shape() - # kernel_shape - kernel_shape_data: list[int] = [] - for value in self.kernel_shape: - val = value.value.data - kernel_shape_data.append(val) - if list(weight_type[-2:]) != kernel_shape_data: - raise VerifyException( - "kernel shape rank and weight tensor rank are not the same" - ) - - # dilations - for value in self.dilations: - val = value.value.data - if val <= 0: - raise VerifyException("dilation value must be non zero positive") - if len(self.dilations) != len(self.kernel_shape): - raise VerifyException( - "dilations rank and kernel shape rank are not the same" - ) - - # group - if self.group.value.data < 1: - raise VerifyException("group value must be nonnegative") - - # strides - for value in self.strides: - val = value.value.data - if val <= 0: - raise VerifyException("stride value must be non zero positive") - if len(self.strides) != len(self.kernel_shape): - raise VerifyException( - "strides rank and kernel shape rank are not the same " - ) - - # pads - for value in self.pads: - val = value.value.data - if val < 0: - raise VerifyException("pads value must be nonnegative") - if len(self.pads) != 2 * len(self.kernel_shape): - raise VerifyException("pads rank is not twice the kernel shape rank") - - auto_pad_strings = ["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] - if self.auto_pad.data not in auto_pad_strings: - raise VerifyException( - f"Invalid auto_pad string. Must be one of {auto_pad_strings}" - ) - - -@irdl_op_definition -class ConstantOp(IRDLOperation): - """ - Produce a constant tensor. - - Exactly one of the provided attributes, either value, sparse_value, or value_* must be specified. - - Attributes: - - sparse_value: sparse_tensor - The value for the elements of the output tensor in sparse format. (currently unsupported) - - value : tensor - The value for the elements of the output tensor. - - value_float: float - The value for the sole element for the scalar, float32, output tensor. - - value_floats: list of floats - The values for the elements for the 1D, float32, output tensor. - - value_int : int - The value for the sole element for the scalar, int64, output tensor. - - value_ints : list of ints - The values for the elements for the 1D, int64, output tensor. - - value_string : string - The value for the sole element for the scalar, UTF-8 string, output tensor. - - value_strings: list of strings - The values for the elements for the 1D, UTF-8 string, output tensor. - """ - - name = "onnx.Constant" - output = result_def(AnyTensorType) - - value = opt_attr_def(DenseIntOrFPElementsAttr) - value_float = opt_attr_def(FloatAttr[Float32Type]) - value_floats = opt_attr_def(ArrayAttr[FloatAttr[Float32Type]]) - value_int = opt_attr_def(IntegerAttr[IntegerType]) - value_ints = opt_attr_def(ArrayAttr[IntegerAttr[IntegerType]]) - value_string = opt_attr_def(StringAttr) - value_strings = opt_attr_def(ArrayAttr[StringAttr]) - - def __init__( - self, - value: Attribute | None, - value_float: Attribute | None, - value_floats: Attribute | None, - value_int: Attribute | None, - value_ints: Attribute | None, - value_string: Attribute | None, - value_strings: Attribute | None, - output_type: Attribute | None, - ): - super().__init__( - attributes={ - "value": value, - "value_float": value_float, - "value_floats": value_floats, - "value_int": value_int, - "value_ints": value_ints, - "value_string": value_string, - "value_strings": value_strings, - }, - operands=[], - result_types=[output_type], - ) - - def verify_(self) -> None: - if self.value is not None and not isinstance(self.value.type, TensorType): - raise VerifyException("value attribute type must be of type TensorType") - - if self.value_int is not None and self.value_int.type.width.data != 64: - raise VerifyException( - "value_int element type has to be a 64-bit signless integer" - ) - - if self.value_ints is not None: - for value in self.value_ints: - width = value.type.width.data - if width != 64: - raise VerifyException( - "value_ints elements type has to be a 64-bit signless integer" - ) - - attrs = [ - self.value, - self.value_float, - self.value_floats, - self.value_int, - self.value_ints, - self.value_string, - self.value_strings, - ] - used_attrs = sum(1 for attr in attrs if attr is not None) - if used_attrs != 1: - raise VerifyException( - f"Only one value attribute must be provided, but {used_attrs} were specified" - ) - - def print(self, printer: Printer): - if self.value is not None: - printer.print(" ") - printer.print(self.value) - - @classmethod - def parse(cls, parser: Parser) -> Self: - v = parser.parse_attribute() - if not isinstance(v, DenseIntOrFPElementsAttr): - raise NotImplementedError() - constant = cls(v, None, None, None, None, None, None, v.type) - return constant - - -@irdl_op_definition -class MaxPoolSingleOutOp(IRDLOperation): - """ - ONNX MaxPool operation with a single output. - - Attributes: - - - auto_pad string (default is 'NOTSET'): auto_pad must be either NOTSET, SAME_UPPER, SAME_LOWER or - VALID. Where default value is NOTSET, which means explicit padding is used. SAME_UPPER or SAME_LOWER mean pad the - input so that output_shape[i] = ceil(input_shape[i] / strides[i]) for each axis i. The padding is split between - the two sides equally or almost equally (depending on whether it is even or odd). In case the padding is an odd - number, the extra padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER. - - - ceil_mode int (default is '1'): Whether to use ceil or floor (default) to compute the output shape. - - - dilations list of ints: Dilation value along each spatial axis of filter. - - - kernel_shape list of ints: The size of the kernel along each axis. - - - pads list of ints: Padding for the beginning and ending along each spatial axis, it can take any value greater - than or equal to 0. The value represent the number of pixels added to the beginning and end part of the - corresponding axis. `pads` format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...], where xi_begin - the number of pixels added at the beginning of axis `i` and xi_end, the number of pixels added at the end of axis - `i`. This attribute cannot be used simultaneously with auto_pad attribute. If not present, the padding defaults - to 0 along start and end of each spatial axis. - - - storage_order int (default is '0') : The storage order of the tensor. 0 is row major, and 1 is column major. - This attribute is used only to convert an n-tuple index value into a single integer value for producing the - second output. - - - strides list of ints: Stride along each spatial axis. If not present, the stride defaults to 1 along each spatial axis - - """ - - name = "onnx.MaxPoolSingleOut" - - T: ClassVar = VarConstraint("T", AnyFloatConstr | base(IntegerType)) - data = operand_def(TensorOrMemrefOf(T)) - output = result_def(TensorOrMemrefOf(T)) - - auto_pad = attr_def(StringAttr) - ceil_mode = attr_def(AnyIntegerAttr) - dilations = attr_def(ArrayAttr[AnyIntegerAttr]) - kernel_shape = attr_def(ArrayAttr[AnyIntegerAttr]) - pads = attr_def(ArrayAttr[AnyIntegerAttr]) - storage_order = attr_def(AnyIntegerAttr) - strides = attr_def(ArrayAttr[AnyIntegerAttr]) - - assembly_format = ( - "`(` $data`)` attr-dict `:` `(` type($data) `)` `->` type($output)" - ) - - def __init__( - self, - data: SSAValue, - auto_pad: Attribute, - ceil_mode: Attribute, - dilations: Attribute, - kernel_shape: Attribute, - pads: Attribute, - storage_order: Attribute, - strides: Attribute, - ): - super().__init__( - attributes={ - "auto_pad": auto_pad, - "ceil_mode": ceil_mode, - "dilations": dilations, - "kernel_shape": kernel_shape, - "pads": pads, - "storage_order": storage_order, - "strides": strides, - }, - operands=[data], - result_types=[data.type], - ) - - def verify_(self) -> None: - assert isinstance(data_type := self.data.type, TensorType | MemRefType) - assert isinstance(output_type := self.output.type, TensorType | MemRefType) - - data_type = cast(TensorType[Attribute], data_type) - output_type = cast(TensorType[Attribute], output_type) - - # auto pad - auto_pad_strings = ["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] - if self.auto_pad.data not in auto_pad_strings: - raise VerifyException( - f"Invalid auto_pad string. Must be one of {auto_pad_strings}" - ) - - # ceil mode - if self.ceil_mode.value.data < 0 or self.ceil_mode.value.data > 1: - raise VerifyException("ceil value must be either zero or one") - - # kernel shape - if (input_dims := len(data_type.get_shape()) - 2) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"input data and kernel shape rank mismatch: ({input_dims}) vs ({kernel_dims})" - ) - - # dilations - for value in self.dilations: - val = value.value.data - if val <= 0: - raise VerifyException("dilation value must be non zero positive") - - if (dilations_dims := len(self.dilations)) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"dilations rank ({dilations_dims}) and kernel shape rank ({kernel_dims}) are not the " - f"same " - ) - - # storage order - # Not supported for storage order in column major mode in onnx-mlir (therefore row major mode only considered) - if self.storage_order.value.data != 0: - raise VerifyException("column major storage order not implemented yet") - - # strides - for value in self.strides: - val = value.value.data - if val <= 0: - raise VerifyException("stride value must be non zero positive") - - if (strides_dims := len(self.strides)) != ( - kernel_dims := len(self.kernel_shape) - ): - raise VerifyException( - f"strides rank ({strides_dims}) and kernel shape rank ({kernel_dims}) are not the " - f"same " - ) - - # pads - for value in self.pads: - val = value.value.data - if val < 0: - raise VerifyException("pads value must be nonnegative") - - if (pads_dims := len(self.pads)) != 2 * len(self.kernel_shape): - raise VerifyException( - f"pads rank ({pads_dims}) is not twice the kernel shape rank ({len(self.kernel_shape)})" - ) - - -@irdl_op_definition -class EntryPointOp(IRDLOperation): - """ - Indicate ONNX entry point - The "onnx.EntryPoint" function indicates the main entry point of ONNX model. - """ - - name = "onnx.EntryPoint" - func = attr_def(SymbolRefAttr) - - def __init__(self, func: Attribute): - super().__init__( - attributes={ - "func": func, - }, - ) - - -@irdl_op_definition -class MatMulOp(IRDLOperation): - """ - The operation MatMul performs matrix multiplication between two input matrices, A and B, and returns the result as matrix Y. - Matrix multiplication is a fundamental operation in linear algebra, where each element of the resulting matrix Y is computed by taking the - dot product of the corresponding row of matrix A and column of matrix B. - """ - - name = "onnx.MatMul" - - # describe annotated type - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - - # input matrices - matrix_A = operand_def(TensorType[T]) - matrix_B = operand_def(TensorType[T]) - - # output matrices - matrix_Y = result_def(TensorType[T]) - - assembly_format = ( - "`(` $matrix_A `,` $matrix_B `)` attr-dict `:` `(` type($matrix_A) `," - "` type($matrix_B) `)` `->` type($matrix_Y) " - ) - - def __init__( - self, - matrix_A: SSAValue, - matrix_B: SSAValue, - matrix_Y_type: Attribute, - ): - super().__init__( - operands=[matrix_A, matrix_B], - result_types=[matrix_Y_type], - ) - - def verify_(self) -> None: - # store dimensions of tensor A and tensor B - res_shape: list[int] = [] - matrix_A_type = cast(TensorType[Any], self.matrix_A.type) - matrix_B_type = cast(TensorType[Any], self.matrix_B.type) - matrix_Y_type = cast(TensorType[Any], self.matrix_Y.type) - - # check shape compatibility - matrix_A_shape = matrix_A_type.get_shape() - matrix_B_shape = matrix_B_type.get_shape() - - if matrix_A_type.get_num_dims() != 2: - raise VerifyException("input matrix A should be a 2D tensor") - - if matrix_B_type.get_num_dims() != 2: - raise VerifyException("input matrix B should be a 2D tensor") - - if matrix_A_shape[1] != matrix_B_shape[0]: - raise VerifyException( - f"operands have incompatible shapes: {matrix_A_shape} and {matrix_B_shape}" - ) - else: - res_shape.append(matrix_A_shape[0]) - res_shape.append(matrix_B_shape[1]) - - matrix_Y_type_shape = list(matrix_Y_type.get_shape()) - if ( - len(res_shape) != len(matrix_Y_type_shape) - or res_shape != matrix_Y_type_shape - ): - raise VerifyException( - f"result shape {res_shape} does not match result type {matrix_Y_type_shape}" - ) - - -@irdl_op_definition -class TransposeOp(IRDLOperation): - """ - The transpose_tensor function takes a tensor as input and returns its transpose. - Transposing a tensor means flipping its dimensions, so that rows become columns and vice versa. - """ - - name = "onnx.Transpose" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - tensor_input = operand_def(TensorType[T]) - - perm = opt_attr_def(ArrayAttr[AnyIntegerAttr], attr_name="perm") - - tensor_output = result_def(TensorType[T]) - - assembly_format = ( - "`(` $tensor_input `)` attr-dict `:` `(` type($tensor_input) " - "`)` `->` type($tensor_output) " - ) - - def __init__(self, tensor_input: SSAValue, perm: Attribute): - super().__init__( - attributes={"perm": perm}, - operands=[tensor_input], - result_types=[tensor_input.type], - ) - - def verify_(self) -> None: - assert isinstance(tensor_input_type := self.tensor_input.type, TensorType) - assert isinstance(tensor_output_type := self.tensor_output.type, TensorType) - - tensor_input_shape = tensor_input_type.get_shape() - tensor_output_shape = tensor_output_type.get_shape() - - # numbers in perm cannot be repeated - if self.perm is not None: - for _, int_attr in enumerate(self.perm.data): - attr_value = int_attr.value.data - count = self.perm.data.count(int_attr) - if count != 1: - raise VerifyException( - f"permutation can not contain more than one occurrence of the same dimension: dimension #{attr_value} appears {count} times." - ) - - # numbers in perm must be between 0 and len(tensor_input_shape)-1 - perm_size = len(self.perm.data) - for int_index, int_attr in enumerate(self.perm.data): - int_index = int_index + 0 - int_attr_val = int_attr.value.data - if int_attr_val < 0 or int_attr_val >= perm_size: - raise VerifyException( - f"permutation can only contain values between 0 and {perm_size}-1: dimension #{int_index} value is {int_attr_val}" - ) - - # len(tensor_input_shape) must be equal to len(perm) - perm_size = len(self.perm.data) - input_size = len(tensor_input_shape) - if perm_size != input_size: - raise VerifyException( - f"permutation and inputs dimensions must have the same size: #dimensions input is {input_size}, #dimension perimutation is {perm_size}" - ) - - # check output shape - for index_attr, int_attr in enumerate(self.perm.data): - int_attr_val = int_attr.value.data - if tensor_output_shape[index_attr] != tensor_input_shape[int_attr_val]: - raise VerifyException( - f"incorrect output shape: output dimension #{index_attr} should be equal to {tensor_input_shape[int_attr_val]}" - ) - - -@irdl_op_definition -class SqueezeOp(IRDLOperation): - """ - Squeeze the input tensor along the specified axes. - - Squeezing a tensor removes dimensions of size 1, effectively reducing the rank of the tensor and collapsing those dimensions. - This operation is particularly useful for removing unnecessary singleton dimensions, which may arise from broadcasting or previous operations. - - Args: - input_tensor: The input tensor to be squeezed. This tensor should be a multi-dimensional array-like object. - axes: A list of axes along which to squeeze the tensor. If provided, only the specified axes will be squeezed. If not provided, all dimensions of size 1 will be squeezed. - - Returns: - output_tensor: The squeezed tensor. - """ - - name = "onnx.Squeeze" - - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - input_tensor = operand_def(TensorType[T]) - axes = opt_attr_def(base(AnyIntegerAttr), attr_name="axes") - - output_tensor = result_def(TensorType[T]) - - assembly_format = "`(` $input_tensor `)` attr-dict `:` `(` type($input_tensor) `)` `->` type($output_tensor) " - - def __init__( - self, - input_tensor: SSAValue, - axes: Attribute, - ): - super().__init__( - attributes={ - "axes": axes, - }, - operands=[input_tensor], - result_types=[input_tensor.type], - ) - - def verify_(self) -> None: - assert isinstance(input_tensor_type := self.input_tensor.type, TensorType), ( - "onnx elementwise operation operands and result must be of type TensorType" - ) - - input_tensor_shape = input_tensor_type.get_shape() - - if self.axes is not None: - axes_value = self.axes.value.data - - # axes out of bounds: the axes value must between 0 and len(input_tensor.shape)-1 - if axes_value < 0 or axes_value >= len(input_tensor_shape): - max_axes_value = len(input_tensor_shape) - 1 - raise VerifyException( - f"axes to squeeze must be between 0 and {max_axes_value}, axes: {axes_value}" - ) - - -@irdl_op_definition -class SigmoidOp(IRDLOperation): - """ - Applies the sigmoid function element-wise to all elements of the input tensor. - The sigmoid function, denoted by sigma(x), is a common mathematical function used in machine learning and neural networks. It is defined as: - sigma(x) = 1 / (1 + e^-x) - where e is the base of the natural logarithm. The sigmoid function maps any real-valued number to the range of [0, 1]. - The sigmoid function is used as an activation function. - - Args: - - input_tensor (TensorType): The input tensor to which the sigmoid function will be applied. - - Returns: - - output_tensor (TensorType): The output tensor after applying the sigmoid function element-wise to the input tensor. - """ - - name = "onnx.Sigmoid" - - T = Annotated[AnyFloat, ConstraintVar("T")] - input_tensor = operand_def(TensorType[T]) - output_tensor = result_def(TensorType[T]) - - assembly_format = "`(` $input_tensor`)` attr-dict `:` `(` type($input_tensor) `)` `->` type($output_tensor) " - - def __init__( - self, - input_tensor: SSAValue, - ): - super().__init__( - operands=[input_tensor], - result_types=[input_tensor.type], - ) - - def verify_(self) -> None: - assert isinstance(input_tensor_type := self.input_tensor.type, TensorType) - assert isinstance(output_tensor_type := self.output_tensor.type, TensorType) - - input_tensor_shape = input_tensor_type.get_shape() - output_tensor_shape = output_tensor_type.get_shape() - - # check if input tensor and output tensor have the same shape - if input_tensor_shape != output_tensor_shape: - raise VerifyException( - f"tensor input shape {input_tensor_shape} is not equal to tensor output shape {output_tensor_shape}" - ) - - -ONNX = Dialect( - "onnx", - [ - AbsOp, - AddOp, - ConstantOp, - ConvOp, - DivOp, - EntryPointOp, - GemmOp, - MatMulOp, - MaxPoolSingleOutOp, - MulOp, - ReluOp, - ReshapeOp, - SubOp, - TransposeOp, - SqueezeOp, - SigmoidOp, - ], -) diff --git a/xdsl/frontend/onnx/__init__.py b/xdsl/frontend/onnx/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/xdsl/frontend/onnx/ir_builder.py b/xdsl/frontend/onnx/ir_builder.py deleted file mode 100644 index 41c867f488..0000000000 --- a/xdsl/frontend/onnx/ir_builder.py +++ /dev/null @@ -1,93 +0,0 @@ -from onnx import GraphProto, NodeProto, ValueInfoProto - -from xdsl.dialects import func, onnx -from xdsl.dialects.builtin import ModuleOp -from xdsl.frontend.onnx.type import get_type -from xdsl.ir import Attribute, SSAValue -from xdsl.irdl import IRDLOperation - -OP_BY_OP_TYPE: dict[str, type[IRDLOperation]] = { - "Add": onnx.AddOp, - "Sub": onnx.SubOp, - "MatMul": onnx.MatMulOp, - "Transpose": onnx.TransposeOp, -} -"""Associate the name of the operations with the respective operation in ONNX dialect.""" - - -class OnnxXdslMapping: - """The representation of the onnx context.""" - - type_by_name: dict[str, Attribute] - value_by_name: dict[str, SSAValue] - - def __init__(self): - self.type_by_name = {} - self.value_by_name = {} - - -def visit_node(node: NodeProto, ctx: OnnxXdslMapping) -> IRDLOperation: - """Update the onnx context with the current node of the onnx graph.""" - if node.op_type not in OP_BY_OP_TYPE: - raise ValueError(f"Unknown ONNX op name {node.op_type}") - - op_class = OP_BY_OP_TYPE[node.op_type] - - operands = tuple(ctx.value_by_name[name] for name in node.input) - result_types = tuple(ctx.type_by_name[name] for name in node.output) - - op = op_class.build(operands=operands, result_types=result_types) - results = op.results - - for output_name, result in zip(node.output, results, strict=True): - ctx.value_by_name[output_name] = result - - return op - - -def visit_value_info(i: ValueInfoProto, ctx: OnnxXdslMapping) -> Attribute: - """Given the onnx ValueInforProto, it returns the corresponding Attribute stored in the context.""" - name = i.name - - if name in ctx.type_by_name: - return ctx.type_by_name[name] - - t = get_type(i.type) - ctx.type_by_name[name] = t - return t - - -def build_module(graph: GraphProto) -> ModuleOp: - """Create the ModuleOp based on the onnx graph provided.""" - - ctx = OnnxXdslMapping() - fn = visit_graph(graph, ctx) - module = ModuleOp([fn]) - - return module - - -def visit_graph(g: GraphProto, ctx: OnnxXdslMapping) -> IRDLOperation: - """ - Visit the onnx graph to update the onnx context. - """ - - name = g.name - - input_types = tuple(visit_value_info(input, ctx) for input in g.input) - output_types = tuple(visit_value_info(output, ctx) for output in g.output) - - fn = func.FuncOp(name, (input_types, output_types)) - - for input, arg in zip(g.input, fn.body.block.args, strict=True): - ctx.value_by_name[input.name] = arg - - for node in g.node: - results = visit_node(node, ctx) - fn.body.block.add_op(results) - - returned_values = tuple(ctx.value_by_name[output.name] for output in g.output) - retfn = func.ReturnOp(*returned_values) - fn.body.block.add_op(retfn) - - return fn diff --git a/xdsl/frontend/onnx/type.py b/xdsl/frontend/onnx/type.py deleted file mode 100644 index b09135bdce..0000000000 --- a/xdsl/frontend/onnx/type.py +++ /dev/null @@ -1,46 +0,0 @@ -from onnx import TensorShapeProto, TypeProto - -from xdsl.dialects.builtin import TensorType, f32, f64 -from xdsl.ir import Attribute - -ELEM_TYPE = { - 1: f32, - 11: f64, -} -""" -Dictionary containing information about the type. -""" - - -def get_elem_type(code: int) -> Attribute: - """ - It takes an ONNX tensor element type code as input and returns the corresponding Attribute. - """ - if code in ELEM_TYPE: - return ELEM_TYPE[code] - else: - raise ValueError(f"Unknown elem_type: {code}") - - -def get_type(type: TypeProto) -> Attribute: - """ - It takes the type in ONNX as input and returns the corresponding Attribute. - """ - tt = get_tensor_type(type.tensor_type) - return tt - - -def get_shape(shape: TensorShapeProto) -> tuple[int, ...]: - """ - It returns the shape of a tensor in ONNX. - """ - return tuple(dim.dim_value for dim in shape.dim) - - -def get_tensor_type(tensor: TypeProto.Tensor) -> TensorType[Attribute]: - """ - Function that returns the type of the tensor in ONNX. - """ - elem_type = get_elem_type(tensor.elem_type) - shape = get_shape(tensor.shape) - return TensorType(elem_type, shape) diff --git a/xdsl/interpreters/__init__.py b/xdsl/interpreters/__init__.py index 6cb1e68d80..1c597b8ad6 100644 --- a/xdsl/interpreters/__init__.py +++ b/xdsl/interpreters/__init__.py @@ -25,12 +25,7 @@ ) -def register_implementations( - interpreter: Interpreter, - ctx: MLContext, - *, - include_onnx: bool = True, -): +def register_implementations(interpreter: Interpreter, ctx: MLContext): interpreter.register_implementations(affine.AffineFunctions()) interpreter.register_implementations(arith.ArithFunctions()) interpreter.register_implementations(builtin.BuiltinFunctions()) @@ -52,7 +47,3 @@ def register_implementations( interpreter.register_implementations(scf.ScfFunctions()) interpreter.register_implementations(snitch_stream.SnitchStreamFunctions()) interpreter.register_implementations(tensor.TensorFunctions()) - if include_onnx: - from xdsl.interpreters import onnx - - interpreter.register_implementations(onnx.OnnxFunctions()) diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py deleted file mode 100644 index 2d26cc0b38..0000000000 --- a/xdsl/interpreters/onnx.py +++ /dev/null @@ -1,440 +0,0 @@ -from typing import Any, cast, overload - -import numpy as np -import numpy.typing as npt - -from xdsl.dialects import onnx -from xdsl.dialects.builtin import ( - Float32Type, - Float64Type, - IntAttr, - IntegerType, - PackableType, - TensorType, -) -from xdsl.interpreter import ( - Interpreter, - InterpreterFunctions, - ReturnedValues, - impl, - register_impls, -) -from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils import ptr -from xdsl.utils.exceptions import InterpretationError - - -def to_dtype( - xtype: PackableType[int] | PackableType[float], -) -> type[np.int32] | type[np.int64] | type[np.float32] | type[np.float64]: - match xtype: - case IntegerType(width=IntAttr(data=32)): - return np.int32 - case IntegerType(width=IntAttr(data=64)): - return np.int64 - case Float32Type(): - return np.float32 - case Float64Type(): - return np.float64 - case _: - raise NotImplementedError() - - -def from_dtype( - dtype: np.dtype[np.float32 | np.float64 | np.int32 | np.int64], -) -> PackableType[float] | PackableType[int]: - if dtype == np.float32: - return ptr.float32 - elif dtype == np.float64: - return ptr.float64 - elif dtype == np.float32: - return ptr.int32 - elif dtype == np.float64: - return ptr.int64 - else: - raise NotImplementedError() - - -@overload -def to_ndarray( - shaped_array: ShapedArray[float], -) -> npt.NDArray[np.float32 | np.float64]: ... - - -@overload -def to_ndarray( - shaped_array: ShapedArray[int], -) -> npt.NDArray[np.int32 | np.int64]: ... - - -def to_ndarray( - shaped_array: ShapedArray[int] | ShapedArray[float], -) -> npt.NDArray[np.float32 | np.float64 | np.int32 | np.int64]: - dtype = to_dtype(shaped_array.data_ptr.xtype) - flat = np.frombuffer(shaped_array.data_ptr.raw.memory, dtype) - shaped = flat.reshape(shaped_array.shape) - return shaped - - -def from_ndarray( - ndarray: npt.NDArray[np.float32 | np.float64 | np.int32 | np.int64], -) -> ShapedArray[float] | ShapedArray[int]: - xtype = from_dtype(np.dtype(ndarray.dtype)) - # TypedPtr's generic parameter is invariant, so ambiguous here - typed_ptr: ptr.TypedPtr[float] | ptr.TypedPtr[int] = ptr.TypedPtr( - ptr.RawPtr(bytearray(ndarray.data)), - xtype=xtype, # pyright: ignore[reportArgumentType] - ) - return ShapedArray( - typed_ptr, - list(ndarray.shape), - ) - - -@register_impls -class OnnxFunctions(InterpreterFunctions): - @impl(onnx.AddOp) - def run_add( - self, interpreter: Interpreter, op: onnx.AddOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) + to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.SubOp) - def run_sub( - self, interpreter: Interpreter, op: onnx.SubOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) - to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.MulOp) - def run_mul( - self, interpreter: Interpreter, op: onnx.MulOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) * to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.DivOp) - def run_div( - self, interpreter: Interpreter, op: onnx.DivOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - lhs, rhs = args[0], args[1] - assert isinstance(lhs, ShapedArray) - assert isinstance(rhs, ShapedArray) - lhs = cast(ShapedArray[float], lhs) - rhs = cast(ShapedArray[float], rhs) - result = to_ndarray(lhs) / to_ndarray(rhs) - return (from_ndarray(result),) - - @impl(onnx.ReluOp) - def run_relu( - self, interpreter: Interpreter, op: onnx.ReluOp, args: tuple[Any, ...] - ) -> tuple[Any, ...]: - operand = args[0] - assert isinstance(operand, ShapedArray) - operand = cast(ShapedArray[float], operand) - operand_data = to_ndarray(operand) - result = operand_data * (operand_data > 0) - return (from_ndarray(result),) - - @impl(onnx.ConstantOp) - def run_constant( - self, interpreter: Interpreter, op: onnx.ConstantOp, args: tuple[Any, ...] - ): - if op.value is None: - raise NotImplementedError("Only dense constant values implemented") - shape = op.value.get_shape() - data = op.value.get_values() - data_ptr = ptr.TypedPtr[Any].new( - data, - xtype=xtype_for_el_type( - op.value.get_element_type(), interpreter.index_bitwidth - ), - ) - return (ShapedArray(data_ptr, list(shape)),) - - @impl(onnx.ReshapeOp) - def run_reshape( - self, interpreter: Interpreter, op: onnx.ReshapeOp, args: tuple[Any, ...] - ): - if op.allow_zero is not None and op.allow_zero.value.data == 1: - raise NotImplementedError( - "allow_zero not yet supported in onnx.reshape interpreter" - ) - input, new_shape = args - assert isinstance(input, ShapedArray) - assert isinstance(new_shape, ShapedArray) - input = cast(ShapedArray[float], input) - new_shape = cast(ShapedArray[int], new_shape) - result_type = op.reshaped.type - assert isinstance(result_type, TensorType) - static_shape = list(result_type.get_shape()) - assert static_shape is not None - if static_shape != new_shape.data: - raise InterpretationError("Mismatch between static shape and new shape") - return (input.with_shape(new_shape.data),) - - @impl(onnx.GemmOp) - def run_gemm( - self, interpreter: Interpreter, op: onnx.GemmOp, args: tuple[Any, ...] - ): - a, b, c = args[0], args[1], args[2] - - alpha = op.alpha.value.data if op.alpha is not None else 1.0 - beta = op.beta.value.data if op.beta is not None else 1.0 - - assert isinstance(a, ShapedArray) - assert isinstance(b, ShapedArray) - assert isinstance(c, ShapedArray) - - a = cast(ShapedArray[float], a) - b = cast(ShapedArray[float], b) - c = cast(ShapedArray[float], c) - - nd_a = to_ndarray(a) - nd_b = to_ndarray(b) - nd_c = to_ndarray(c) - - if op.trans_a is not None and op.trans_a.value.data == 1: - nd_a = np.transpose(nd_a) - - if op.trans_b is not None and op.trans_b.value.data == 1: - nd_b = np.transpose(nd_b) - - result = alpha * nd_a @ nd_b + beta * nd_c - - return (from_ndarray(result),) - - @impl(onnx.ConvOp) - def run_conv( - self, interpreter: Interpreter, op: onnx.ConvOp, args: tuple[Any, ...] - ): - # initialise the attributes used - auto_pad = op.auto_pad.data - strides: list[int] = [value.value.data for value in op.strides] - matrix, kernel, bias = args[0], args[1], args[2] - pads: list[int] = [value.value.data for value in op.pads] - - matrix = cast(ShapedArray[float], matrix) - kernel = cast(ShapedArray[float], kernel) - bias = cast(ShapedArray[float], bias) - - matrix = to_ndarray(matrix) - kernel = to_ndarray(kernel) - - if auto_pad != "NOTSET": - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_height = int(np.ceil(matrix.shape[2] / strides[0])) - out_width = int(np.ceil(matrix.shape[3] / strides[1])) - - pad_along_height = max( - (out_height - 1) * strides[0] + kernel.shape[2] - matrix.shape[2], 0 - ) - pad_along_width = max( - (out_width - 1) * strides[1] + kernel.shape[3] - matrix.shape[3], 0 - ) - - if auto_pad == "SAME_UPPER": - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - else: - pad_bottom = pad_along_height // 2 - pad_top = pad_along_height - pad_bottom - pad_right = pad_along_width // 2 - pad_left = pad_along_width - pad_right - - pads = [pad_top, pad_bottom, pad_left, pad_right] - - elif auto_pad == "VALID": - pads = [0, 0, 0, 0] # set padding to all zeros - - if pads: - # case of asymmetric padding - pad_values = [ - (pads[i], pads[i + len(pads) // 2]) for i in range(len(pads) // 2) - ] - - # pad input matrix - padded_matrix = np.pad( - matrix, - ( - (0, 0), - (0, 0), - (pad_values[0][0], pad_values[0][1]), - (pad_values[1][0], pad_values[1][1]), - ), - mode="constant", - ) - - # padded shape case - m_height, m_width = padded_matrix.shape[2:] - - else: - m_height, m_width = matrix.shape[2:] - - padded_matrix = matrix - - # based on strides calculate the output shape - out_height = int((m_height - kernel.shape[2]) // strides[0] + 1) - out_width = int((m_width - kernel.shape[3]) // strides[1] + 1) - - output = np.zeros( - (matrix.shape[0], matrix.shape[1], out_height, out_width), - dtype=matrix.dtype, - ) - - # do convolution - for k in range(matrix.shape[0]): - for l in range(matrix.shape[1]): - for i in range(0, m_height - kernel.shape[2] + 1, strides[0]): - for j in range(0, m_width - kernel.shape[3] + 1, strides[1]): - output[k, l, i // strides[0], j // strides[1]] = np.sum( - padded_matrix[ - k, l, i : i + kernel.shape[2], j : j + kernel.shape[3] - ] - * kernel[k, l] - ) - - output += to_ndarray(bias) - - # the number of channels is not always fixed to one - result_type = op.res.type - assert isinstance(result_type, TensorType) - static_shape = list(result_type.get_shape()) - - result = np.array(output) - assert tuple(result.shape) == ( - 1, - static_shape[1], - output.shape[2], - output.shape[3], - ) - return (from_ndarray(result),) - - @impl(onnx.MaxPoolSingleOutOp) - def run_max_pool_single_out( - self, - interpreter: Interpreter, - op: onnx.MaxPoolSingleOutOp, - args: tuple[Any, ...], - ): - kernel_shape = tuple(value.value.data for value in op.kernel_shape) - - if len(kernel_shape) != 2: - raise NotImplementedError("Only 2d max pooling supported") - ky, kx = kernel_shape - - strides = tuple(value.value.data for value in op.strides) - - if len(strides) != 2: - raise NotImplementedError("Only 2d max pooling supported") - - # initialise the attributes used - auto_pad = op.auto_pad.data - (matrix,) = args - pads: list[int] = [value.value.data for value in op.pads] - - matrix = cast(ShapedArray[float], matrix) - - matrix = to_ndarray(matrix) - - if auto_pad != "NOTSET": - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_height = int(np.ceil(matrix.shape[2] / strides[0])) - out_width = int(np.ceil(matrix.shape[3] / strides[1])) - - pad_along_height = max( - (out_height - 1) * strides[0] + ky - matrix.shape[2], 0 - ) - pad_along_width = max( - (out_width - 1) * strides[1] + kx - matrix.shape[3], 0 - ) - - if auto_pad == "SAME_UPPER": - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - else: - pad_bottom = pad_along_height // 2 - pad_top = pad_along_height - pad_bottom - pad_right = pad_along_width // 2 - pad_left = pad_along_width - pad_right - - pads = [pad_top, pad_bottom, pad_left, pad_right] - - elif auto_pad == "VALID": - pads = [0, 0, 0, 0] # set padding to all zeros - - if pads: - # case of asymmetric padding - pad_values = [ - (pads[i], pads[i + len(pads) // 2]) for i in range(len(pads) // 2) - ] - - # pad input matrix - padded_matrix = np.pad( - matrix, - ( - (0, 0), - (0, 0), - (pad_values[0][0], pad_values[0][1]), - (pad_values[1][0], pad_values[1][1]), - ), - mode="constant", - ) - - # padded shape case - m_height, m_width = padded_matrix.shape[2:] - - else: - m_height, m_width = matrix.shape[2:] - - padded_matrix = matrix - - # based on strides calculate the output shape - out_height = int((m_height - ky) // strides[0] + 1) - out_width = int((m_width - kx) // strides[1] + 1) - - result = np.zeros( - (matrix.shape[0], matrix.shape[1], out_height, out_width), - dtype=matrix.dtype, - ) - - # do maxpool computation - for k in range(matrix.shape[0]): - for l in range(matrix.shape[1]): - for i in range(0, m_height - ky + 1, strides[0]): - for j in range(0, m_width - kx + 1, strides[1]): - result[k, l, i // strides[0], j // strides[1]] = np.nanmax( - padded_matrix[k, l, i : i + ky, j : j + kx] - ) - - # Numpy has two types of ndarray: ndarray and NDArray, weirdly they don't seem - # to be compatible, despite one being a typealias for the other... - output: Any = result - return (from_ndarray(output),) - - @impl(onnx.EntryPointOp) - def run_entry_point( - self, interpreter: Interpreter, op: onnx.EntryPointOp, args: tuple[Any, ...] - ): - return ReturnedValues(args), () diff --git a/xdsl/tools/xdsl_run.py b/xdsl/tools/xdsl_run.py index 96c66b302d..71dc56f302 100644 --- a/xdsl/tools/xdsl_run.py +++ b/xdsl/tools/xdsl_run.py @@ -35,12 +35,6 @@ def __init__( self.ctx.allow_unregistered = self.args.allow_unregistered_dialect def register_all_arguments(self, arg_parser: argparse.ArgumentParser): - arg_parser.add_argument( - "--onnx", - default=False, - action="store_true", - help="Enable the onnx-compilation interpreter.", - ) arg_parser.add_argument( "--verbose", default=False, @@ -71,11 +65,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): return super().register_all_arguments(arg_parser) def register_implementations(self, interpreter: Interpreter): - register_implementations( - interpreter, - self.ctx, - include_onnx=self.args.onnx, - ) + register_implementations(interpreter, self.ctx) def run(self): input, file_extension = self.get_input_stream() diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index 81a12e06e1..aabe89eda9 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -98,11 +98,6 @@ def get_convert_ml_program_to_memref(): return convert_ml_program_to_memref.ConvertMlProgramToMemrefPass - def get_convert_onnx_to_linalg(): - from xdsl.transforms import convert_onnx_to_linalg - - return convert_onnx_to_linalg.ConvertOnnxToLinalgPass - def get_convert_print_format_to_riscv_debug(): from xdsl.backend.riscv.lowering import convert_print_format_to_riscv_debug @@ -514,7 +509,6 @@ def get_varith_fuse_repeated_operands(): "convert-memref-to-ptr": get_convert_memref_to_ptr, "convert-memref-to-riscv": get_convert_memref_to_riscv, "convert-ml-program-to-memref": get_convert_ml_program_to_memref, - "convert-onnx-to-linalg": get_convert_onnx_to_linalg, "convert-print-format-to-riscv-debug": get_convert_print_format_to_riscv_debug, "convert-ptr-to-riscv": get_convert_ptr_to_riscv, "convert-qref-to-qssa": get_convert_qref_to_qssa, diff --git a/xdsl/transforms/constant_fold_interp.py b/xdsl/transforms/constant_fold_interp.py index acb4c621f4..d550d5fc09 100644 --- a/xdsl/transforms/constant_fold_interp.py +++ b/xdsl/transforms/constant_fold_interp.py @@ -85,7 +85,6 @@ class ConstantFoldInterpPass(ModulePass): def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: interpreter = Interpreter(op) - # Do not call onnx interpreter function for this pass - register_implementations(interpreter, ctx, include_onnx=False) + register_implementations(interpreter, ctx) pattern = ConstantFoldInterpPattern(interpreter) PatternRewriteWalker(pattern).rewrite_module(op) diff --git a/xdsl/transforms/convert_onnx_to_linalg.py b/xdsl/transforms/convert_onnx_to_linalg.py deleted file mode 100644 index 59845e5f54..0000000000 --- a/xdsl/transforms/convert_onnx_to_linalg.py +++ /dev/null @@ -1,415 +0,0 @@ -from dataclasses import dataclass -from typing import cast - -from xdsl.builder import ImplicitBuilder -from xdsl.context import MLContext -from xdsl.dialects import arith, linalg, ml_program, onnx, tensor -from xdsl.dialects.builtin import ( - AffineMapAttr, - AnyFloat, - DenseArrayBase, - DenseIntOrFPElementsAttr, - FloatAttr, - ModuleOp, - NoneType, - StringAttr, - SymbolRefAttr, - TensorType, - f32, - f64, - i64, -) -from xdsl.ir import Attribute, Block, Operation, Region -from xdsl.ir.affine import AffineMap -from xdsl.passes import ModulePass -from xdsl.pattern_rewriter import ( - GreedyRewritePatternApplier, - PatternRewriter, - PatternRewriteWalker, - RewritePattern, - op_type_rewrite_pattern, -) -from xdsl.traits import SymbolTable - - -def get_root_op(op: Operation | None) -> Operation | None: - """ - Recursively finds and returns the root operation associated with the given operation. - """ - return op if op is None or op.parent_op() is None else get_root_op(op.parent_op()) - - -@dataclass -class AddOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, add: onnx.AddOp, rewriter: PatternRewriter, /): - lhs_type = add.lhs.type - rhs_type = add.rhs.type - if isinstance(lhs_type, TensorType) and isinstance(rhs_type, TensorType): - lhs_shape = lhs_type.get_shape() - rhs_shape = rhs_type.get_shape() - - if -1 in lhs_shape or -1 in rhs_shape: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), add.res.type), - linalg.AddOp((add.lhs, add.rhs), (empty.tensor,), res=(add.res.type,)), - ) - ) - - -@dataclass -class SubOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, sub: onnx.SubOp, rewriter: PatternRewriter, /): - lhs_type = sub.lhs.type - rhs_type = sub.rhs.type - if isinstance(lhs_type, TensorType) and isinstance(rhs_type, TensorType): - lhs_shape = lhs_type.get_shape() - rhs_shape = rhs_type.get_shape() - - if -1 in lhs_shape or -1 in rhs_shape: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), sub.res.type), - linalg.SubOp((sub.lhs, sub.rhs), (empty.tensor,), res=(sub.res.type,)), - ) - ) - - -@dataclass -class ReluOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, relu: onnx.ReluOp, rewriter: PatternRewriter, /): - operand = relu.operand.type - assert isinstance(operand, TensorType) - operand = cast(TensorType[Attribute], operand) - operand_rank = len(operand.get_shape()) - body = Region(Block(arg_types=(operand.element_type, operand.element_type))) - affine_map = AffineMapAttr(AffineMap.identity(operand_rank)) - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), relu.res.type), - zero := arith.ConstantOp( - FloatAttr(0.0, cast(AnyFloat, operand.element_type)) - ), - linalg.GenericOp( - (relu.operand,), - (empty.tensor,), - body, - (affine_map, affine_map), - (linalg.IteratorTypeAttr.parallel(),) * operand_rank, - (relu.res.type,), - ), - ) - ) - with ImplicitBuilder(body) as (a, _): - max_op = arith.MaximumfOp(a, zero.result) - linalg.YieldOp(max_op.result) - - -@dataclass -class ConstantOpLowering(RewritePattern): - constant_count: int = 0 - - def make_unique_name(self): - self.constant_count += 1 - return f"onnx_constant_{self.constant_count}" - - @op_type_rewrite_pattern - def match_and_rewrite( - self, constant: onnx.ConstantOp, rewriter: PatternRewriter, / - ): - attr_value = list(constant.attributes.values())[1] - constant_name = self.make_unique_name() - global_op = ml_program.GlobalOp( - StringAttr(constant_name), - constant.output.type, - None, - attr_value, - StringAttr("private"), - ) - root_op = get_root_op(constant) - if root_op is not None and root_op.has_trait(SymbolTable): - SymbolTable.insert_or_update(root_op, global_op) - rewriter.replace_matched_op( - ( - ml_program.GlobalLoadConstantOp( - SymbolRefAttr(global_op.sym_name), - global_op.type, - ), - ) - ) - - -@dataclass -class ReshapeOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, reshape: onnx.ReshapeOp, rewriter: PatternRewriter, /): - # Dynamic shapes not currently supported - source_type = reshape.data.type - shape_type = reshape.shape.type - if isinstance(source_type, TensorType) and isinstance(shape_type, TensorType): - source_shape = source_type.get_shape() - shape_shape = shape_type.get_shape() - - if -1 in source_shape or -1 in shape_shape: - raise NotImplementedError() - - # Lowering with `allowzero = 1` attribute not supported" - if reshape.allow_zero is not None and reshape.allow_zero.value.data == 1: - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - tensor.ReshapeOp( - reshape.data, - reshape.shape, - reshape.reshaped.type, - ), - ) - ) - - -@dataclass -class GemmOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, gemm: onnx.GemmOp, rewriter: PatternRewriter, /): - assert isinstance(tensor_a_type := gemm.tensor_a.type, TensorType) - assert isinstance(tensor_b_type := gemm.tensor_b.type, TensorType) - assert isinstance(tensor_c_type := gemm.tensor_c.type, TensorType) - - tensor_a_type = cast(TensorType[Attribute], tensor_a_type) - tensor_b_type = cast(TensorType[Attribute], tensor_b_type) - tensor_c_type = cast(TensorType[Attribute], tensor_c_type) - - tensor_a_shape = tensor_a_type.get_shape() - tensor_b_shape = tensor_b_type.get_shape() - tensor_c_shape = tensor_c_type.get_shape() - - # Dynamic shapes not currently supported - if any( - -1 in shape for shape in [tensor_a_shape, tensor_b_shape, tensor_c_shape] - ): - raise NotImplementedError() - - perm: list[int] = [1, 0] - permutation = DenseArrayBase.create_dense_int(i64, perm) - - # if transA is set, trans_a is changed to this op - trans_a_res = None - if gemm.trans_a is not None and gemm.trans_a.value.data == 1: - shape_type = tensor_a_type.element_type - # onnx.gemm supports only 2D tensors, hence reversing is acceptable - shape = tuple(reversed(tensor_a_shape)) - empty_shape = TensorType(shape_type, shape) - empty = tensor.EmptyOp((), empty_shape) - trans_a = linalg.TransposeOp( - gemm.tensor_a, empty.tensor, permutation, empty.tensor.type - ) - # save the result - trans_a_res = trans_a.result[0] - rewriter.insert_op_before_matched_op([empty, trans_a]) - - # if transB is set, trans_b is changed to this op - trans_b_res = None - if gemm.trans_b is not None and gemm.trans_b.value.data == 1: - shape_type = tensor_b_type.element_type - # onnx.gemm supports only 2D tensors, hence reversing is acceptable - shape = tuple(reversed(tensor_b_shape)) - empty_shape = TensorType(shape_type, shape) - empty = tensor.EmptyOp((), empty_shape) - trans_b = linalg.TransposeOp( - gemm.tensor_b, empty.tensor, permutation, empty.tensor.type - ) - # save the result - trans_b_res = trans_b.result[0] - rewriter.insert_op_before_matched_op([empty, trans_b]) - - # if trans_a occurs, else remain - if trans_a_res is not None: - trans_a = trans_a_res - else: - trans_a = gemm.tensor_a - - # if trans_b occurs, else remain - if trans_b_res is not None: - trans_b = trans_b_res - else: - trans_b = gemm.tensor_b - - # alpha * A - alpha_res = None - if gemm.alpha is not None and gemm.alpha.value.data != 1: - constant = arith.ConstantOp( - FloatAttr(gemm.alpha.value.data, gemm.alpha.type) - ) - alpha_mul_result = linalg.MulOp( - (constant.result, trans_a), - (trans_a,), - (trans_a.type,), - ) - alpha_res = alpha_mul_result.res[0] - rewriter.insert_op_before_matched_op([constant, alpha_mul_result]) - - # if alpha * a does not occur remain on previous trans_a else switch - if alpha_res is not None: - trans_a = alpha_res - - # beta * C - beta_mul_result = gemm.tensor_c - beta_res = None - if gemm.beta is not None and gemm.beta.value.data != 1: - constant = arith.ConstantOp(FloatAttr(gemm.beta.value.data, gemm.beta.type)) - beta_mul_result = linalg.MulOp( - ( - constant.result, - beta_mul_result, - ), - (gemm.tensor_c,), - (gemm.tensor_c.type,), - ) - beta_res = beta_mul_result.res[0] - rewriter.insert_op_before_matched_op([constant, beta_mul_result]) - - # this is beta * c result else its just c - if beta_res is not None: - beta_mul_result = beta_res - else: - beta_mul_result = gemm.tensor_c - - # (A * B) + beta * C - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp( - (), - gemm.res_tensor.type, - ), - # A * B - mat_mul_res := linalg.MatmulOp( - (trans_a, trans_b), - (empty.tensor,), - res=(gemm.res_tensor.type,), - ), - # (A * B) + beta * C - linalg.AddOp( - (mat_mul_res.results[0], beta_mul_result), - (mat_mul_res.results[0],), - res=(mat_mul_res.results[0].type,), - ), - ) - ) - - -@dataclass -class MaxPoolSingleOutOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite( - self, max_pool_single_out: onnx.MaxPoolSingleOutOp, rewriter: PatternRewriter, / - ): - kernel: list[int] = [ - value.value.data for value in max_pool_single_out.kernel_shape.data - ] - dilations: list[int] = [ - value.value.data for value in max_pool_single_out.dilations.data - ] - strides: list[int] = [ - value.value.data for value in max_pool_single_out.strides.data - ] - kernel_shape = TensorType(f32, kernel) - - # Lowering with `storage_order = 1` attribute not supported" - if ( - max_pool_single_out.storage_order.value.data != 0 - and max_pool_single_out.storage_order - ): - raise NotImplementedError() - - rewriter.replace_matched_op( - ( - empty := tensor.EmptyOp((), kernel_shape), - init := tensor.EmptyOp((), max_pool_single_out.output.type), - # Since we're unable to represent +/- infinity, - # we currently use the maximum value by sys - cst := arith.ConstantOp(FloatAttr(-1e308, f64)), - fill := linalg.FillOp( - (cst.result,), - (init.tensor,), - (max_pool_single_out.output.type,), - ), - linalg.PoolingNchwMaxOp( - DenseIntOrFPElementsAttr.tensor_from_list(dilations, i64, [2]), - DenseIntOrFPElementsAttr.tensor_from_list(strides, i64, [2]), - ( - max_pool_single_out.data, - empty.tensor, - ), - (fill.results[0],), - (max_pool_single_out.output.type,), - ), - ) - ) - - -@dataclass -class ConvOpLowering(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, conv: onnx.ConvOp, rewriter: PatternRewriter, /): - dilations = tuple(value.value.data for value in conv.dilations.data) - strides = tuple(value.value.data for value in conv.strides.data) - - if conv.group.value.data != 1: - raise NotImplementedError("Only 1 group supported") - - if not all(dilation == 1 for dilation in dilations): - raise NotImplementedError("Only 1 dilation supported") - - empty = tensor.EmptyOp((), conv.res.type) - conv_op = linalg.Conv2DNchwFchwOp( - DenseIntOrFPElementsAttr.tensor_from_list(dilations, i64, [2]), - DenseIntOrFPElementsAttr.tensor_from_list(strides, i64, [2]), - ( - conv.data, - conv.weight, - ), - (empty.tensor,), - (conv.res.type,), - ) - conv_ops = ( - empty, - conv_op, - ) - if not isinstance(conv.bias.type, NoneType): - add_bias = linalg.AddOp( - (conv.bias,), - (conv_op.results[0],), - res=(conv.res.type,), - ) - conv_ops += (add_bias,) - rewriter.replace_matched_op(conv_ops) - - -@dataclass(frozen=True) -class ConvertOnnxToLinalgPass(ModulePass): - name = "convert-onnx-to-linalg" - - def apply(self, ctx: MLContext, op: ModuleOp) -> None: - PatternRewriteWalker( - GreedyRewritePatternApplier( - [ - AddOpLowering(), - SubOpLowering(), - ReluOpLowering(), - ConstantOpLowering(), - ReshapeOpLowering(), - GemmOpLowering(), - MaxPoolSingleOutOpLowering(), - ConvOpLowering(), - ] - ), - apply_recursively=False, - ).rewrite_module(op)