Skip to content

Commit

Permalink
[microNPU] Fixing imports in the entry point (apache#9624)
Browse files Browse the repository at this point in the history
This commit fixes errornous reporting that Vela
is missing if other import errors.

Change-Id: I8db97be10018726cf5d9483508321a176c212516
  • Loading branch information
manupak authored and ylc committed Jan 13, 2022
1 parent eec0ef1 commit 9209391
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports
# pylint: disable=ungrouped-imports, import-outside-toplevel
"""Arm(R) Ethos(TM)-U NPU supported operators."""
import functools

Expand All @@ -36,14 +36,6 @@
# rely on imports from ethos-u-vela, we protect them with the decorator @requires_vela
# implemented below.
from ethosu.vela import api as vapi # type: ignore
from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import get_dim_value
except ImportError:
vapi = None

Expand Down Expand Up @@ -116,6 +108,8 @@ def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: List

def check_weights(weights: TensorParams, dilation: List[int]):
"""This function checks whether weight tensor is compatible with the NPU"""
from tvm.relay.backend.contrib.ethosu.util import get_dim_value

dilated_height_range = (1, 64)
dilated_hxw_range = (1, 64 * 64)
weights_limit = 127 * 65536
Expand Down Expand Up @@ -200,6 +194,10 @@ class QnnConv2DParams:

@requires_vela
def __init__(self, func_body: tvm.relay.Function):
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

activation = None
if str(func_body.op) in self.activation_map.keys():
activation = func_body
Expand Down Expand Up @@ -472,6 +470,8 @@ class BinaryElementwiseParams:
"""

def __init__(self, func_body: Call, operator_type: str, has_quantization_parameters: bool):
from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs

clip = None
if str(func_body.op) == "clip":
clip = func_body
Expand Down Expand Up @@ -869,6 +869,9 @@ class AbsParams:
composite_name = "ethos-u.abs"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs

quantize = func_body
abs_op = quantize.args[0]
dequantize = abs_op.args[0]
Expand Down Expand Up @@ -1037,6 +1040,8 @@ def partition_for_ethosu(
mod : IRModule
The partitioned IRModule with external global functions
"""
from tvm.relay.backend.contrib.ethosu import preprocess

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

Expand Down

0 comments on commit 9209391

Please sign in to comment.