Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX-MLIR dialect support #40

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nebullvm/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, List, Dict, Union
from typing import Dict, List, Tuple, Union


class DataType(str, Enum):
Expand Down Expand Up @@ -106,6 +106,7 @@ class ModelCompiler(Enum):
OPENVINO = "openvino"
APACHE_TVM = "tvm"
ONNX_RUNTIME = "onnxruntime"
ONNX_MLIR_RUNTIME = "onnx_mlir"


class QuantizationType(Enum):
Expand Down
4 changes: 4 additions & 0 deletions nebullvm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@
"description_file": "description.xml",
"weights": "weights.bin",
}

ONNX_MLIR_FILENAMES = {
"model_name": "mlir_model.so",
}
211 changes: 211 additions & 0 deletions nebullvm/inference_learners/onnx_mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import copy
import os
import shutil
import sys
import warnings
from abc import ABC
from pathlib import Path
from typing import Dict, Generator, List, Tuple, Type, Union

import cpuinfo
import numpy as np
import tensorflow as tf
import torch
from nebullvm.base import DeepLearningFramework, ModelParams
from nebullvm.config import ONNX_MLIR_FILENAMES
from nebullvm.inference_learners.base import (BaseInferenceLearner,
LearnerMetadata,
PytorchBaseInferenceLearner,
TensorflowBaseInferenceLearner)

try:
# Set the ONNX_MLIR_HOME as the environment variable and append in the path,
# directory path where the MLIR is built
MLIR_INSTALLATION_ROOT = Path.home()

os.environ['ONNX_MLIR_HOME'] = os.path.join(
MLIR_INSTALLATION_ROOT,
'onnx-mlir',
'build',
'Debug',
)

sys.path.append(
os.path.join(
os.environ.get('ONNX_MLIR_HOME', ''),
'lib',
)
)
import PyRuntime
except ImportError:
warnings.warn(
"No valid onnxruntime installation found. Trying to install it..."
)
from nebullvm.installers.installers import install_onnx_mlir

install_onnx_mlir(
working_dir=MLIR_INSTALLATION_ROOT,
)
import PyRuntime


class ONNXMlirInferenceLearner(BaseInferenceLearner, ABC):
"""Model converted from ONNX to Shared Object file using ONNX-MLIR dialect
and run with ONNX-MLIR's PyRuntime
created at onnx-mlir/build/Debug/lib/PyRuntime.cpython-<target>.so.

Attributes:
onnx_mlir_model_path (str or Path): Path to the shared object mlir model.
network_parameters (ModelParams): The model parameters as batch
size, input and output sizes.
"""

def __init__(
self,
onnx_mlir_model_path: Union[str, Path],
**kwargs,
):
super().__init__(**kwargs)
self.onnx_mlir_model_path = onnx_mlir_model_path
self._session = PyRuntime.ExecutionSession(
os.path.abspath(str(self.onnx_mlir_model_path)),
)

def save(self, path: Union[str, Path], **kwargs):
"""Save the model.

Args:
path (Path or str): Path to the directory where the model will
be stored.
kwargs (Dict): Dictionary of key-value pairs that will be saved in
the model metadata file.
"""
metadata = LearnerMetadata.from_model(self, **kwargs)
metadata.save(path)

shutil.copy(
self.onnx_mlir_model_path,
os.path.join(str(path), ONNX_MLIR_FILENAMES["model_name"]),
)

@classmethod
def load(cls, path: Union[Path, str], **kwargs):
"""Load the model.

Args:
path (Path or str): Path to the directory where the model is
stored.
kwargs (Dict): Dictionary of additional arguments for consistency
with other Learners.

Returns:
ONNXInferenceLearner: The optimized model.
"""
if len(kwargs) > 0:
warnings.warn(
f"No extra keywords expected for the load method. "
f"Got {kwargs}."
)
onnx_mlir_model_path = os.path.join(
str(path), ONNX_MLIR_FILENAMES["model_name"])
metadata = LearnerMetadata.read(path)

return cls(
network_parameters=ModelParams(**metadata.network_parameters),
onnx_mlir_model_path=onnx_mlir_model_path,
)

def _predict_arrays(self, input_arrays: Generator[np.ndarray, None, None]):
outputs = self._session.run(
list(input_arrays)
)
return outputs


class PytorchONNXMlirInferenceLearner(
ONNXMlirInferenceLearner, PytorchBaseInferenceLearner
):
"""Model run with Microsoft's onnxruntime using a Pytorch interface.

Attributes:
onnx_mlir_model_path (str or Path): Path to the shared object mlir model.
network_parameters (ModelParams): The model parameters as batch
size, input and output sizes.
"""

def predict(self, *input_tensors: torch.Tensor) -> Tuple[torch.Tensor]:
"""Predict on the input tensors.

Note that the input tensors must be on the same batch. If a sequence
of tensors is given when the model is expecting a single input tensor
(with batch size >= 1) an error is raised.

Args:
input_tensors (Tuple[Tensor]): Input tensors belonging to the same
batch. The tensors are expected having dimensions
(batch_size, dim1, dim2, ...).

Returns:
Tuple[Tensor]: Output tensors. Note that the output tensors does
not correspond to the prediction on the input tensors with a
1 to 1 mapping. In fact the output tensors are produced as the
multiple-output of the model given a (multi-) tensor input.
"""
input_arrays = (
input_tensor.cpu().detach().numpy()
for input_tensor in input_tensors
)
outputs = self._predict_arrays(input_arrays)
return tuple(
torch.from_numpy(output)
for output in outputs
)


class TensorflowONNXMlirInferenceLearner(
ONNXMlirInferenceLearner, TensorflowBaseInferenceLearner
):
"""Model run with Microsoft's onnxruntime using a tensorflow interface.

Attributes:
onnx_mlir_model_path (str or Path): Path to the shared object mlir model.
network_parameters (ModelParams): The model parameters as batch
size, input and output sizes.
"""

def predict(self, *input_tensors: tf.Tensor) -> Tuple[tf.Tensor]:
"""Predict on the input tensors.

Note that the input tensors must be on the same batch. If a sequence
of tensors is given when the model is expecting a single input tensor
(with batch size >= 1) an error is raised.

Args:
input_tensors (Tuple[Tensor]): Input tensors belonging to the same
batch. The tensors are expected having dimensions
(batch_size, dim1, dim2, ...).

Returns:
Tuple[Tensor]: Output tensors. Note that the output tensors does
not correspond to the prediction on the input tensors with a
1 to 1 mapping. In fact the output tensors are produced as the
multiple-output of the model given a (multi-) tensor input.
"""
input_arrays = (
input_tensor.numpy()
for input_tensor in input_tensors
)
outputs = self._predict_arrays(input_arrays)
# noinspection PyTypeChecker
return tuple(
tf.convert_to_tensor(output)
for output in outputs
)


ONNX_MLIR_INFERENCE_LEARNERS: Dict[
DeepLearningFramework, Type[ONNXMlirInferenceLearner]
] = {
DeepLearningFramework.PYTORCH: PytorchONNXMlirInferenceLearner,
DeepLearningFramework.TENSORFLOW: TensorflowONNXMlirInferenceLearner,
}
53 changes: 53 additions & 0 deletions nebullvm/installers/install_onnx_mlir.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash

# Installation steps to build the ONNX-MLIR from source

# Set non interactive mode for apt-get
export DEBIAN_FRONTEND=noninteractive

# Build ONNX-MLIR

if [ ! -d "onnx-mlir" ]
then

git clone --recursive https://github.com/onnx/onnx-mlir.git onnx-mlir
fi


if [ -z "$NPROC" ]
then
NPROC=4
fi


# Export environment variables pointing to LLVM-Projects.
export MLIR_DIR=$(pwd)/llvm-project/build/lib/cmake/mlir

# Get the python interpreter path
export PYTHON_LOCATION=$(which python3)

mkdir onnx-mlir/build && cd onnx-mlir/build

if [[ -z "$PYTHON_LOCATION" ]]; then
cmake -G Ninja \
-DCMAKE_CXX_COMPILER=/usr/bin/c++ \
-DMLIR_DIR=${MLIR_DIR} \
..
else
echo "Using python path " $PYTHON_LOCATION
echo "Using MLIR_DIR " $MLIR_DIR

cmake -G Ninja \
-DCMAKE_CXX_COMPILER=/usr/bin/c++ \
-DPython3_ROOT_DIR=${PYTHON_LOCATION} \
-DPython3_EXECUTABLE=${PYTHON_LOCATION} \
-DMLIR_DIR=${MLIR_DIR} \
..

fi

cmake --build . --parallel $NPROC

# Run lit tests:
export LIT_OPTS=-v
cmake --build . --parallel $NPROC --target check-onnx-lit
77 changes: 77 additions & 0 deletions nebullvm/installers/install_onnx_mlir_prerequisites.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/bin/bash

# Installation steps to build and install the llvm-project from source

# Set non interactive mode for apt-get
export DEBIAN_FRONTEND=noninteractive

if [ -z "$NPROC" ]
then
export NPROC=4
fi

# Install the OS dependent required packages
if [[ $OSTYPE == "darwin"* ]]
then
brew install gcc git cmake ninja pybind11
elif [[ "$(grep '^ID_LIKE' /etc/os-release)" == *"centos"* ]]
then
sudo yum update -q -y && \
sudo yum install -q -y \
autoconf automake ca-certificates cmake diffutils \
file java-11-openjdk-devel java-11-openjdk-headless \
gcc gcc-c++ git libtool make ncurses-devel \
zlib-devel && \
# Install ninja
git clone -b v1.10.2 https://github.com/ninja-build/ninja.git && \
cd ninja && mkdir -p build && cd build && \
cmake .. && \
make -j$NPROC install && \
cd ../.. && rm -rf ninja;
else
sudo apt-get update && sudo apt-get install -y --no-install-recommends \
autoconf automake ca-certificates cmake curl \
default-jdk-headless gcc g++ git libncurses-dev \
libtool make maven ninja-build openjdk-11-jdk-headless \
zlib1g-dev

fi

# Install protobuf
PROTOBUF_VERSION=3.14.0
git clone -b v$PROTOBUF_VERSION --recursive https://github.com/google/protobuf.git \
&& cd protobuf && ./autogen.sh \
&& ./configure --enable-static=no \
&& make -j$NPROC install && ldconfig \
&& cd python && python setup.py install \
&& cd ../.. && rm -rf protobuf

# Install jsoniter
JSONITER_VERSION=0.9.23
JSONITER_URL=https://repo1.maven.org/maven2/com/jsoniter/jsoniter/$JSONITER_VERSION \
&& JSONITER_FILE=jsoniter-$JSONITER_VERSION.jar \
&& curl -s $JSONITER_URL/$JSONITER_FILE -o /usr/share/java/$JSONITER_FILE


# ONNX-MLIR needs the llvm-project build from the source

# Firstly, install MLIR (as a part of LLVM-Project):
git clone -n https://github.com/llvm/llvm-project.git


# Check out a specific branch that is known to work with ONNX-MLIR.
# TBD: Option to set the commit hash dynamically
cd llvm-project && git checkout a7ac120a9ad784998a5527fc0a71b2d0fd55eccb && cd ..

mkdir llvm-project/build
cd llvm-project/build

cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON

cmake --build . --parallel $NPROC -- ${MAKEFLAGS}
cmake --build . --parallel $NPROC --target check-mlir
Loading