Skip to content

Commit

Permalink
[microNPU] Fixing imports in the entry point
Browse files Browse the repository at this point in the history
This commit fixes errornous reporting that Vela
is missing if other import errors.

Change-Id: I8db97be10018726cf5d9483508321a176c212516
  • Loading branch information
manupak committed Dec 1, 2021
1 parent 2275359 commit 429f02f
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
# pylint: disable=ungrouped-imports, import-outside-toplevel
"""Arm(R) Ethos(TM)-U NPU supported operators."""
import functools

Expand All @@ -36,14 +36,6 @@
# rely on imports from ethos-u-vela, we protect them with the decorator @requires_vela
# implemented below.
from ethosu.vela import api as vapi # type: ignore
from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import get_dim_value
except ImportError:
vapi = None

Expand Down Expand Up @@ -116,6 +108,8 @@ def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: List

def check_weights(weights: TensorParams, dilation: List[int]):
"""This function checks whether weight tensor is compatible with the NPU"""
from tvm.relay.backend.contrib.ethosu.util import get_dim_value

dilated_height_range = (1, 64)
dilated_hxw_range = (1, 64 * 64)
weights_limit = 127 * 65536
Expand Down Expand Up @@ -200,6 +194,10 @@ class QnnConv2DParams:

@requires_vela
def __init__(self, func_body: tvm.relay.Function):
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

activation = None
if str(func_body.op) in self.activation_map.keys():
activation = func_body
Expand Down Expand Up @@ -472,6 +470,8 @@ class BinaryElementwiseParams:
"""

def __init__(self, func_body: Call, operator_type: str, has_quantization_parameters: bool):
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs

clip = None
if str(func_body.op) == "clip":
clip = func_body
Expand Down Expand Up @@ -869,6 +869,9 @@ class AbsParams:
composite_name = "ethos-u.abs"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs

quantize = func_body
abs_op = quantize.args[0]
dequantize = abs_op.args[0]
Expand Down Expand Up @@ -1037,6 +1040,8 @@ def partition_for_ethosu(
mod : IRModule
The partitioned IRModule with external global functions
"""
from tvm.relay.backend.contrib.ethosu import preprocess

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

Expand Down

0 comments on commit 429f02f

Please sign in to comment.