diff --git a/.github/workflows/cc_bot.yml b/.github/workflows/cc_bot.yml index dd50eba79358..873fafa82a60 100644 --- a/.github/workflows/cc_bot.yml +++ b/.github/workflows/cc_bot.yml @@ -43,4 +43,4 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | set -eux - python tests/scripts/github_cc_reviewers.py + python tests/scripts/github_cc_reviewers.py || echo step failed diff --git a/docker/Dockerfile.ci_qemu b/docker/Dockerfile.ci_qemu index bba458458efc..2cae59c35d67 100644 --- a/docker/Dockerfile.ci_qemu +++ b/docker/Dockerfile.ci_qemu @@ -42,10 +42,7 @@ COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh RUN bash /install/ubuntu_install_rust.sh ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust - -# wasmtime -COPY install/ubuntu_install_wasmtime.sh /install/ubuntu_install_wasmtime.sh -RUN bash /install/ubuntu_install_wasmtime.sh +ENV PATH $PATH:$CARGO_HOME/bin # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh diff --git a/docs/contribute/git_howto.rst b/docs/contribute/git_howto.rst index 0fa904ff2ef6..1271aad8a268 100644 --- a/docs/contribute/git_howto.rst +++ b/docs/contribute/git_howto.rst @@ -24,7 +24,7 @@ Git Usage Tips Here are some tips for git workflow. How to resolve a conflict with ``main`` -------------------------------------- +--------------------------------------- - First rebase to most recent main diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py index 5ad7fde1ed52..68a218da2616 100644 --- a/python/tvm/contrib/ethosu/cascader/device_config.py +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name """Device config class to hold information about the target hardware""" -from typing import Tuple, List, Dict +from typing import Tuple, List, Dict, Optional from functools import reduce import math @@ -332,6 +332,7 @@ def _get_input_block( def get_kernel_steps( self, + op_type: str, dilated_kernel_h: int, dilated_kernel_w: int, ifm_dtype: str, @@ -341,6 +342,9 @@ def get_kernel_steps( Parameters ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" dilated_kernel_h: int Height of dilated kernel dilated_kernel_w: int @@ -355,18 +359,23 @@ def get_kernel_steps( List[int] List where each entry contains the amount of elements in one of the subkernels """ + if op_type == "ethosu_binary_elementwise": + return [1] + subkernels = self._get_subkernels(dilated_kernel_h, dilated_kernel_w) # Determine the number of kernel steps per subkernel kernel_steps = [] for y, x in subkernels: subkernel_elements = x * y - if is_partkernel: - # Part-kernel-first traversal + if op_type == "ethosu_conv2d" and is_partkernel: + # Part-kernel-first traversal conv2d divisor = 4 if ifm_dtype == "int8" else 2 kernel_steps.append(int(_round_up_div(subkernel_elements, divisor))) + elif op_type == "ethosu_depthwise_conv2d": + kernel_steps.append(int(_round_up_div(subkernel_elements, 4))) else: - # Depth-first traversal + # Depth-first traversal conv2d or pooling kernel_steps.append(int(subkernel_elements)) return kernel_steps @@ -430,11 +439,133 @@ def is_partkernel( return part_kernel_first_utilization > depth_first_utilization or ifm_channels <= 8 + def get_elementwise_block_config( + self, + ifm_propagator: Propagator, + ifm2_propagator: Optional[Propagator], + op_attrs: Dict, + ofm_shape: List[int], + output_layout: str, + input_layout: str, + input2_layout: Optional[str], + ifm_dtype: str, + ofm_dtype: str, + ) -> List[BlockConfig]: + """Get a suitable block config for an elementwise operator + + Parameters + ---------- + ifm_propagator: Propagator, + The propagator containing the data dependencies between input and output + ifm2_propagator: Propagator, + The propagator containing the data dependencies between input2 and output + op_attrs: Dict, + Dictionary containing operator attributes + ofm_shape: List[int], + Shape of the output tensor + output_layout: str, + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + input_layout: str, + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + input2_layout: str, + The layout of the Input2 Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm_dtype: str, + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str, + Datatype of the Output Feature Map tensor (OFM) + + Returns + ---------- + List[BlockConfig] + List containing a single suitable block config + """ + block_config = [] + output_shape = [int(a) for a in ofm_shape] + + op_type = op_attrs.get("op") + op_str = op_attrs.get("op_str") + activation = op_attrs.get("activation", "NONE") + + input_bytewidth = 1 if ifm_dtype == "int8" else 2 if ifm_dtype == "int16" else 4 + banks_available = self._total_banks - self._reserved_banks + if activation == "LUT" and not self._lut_reserved: + banks_available -= 2 + + # Split the block in half until it fits into SHRAM + if output_layout == "NHCWB16": + split_order = (a for a in [1, 3, 2]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2] * output_shape[4], self._max_block_shape.depth), + min(output_shape[3], self._max_block_shape.width), + 16, + ] + else: + split_order = (a for a in [1, 2, 3]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2], self._max_block_shape.width), + min(output_shape[3], self._max_block_shape.depth), + ] + split_axis = next(split_order) + while True: + # Create stripe config for output block + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + order = [1, 2, 4, 3, 0] if output_layout == "NHCWB16" else [1, 2, 3, 4] + output_stripe_config = StripeConfig( + output_block, output_block, output_block, order, stripes, offset + ) + + # Propagate the output to obtain the two input blocks + input_block = _Shape(ifm_propagator.propagate(output_stripe_config).shape, input_layout) + if ifm2_propagator: + input2_block = _Shape( + ifm2_propagator.propagate(output_stripe_config).shape, input2_layout + ) + else: + # Unary elementwise + input2_block = _Shape([0, 0, 0, 0]) + + input_block.round_up(self._input_micro_block) + input2_block.round_up(self._input_micro_block) + + # Banks required for input block + input_bytes = input_block.area() * self._align(input_block.depth * input_bytewidth, 8) + input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2 + input_banks = _round_up(input_banks, self._input_granularity) + + # Banks required for input2 block + input2_bytes = input2_block.area() * self._align( + input2_block.depth * input_bytewidth, 8 + ) + input2_banks = _round_up_div(input2_bytes, self._bank_size_bytes) * 2 + input2_banks = _round_up(input2_banks, self._input_granularity) + + # Check whether or not both IFMs fit into SHRAM + if (input_banks + input2_banks) <= banks_available: + output_cycles = self._get_output_cycles( + op_type, op_str, ifm_dtype, ofm_dtype, activation + ) + output_cycles *= reduce(lambda a, b: a * b, output_block, 1) + output_cycles = int(math.ceil(output_cycles)) + block_config.append(BlockConfig(output_block, 0, output_cycles)) + break + + if output_block[split_axis] == 1: + split_axis = next(split_order) + + output_block[split_axis] = _round_up_div(output_block[split_axis], 2) + + return block_config + def get_valid_block_configs( self, ifm_propagator: Propagator, op_attrs: Dict, - output_shape: List[int], + ofm_shape: List[int], ofm_channels: int, ifm_channels: int, output_layout: str, @@ -452,7 +583,7 @@ def get_valid_block_configs( The propagator containing the data dependencies between input and output op_attrs: Dict, Dictionary containing operator attributes - output_shape: List[int], + ofm_shape: List[int], Shape of the output tensor ofm_channels: int, Number of output channels @@ -487,9 +618,9 @@ def get_valid_block_configs( subkernel_transform = ifm_propagator.transform if output_layout == "NHCWB16": - output_shape = _Shape([1, output_shape[1], output_shape[3], ofm_channels]) + output_shape = _Shape([1, ofm_shape[1], ofm_shape[3], ofm_channels]) else: - output_shape = _Shape(output_shape) + output_shape = _Shape(ofm_shape) if input_layout == "NHCWB16": subkernel_transform[1][-1] = min( @@ -571,6 +702,7 @@ def get_valid_block_configs( input_block_shape = _Shape(input_block.shape, input_layout) input_block_shape.round_up(self._input_micro_block) + output_block_shape = _Shape(output_block, output_layout) if op_type == "ethosu_conv2d": @@ -592,12 +724,11 @@ def get_valid_block_configs( acc_banks = _round_up(acc_banks, self._accumulator_granularity[acc_bytewidth]) if (input_banks + acc_banks) <= banks_available: - output_cycles = self._get_output_cycles( op_type, op_str, ifm_dtype, ofm_dtype, activation ) output_cycles *= reduce(lambda a, b: a * b, output_block, 1) - output_cycles = int(_round_up(output_cycles, 1)) + output_cycles = int(math.ceil(output_cycles)) compute_cycles = self._estimate_compute_cycles_per_block( op_type, output_block_shape, @@ -634,7 +765,7 @@ def _estimate_compute_cycles_per_block( num_quantum_z = _round_up_div(block_shape.depth, self._micro_block.depth) num_quantum_xy = num_quantum_x * num_quantum_y - kernel_steps = self.get_kernel_steps(kernel_h, kernel_w, ifm_dtype, is_partkernel) + kernel_steps = self.get_kernel_steps(op_type, kernel_h, kernel_w, ifm_dtype, is_partkernel) wd_cycles = self._get_weight_decoder_cycles(op_type) delay_cycles = self._get_delay_cycles(op_type, ifm_dtype) @@ -642,8 +773,9 @@ def _estimate_compute_cycles_per_block( compute_cycles = 0 for subkernel_steps in kernel_steps: + subkernel_cycles = 1 if op_type == "ethosu_pooling" else subkernel_steps compute_cycles += ( - max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_steps * num_quantum_z + max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_cycles * num_quantum_z ) if num_quantum_xy == 1: diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 0884b249df48..7666691aa19f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -18,13 +18,13 @@ import tvm from tvm import relay +from tvm import ir from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator from tvm.relay.backend.contrib.ethosu import util from tvm.relay.expr_functor import ExprMutator -from tvm.ir.transform import Pass # pylint: disable=unused-import from tvm.relay.backend.contrib.ethosu.op import op_attrs @@ -109,13 +109,11 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") -class LUTsOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer") +class LUTsOptimizer: """Register LUTsOptimizer as a relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """Visit relay nodes in the given module. Parameters @@ -131,7 +129,13 @@ def transform_function( New module with optimized LUTs. """ assert len(mod.functions.items()) == 1, "Module can only contain one function." - return OptimizeLUTs().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = OptimizeLUTs().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass class LayoutOptimization(ExprMutator): @@ -247,19 +251,23 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return super().visit_call(call) -@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer") -class LayoutOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer") +class LayoutOptimizer: """Register LayoutOptimizer as a Relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """A pass to optimize the layout of NPU operations. If both the producer and consumer of a tensor are NPU operators, then the layout is converted from NHWC to NHCWB16 as this is the layout NPU uses internally.""" assert len(mod.functions.items()) == 1, "Module can only contain one function." - return LayoutOptimization().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = LayoutOptimization().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8f5d6c24f0f6..d52f3ba6eca5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -230,7 +230,7 @@ def __call__(self, *args, **kwargs): def sigmoid_calc_func(x: float) -> float: """Function to calculate the values for sigmoid""" - # Thse limits are inherited from TFLite + # These limits are inherited from TFLite upper_limit = 8.0 lower_limit = -8.0 diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py index c1d39556d11d..8446b0c2e4ad 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -17,7 +17,10 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for binary_elementwise""" import operator +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -123,6 +126,12 @@ def binary_elementwise_compute( te.Tensor The Output Feature Map tensor. """ + assert ifm.shape[0] == 1 + assert ifm2.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ifm2_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute( ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0) @@ -187,5 +196,147 @@ def binary_elementwise_compute( attrs=binary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - int(broadcast[1])), 0, 0, int(broadcast[1])], + [0, 0, (1 - int(broadcast[2])), 0, int(broadcast[2])], + [0, 0, 0, (1 - int(broadcast[3])), int(broadcast[3])], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + ifm2_propagator = Propagator( + ifm2_matrix, + [0, 0, 0, 0] if ifm2_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "ifm2_propagator": ifm2_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ifm_channels) + return dma_ofm_compute( + binary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ifm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_binary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Binary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + binary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if binary_elementwise.op.name != "ethosu_binary_elementwise": + return None + pad = binary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + pad2 = binary_elementwise.op.input_tensors[1] + if pad2.op.name != "ethosu_pad": + return None + convert_to_nhwc2 = pad2.op.input_tensors[0] + if convert_to_nhwc2.op.name != "ethosu_convert_to_nhwc": + return None + read2 = convert_to_nhwc2.op.input_tensors[0] + if read2.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + read2.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["ifm2_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + input2_layout = convert_to_nhwc2.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + propagators[1], + binary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + input2_layout, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index c61082beb737..ea2290ef1e5f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -297,7 +297,9 @@ def match_ethosu_conv2d(output_tensor, device_config): conv2d.op.name, ifm_channels, ifm_dtype, kernel_elements ) subkernels = len( - device_config.get_kernel_steps(kernel_height, kernel_width, ifm_dtype, is_part_kernel) + device_config.get_kernel_steps( + conv2d.op.name, kernel_height, kernel_width, ifm_dtype, is_part_kernel + ) ) output_layout = convert_to_nhcwb16.op.attrs["layout"] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index f54f2f3654e2..ff09662cc14a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -17,8 +17,11 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for depthwise convolutions""" from typing import Tuple, Union, List +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -110,9 +113,10 @@ def depthwise_conv2d_compute( assert ifm_layout in {"NHWC", "NHCWB16"} assert ofm_layout in {"NHWC", "NHCWB16"} - stride_h, stride_w = strides - dilation_h, dilation_w = dilation - channels, kernel_h, kernel_w, _ = weight.shape + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + dilation_h, dilation_w = [int(v) for v in dilation] + channels, kernel_h, kernel_w, _ = [int(v) for v in weight.shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding) @@ -165,5 +169,155 @@ def depthwise_conv2d_compute( attrs=depthwise_conv2d_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weights_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weights_matrix = np.matmul(weights_matrix, nhcwb16_to_nhwc).tolist() + bias_matrix = np.matmul(bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + weights_propagator = Propagator( + weights_matrix, + [0, 0, 0, 0], + ) + bias_propagator = Propagator( + bias_matrix, + [0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "weights_propagator": weights_propagator, + "bias_propagator": bias_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels) + return dma_ofm_compute( + depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_depthwise_conv2d(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Depthwise Conv2D. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration. + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + depthwise2d = convert_to_nhcwb16.op.input_tensors[0] + if depthwise2d.op.name != "ethosu_depthwise_conv2d": + return None + pad = depthwise2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + depthwise2d.op.input_tensors[1], + depthwise2d.op.input_tensors[2], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["weights_propagator"], + write.op.attrs["bias_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels, kernel_height, kernel_width = (int(axis) for axis in input_tensors[1].shape[0:3]) + + subkernels = len( + device_config.get_kernel_steps(depthwise2d.op.name, kernel_height, kernel_width, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + depthwise2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + kernel_height, + kernel_width, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + 1, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index e98a72db7f02..aaf79e8a8c8d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -18,7 +18,10 @@ """Tensor Expressions for poolings""" from typing import Tuple +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -99,8 +102,13 @@ def pooling_compute( te.Tensor The OFM tensor. """ - stride_h, stride_w = strides - pool_shape_h, pool_shape_w = pool_shape + assert ifm.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + pool_shape_h, pool_shape_w = [int(v) for v in pool_shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) @@ -114,6 +122,8 @@ def pooling_compute( pooling_attrs = { "op": "ethosu_pooling", "pooling_type": pooling_type, + "pool_shape_h": pool_shape_h, + "pool_shape_w": pool_shape_w, "stride_h": stride_h, "stride_w": stride_w, "activation": activation, @@ -144,5 +154,128 @@ def pooling_compute( attrs=pooling_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (pool_shape_h - stride_h)], + [0, 0, stride_w, 0, (pool_shape_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_pooling(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Pooling. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + pool2d = convert_to_nhcwb16.op.input_tensors[0] + if pool2d.op.name != "ethosu_pooling": + return None + pad = pool2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels = ifm_channels + pool_shape_h = int(pool2d.op.attrs["pool_shape_h"]) + pool_shape_w = int(pool2d.op.attrs["pool_shape_w"]) + + subkernels = len( + device_config.get_kernel_steps(pool2d.op.name, pool_shape_h, pool_shape_w, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + pool2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + pool_shape_h, + pool_shape_w, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index 0aefc1c35d4c..68d1c603ad98 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for unary_elementwise for the NPU""" +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher from .dma import dma_ofm_compute, dma_ifm_compute @@ -127,5 +129,119 @@ def clz_imp(inp): attrs=unary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = {"ifm_propagator": ifm_propagator} + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(unary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + unary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ofm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_unary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Unary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + unary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if unary_elementwise.op.name != "ethosu_unary_elementwise": + return None + pad = unary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + None, + unary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + None, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 1d4c8ad75762..f8c12ff334db 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -716,11 +716,11 @@ def gru_cell( b_inp, b_hid : relay.Expr bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size) r_act : relay.op - activation funtion for reset gate. it is sigmoid by default + activation function for reset gate. it is sigmoid by default z_act : relay.op - activation funtion for update gate. it is sigmoid by default + activation function for update gate. it is sigmoid by default n_act : relay.op - activation funtion for new gate. it is tanh by default + activation function for new gate. it is tanh by default backwards : bool Flag for reverse pass of GRU @@ -812,7 +812,7 @@ def lstm_cell( p_i, p_f, p_o : relay.Expr peephole LSTM matrices. shape = (batch, hidden_size) f_act, g_act, h_act : relay.op - activation funtions + activation functions backwards : bool Flag for reverse pass of LSTM diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 542980561e78..f21e3eaf2c3c 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude): return None -def _get_name_static(canonical, dtype, shape, batch_dim=None): +def _get_name_static(canonical, dtype, shape, batch_dim=None, extra_shapes=None): """Get name for static shape tensor array op By design, static ADT tensor in TVM has type name in the format @@ -100,14 +100,12 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None): name : String The tensor array op name """ - dim_names = [] - for dim in shape: - if isinstance(dim, Any): - dim_names.append("any") - else: - dim_names.append(str(dim)) + shape_str = _to_str(shape) - shape_str = "_".join(dim_names) + if extra_shapes is not None: + for n, s in extra_shapes.items(): + extra_shape_str = "_{}_{}".format(n, _to_str(s)) + shape_str += extra_shape_str if len(shape_str) == 0: shape_str = "scalar" @@ -120,6 +118,16 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None): return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) +def _to_str(shape): + dim_names = [] + for dim in shape: + if isinstance(dim, Any): + dim_names.append("any") + else: + dim_names.append(str(dim)) + return "_".join(dim_names) + + class StaticTensorArrayOps(object): """Contains tensor array related ops for fixed rank tensor array""" @@ -131,9 +139,9 @@ def __init__(self, prelude, dtype, shape, batch_dim=None): self.batch_dim = batch_dim self.list, self.cons, self.nil = self.prelude.mod.get_type("List") - def get_name(self, canonical): + def get_name(self, canonical, extra_shapes=None): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim) + return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim, extra_shapes) def get_global_var(self, canonical): """Get global corresponding to the canonical name""" @@ -408,11 +416,16 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): # When this operator has already been registered, only update # when force_update is set. This should be used only when we need to # redefine this op for static indices shape. - tensor_array_scatter_name = self.get_name("tensor_array_scatter") + + extra_shapes = {"indices": indices_shape} if indices_shape is not None else None + tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes) if hasattr(self.prelude, tensor_array_scatter_name) and not force_update: return - tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_name = self.get_name( + "tensor_array_scatter_helper", extra_shapes + ) + tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name) ta = Var("ta", self.list(self.tensor_type_var())) current = Var("current", scalar_type("int32")) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 2b9f7f9446ba..8400a5998e39 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -127,7 +127,7 @@ def __setitem__(self, in_slice, value): raise TypeError("type %s not supported" % str(type(value))) def copyfrom(self, source_array): - """Peform an synchronize copy from the array. + """Perform an synchronize copy from the array. Parameters ---------- diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index d05984b91393..ac16dd7b65b4 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -203,7 +203,7 @@ def _schedule(cfg, s, C): s[BF].reorder(bs, o, i, o_ii, i_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) bs, xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -217,8 +217,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index 1e368f585354..698beeac6dc4 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -165,7 +165,7 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack): s[BF].reorder(i, o, i_ii, o_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) _, _, xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -179,8 +179,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) in_dtype = "float16" diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 20ff1aaccc5f..7acc1307f84c 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -238,7 +238,7 @@ def _schedule_dense_tensorcore(cfg, s, C): s[BF].reorder(o, i, o_ii, i_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -252,8 +252,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index eff52308f389..6dec9a5502e1 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -78,7 +78,7 @@ inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, } /*! - * \brief Peform trunc division of two integers. + * \brief Perform trunc division of two integers. * \param x The left operand. * \param y The right operand. * \return the result. @@ -94,7 +94,7 @@ inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } /*! - * \brief Peform floor division of two integers. + * \brief Perform floor division of two integers. * \param x The left operand. * \param y The right operand. * \return the result. diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc index c5f236761ba0..cdbbda18c142 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.cc +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -56,7 +56,6 @@ const std::vector EthosuPartNode::GetBytesRead(const std::vector& int i = 0; for (const auto& input_block_config : input_block_configs) { std::map, int> input_blocks = CountStripes(input_block_config, false); - for (const auto& block : input_blocks) { bytes_per_input[i] += mul_reduce(block.first) * block.second; } @@ -82,8 +81,8 @@ const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stri bytes_per_input[0] *= subkernels_; // Calculate bytes read per output element - float relative_cost = - (bytes_per_input[0] + bytes_per_input[1]) / mul_reduce(output_stripe_shape); + float relative_cost = static_cast(bytes_per_input[0] + bytes_per_input[1]) / + mul_reduce(output_stripe_shape); // Single buffering hardware optimization if (mul_reduce(output_stripe_shape) <= 2 * mul_reduce(output_block)) { @@ -116,7 +115,8 @@ const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& out output_stripe_config->GetStripes()[i]) / block_shape[i]; } else { - num_blocks *= static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i]; + num_blocks *= + std::max(static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i], 1.0f); } } float num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1.0f; diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 5ed93914ac53..4b77cb14d48b 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -192,7 +192,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 2d6f75ae3948..b710c2791acf 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -166,7 +166,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* var_node) { if (var_node->type_annotation.defined()) { type_annotation = this->VisitType(var_node->type_annotation); } - return WithFields(GetRef(var_node), std::move(var_node->vid), std::move(type_annotation)); + return WithFields(GetRef(var_node), var_node->vid, type_annotation); } Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } @@ -183,7 +183,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) { auto new_field = this->Mutate(field); fields.push_back(new_field); } - return WithFields(GetRef(tuple_node), std::move(fields)); + return WithFields(GetRef(tuple_node), fields); } Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { @@ -203,8 +203,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { auto ret_type = this->VisitType(func_node->ret_type); auto body = this->Mutate(func_node->body); - return WithFields(GetRef(func_node), std::move(params), std::move(body), - std::move(ret_type), std::move(ty_params)); + return WithFields(GetRef(func_node), params, body, ret_type, ty_params); } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { @@ -225,8 +224,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { call_args.push_back(new_arg); } - return WithFields(GetRef(call_node), std::move(new_op), std::move(call_args), {}, - std::move(ty_args)); + return WithFields(GetRef(call_node), new_op, call_args, {}, ty_args); } Expr ExprMutator::VisitExpr_(const LetNode* let_node) { @@ -234,7 +232,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* let_node) { auto value = this->Mutate(let_node->value); auto body = this->Mutate(let_node->body); - return WithFields(GetRef(let_node), std::move(var), std::move(value), std::move(body)); + return WithFields(GetRef(let_node), var, value, body); } Expr ExprMutator::VisitExpr_(const IfNode* if_node) { @@ -242,28 +240,28 @@ Expr ExprMutator::VisitExpr_(const IfNode* if_node) { auto true_b = this->Mutate(if_node->true_branch); auto false_b = this->Mutate(if_node->false_branch); - return WithFields(GetRef(if_node), std::move(cond), std::move(true_b), std::move(false_b)); + return WithFields(GetRef(if_node), cond, true_b, false_b); } Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { Expr tuple = this->Mutate(get_item->tuple); - return WithFields(GetRef(get_item), std::move(tuple)); + return WithFields(GetRef(get_item), tuple); } Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) { Expr value = this->Mutate(ref_create->value); - return WithFields(GetRef(ref_create), std::move(value)); + return WithFields(GetRef(ref_create), value); } Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) { Expr ref = this->Mutate(ref_read->ref); - return WithFields(GetRef(ref_read), std::move(ref)); + return WithFields(GetRef(ref_read), ref); } Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) { Expr ref = this->Mutate(ref_write->ref); Expr value = this->Mutate(ref_write->value); - return WithFields(GetRef(ref_write), std::move(ref), std::move(value)); + return WithFields(GetRef(ref_write), ref, value); } Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } @@ -275,13 +273,13 @@ Expr ExprMutator::VisitExpr_(const MatchNode* match_node) { } Expr data = Mutate(match_node->data); - return WithFields(GetRef(match_node), std::move(data), std::move(clauses)); + return WithFields(GetRef(match_node), data, clauses); } Clause ExprMutator::VisitClause(const Clause& clause) { Pattern lhs = VisitPattern(clause->lhs); Expr rhs = Mutate(clause->rhs); - return WithFields(std::move(clause), std::move(lhs), std::move(rhs)); + return WithFields(clause, lhs, rhs); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -462,7 +460,7 @@ class ExprBinder : public MixedModeMutator, PatternMutator { Clause VisitClause(const Clause& clause) final { Pattern lhs = VisitPattern(clause->lhs); - return WithFields(std::move(clause), std::move(lhs), VisitExpr(clause->rhs)); + return WithFields(clause, lhs, VisitExpr(clause->rhs)); } Var VisitVar(const Var& v) final { diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index df1a858f8d0b..6e4ab88ea326 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -270,7 +270,7 @@ class AnnotateTargetRewriter : public ExprRewriter { auto tuple = Downcast(post); auto target_n_args = AnnotateArgs(tuple->fields); - auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args))); + auto new_expr = WithFields(tuple, std::get<1>(target_n_args)); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return std::move(new_expr); } @@ -378,7 +378,7 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter { for (auto f : tuple->fields) { new_fields.push_back(InsertCompilerEndAndPropogateTarget(f)); } - return WithFields(std::move(tuple), std::move(new_fields)); + return WithFields(tuple, new_fields); } Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index d40dd6c95089..c7ca2227fd90 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -334,7 +334,7 @@ class RewriteOnDevices : public ExprMutator { Expr tuple = VisitExpr(tuple_get_item_node->tuple); OnDeviceProps props = GetOnDeviceProps(tuple); - Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), std::move(tuple)); + Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), tuple); if (props.body.defined() && props.is_normal()) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl @@ -363,8 +363,8 @@ class RewriteOnDevices : public ExprMutator { } expr = VisitExpr(expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { - expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*opt_var=*/{}, - /*opt_value=*/std::move(std::get<1>(*itr)), /*opt_body=*/std::move(expr)); + expr = WithFields(/*let=*/std::get<0>(*itr), /*opt_var=*/{}, + /*opt_value=*/std::get<1>(*itr), /*opt_body=*/expr); } return expr; } @@ -378,7 +378,7 @@ class RewriteOnDevices : public ExprMutator { << "to be fixed to VirtualDevice " << props.virtual_device; body = MaybeOnDeviceFixed(props.body, props.virtual_device); } - return WithFields(GetRef(function_node), function_node->params, std::move(body)); + return WithFields(GetRef(function_node), function_node->params, body); } Expr VisitExpr_(const CallNode* call_node) final { @@ -990,7 +990,7 @@ class DeviceCapturer : public ExprMutator { for (const auto& field : tuple_node->fields) { fields.push_back(VisitChild(tuple, field)); } - return WithFields(std::move(tuple), std::move(fields)); + return WithFields(tuple, fields); } Expr VisitExpr_(const FunctionNode* function_node) final { @@ -1025,8 +1025,7 @@ class DeviceCapturer : public ExprMutator { /*expected_virtual_device=*/result_virtual_device, /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body); - Function func = WithFields(GetRef(function_node), std::move(function_node->params), - std::move(body)); + Function func = WithFields(GetRef(function_node), function_node->params, body); return FunctionOnDevice(func, std::move(param_virtual_devices), std::move(result_virtual_device)); } @@ -1102,9 +1101,9 @@ class DeviceCapturer : public ExprMutator { if (call_node->op == CallLoweredOp()) { Call new_call = CallLowered(Downcast(op), args, /*call_lowered_attrs=*/{}, /*span=*/{}); - return WithFields(call, std::move(new_call->op), std::move(new_call->args)); + return WithFields(call, new_call->op, new_call->args); } else { - return WithFields(call, std::move(op), std::move(args)); + return WithFields(call, op, args); } } @@ -1145,33 +1144,32 @@ class DeviceCapturer : public ExprMutator { Expr cond = VisitChild(ife, if_node->cond); Expr true_branch = VisitChild(ife, if_node->true_branch); Expr false_branch = VisitChild(ife, if_node->false_branch); - return WithFields(std::move(ife), std::move(cond), std::move(true_branch), - std::move(false_branch)); + return WithFields(ife, cond, true_branch, false_branch); } Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { auto tuple_get_item = GetRef(tuple_get_item_node); Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); - return WithFields(std::move(tuple_get_item), std::move(tuple)); + return WithFields(tuple_get_item, tuple); } Expr VisitExpr_(const RefCreateNode* ref_create_node) final { auto ref_create = GetRef(ref_create_node); Expr value = VisitChild(ref_create, ref_create_node->value); - return WithFields(std::move(ref_create), std::move(value)); + return WithFields(ref_create, value); } Expr VisitExpr_(const RefReadNode* ref_read_node) final { auto ref_read = GetRef(ref_read_node); Expr ref = VisitChild(ref_read, ref_read_node->ref); - return WithFields(std::move(ref_read), std::move(ref)); + return WithFields(ref_read, ref); } Expr VisitExpr_(const RefWriteNode* ref_write_node) final { auto ref_write = GetRef(ref_write_node); Expr ref = VisitChild(ref_write, ref_write_node->ref); Expr value = VisitChild(ref_write, ref_write_node->value); - return WithFields(std::move(ref_write), std::move(ref), std::move(value)); + return WithFields(ref_write, ref, value); } Expr VisitExpr_(const MatchNode* match_node) final { @@ -1184,7 +1182,7 @@ class DeviceCapturer : public ExprMutator { Expr rhs = VisitChild(match, clause->rhs); clauses.push_back(Clause(lhs, rhs)); } - return WithFields(std::move(match), std::move(data), std::move(clauses)); + return WithFields(match, data, clauses); } VirtualDevice GetVirtualDevice(const Expr& expr) { diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 9408d16d87e9..d695c6dc491d 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -211,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor { field_bindings.push_back(f_ad->get().forward); } // reconstruct tuple using let-bound variables to avoid duplication - auto orig = WithFields(GetRef(tuple_node), std::move(field_bindings)); + auto orig = WithFields(GetRef(tuple_node), field_bindings); orig->checked_type_ = tt; auto ret = std::make_shared(ll, orig, diag_ctx); // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)] diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index 23c45a90a5e3..0e7e9076ae0e 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -122,7 +122,7 @@ class ForwardRewriter : private MixedModeMutator { fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i])); } - return WithFields(GetRef(tuple_node), std::move(fields)); + return WithFields(GetRef(tuple_node), fields); } Expr Rewrite_(const CallNode* call_node, const Expr& post) final { diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f2fc0af4f9c1..5037b32ce615 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -905,7 +905,7 @@ class FuseMutator : private MixedModeMutator { } // This tuple is an intermediate node in the group Array new_fields = GetNewArguments(tuple_node->fields, ret_group); - return WithFields(GetRef(tuple_node), std::move(new_fields)); + return WithFields(GetRef(tuple_node), new_fields); } Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index d1f0f69c5e93..900442e9b9a8 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -92,7 +92,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } new_fields.push_back(new_field); } - return WithFields(GetRef(tuple_node), std::move(new_fields)); + return WithFields(GetRef(tuple_node), new_fields); } void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index c22519c441c2..d1b9b563e932 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -470,7 +470,7 @@ IRModule FlattenTupleOutputs(IRModule module) { // Return a tuple of compiler_ends in the place of the tuple that was // annotated with a compiler_end. - return WithFields(GetRef(tuple_node), std::move(new_fields)); + return WithFields(GetRef(tuple_node), new_fields); } } return post; diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index fbb2d73d1db0..a5266df8b057 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -59,12 +59,12 @@ class ArgumentSplitter : public ExprRewriter { for (int j = 0; j < argsCount; ++j) { args.push_back(tuple_node->fields[j + startIdx]); } - Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(args)); + Tuple new_tuple = WithFields(GetRef(tuple_node), args); Expr body = MakeConcatenate(new_tuple, param->axis); splitted[i] = StopFusion(body); } tvm::Array tuple_args(splitted); - Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(tuple_args)); + Tuple new_tuple = WithFields(GetRef(tuple_node), tuple_args); return MakeConcatenate(new_tuple, param->axis); } return post; diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 321839d81e3e..f6d5ac9cf8bb 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -255,7 +255,7 @@ class Fill : ExprFunctor, private transform::Lexi for (const auto& a : tuple_node->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, WithFields(GetRef(tuple_node), std::move(fields)), v); + return Compound(e, WithFields(GetRef(tuple_node), fields), v); } Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 0f889cd6ff7f..c5d17fbfbef7 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -216,7 +216,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, std::function next; next = [&]() { return (fields.size() == tuple_node->fields.size()) - ? k(WithFields(GetRef(tuple_node), std::move(fields))) + ? k(WithFields(GetRef(tuple_node), fields)) : VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) { fields.push_back(v); return next(); diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index 56affb581fd1..3dbf10e0611b 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -300,7 +300,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj Expr tmp = push_back_one_arg(x); fields.push_back(tmp); } - normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields))); + normal_new_args.push_back(WithFields(tuple_new_arg, fields)); } else { Expr tmp = push_back_one_arg(new_arg); normal_new_args.push_back(tmp); @@ -383,7 +383,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); pt++; } - transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg))); + transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg)); } else { transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); pt++; diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index d1b54d5be65b..a081cf97db4a 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -60,7 +60,7 @@ class LocalSession : public RPCSession { protected: /*! - * \brief internal encode return fucntion. + * \brief internal encode return function. * \param rv The return value. * \param encode_return The encoding function. */ diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index af814158f7b6..1c6a6f8b4350 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -87,7 +87,7 @@ class RingBuffer { } /*! - * \brief Peform a non-blocking read from buffer + * \brief Perform a non-blocking read from buffer * size must be smaller than this->bytes_available() * \param data the data pointer. * \param size The number of bytes to read. diff --git a/src/support/socket.h b/src/support/socket.h index a83a67c85d76..42d5d9004c15 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -516,7 +516,7 @@ class TCPSocket : public Socket { [&]() { return recv(sockfd, buf, static_cast(len), flags); }); } /*! - * \brief peform block write that will attempt to send all data out + * \brief perform block write that will attempt to send all data out * can still return smaller than request when error occurs * \param buf_ the pointer to the buffer * \param len the size of the buffer @@ -538,7 +538,7 @@ class TCPSocket : public Socket { return ndone; } /*! - * \brief peform block read that will attempt to read all data + * \brief perform block read that will attempt to read all data * can still return smaller than request when error occurs * \param buf_ the buffer pointer * \param len length of data to recv @@ -654,7 +654,7 @@ struct PollHelper { } /*! - * \brief peform poll on the set defined, read, write, exception + * \brief perform poll on the set defined, read, write, exception * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block * \return */ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 5de0538960fc..94dd0b044d71 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -282,7 +282,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { << "Only te.placeholder and te.compute are allowed for now."; } - // Infomations used in CreatePrimFunc and its sub-funtions. + // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. Array root_stmts; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 9e2d3d0e725f..b31b61b739c1 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -35,7 +35,7 @@ namespace te { using namespace tir; -// Detect the region of input and output to be tensrized. +// Detect the region of input and output to be tensorized. // out_dom: the domain of root iter vars in output op // in_region: region of each input tensor. // return The location of the tensorized scope start. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 159171ecae31..4a80279d97cb 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -635,7 +635,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff /*require_subtree_compact_dataflow=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); - // Step 2. Creat CacheStageInfo + // Step 2. Create CacheStageInfo CacheStageInfo info; info.read_buffer = read_buffer; // Create the corresponding buffer to be written, i.e. result of cache_read diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 424a1bbb0ae6..7a6d2d37c376 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -450,7 +450,7 @@ class CoProcInstDepDetector : public StmtVisitor { std::unordered_set exit_ctx; // existing pop performed at enter std::vector > enter_pop; - // existing push peformed at exit + // existing push performed at exit std::vector > exit_push; // clear the state void clear() { diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index d8c559cec6e0..e9eb6fb3a145 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -124,22 +124,118 @@ def make_model( @tvm.testing.requires_cmsisnn -@pytest.mark.parametrize("ifm_shape", [(1, 25, 25, 12), (1, 64, 100, 4)]) -@pytest.mark.parametrize("kernel_size", [(5, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1))]) @pytest.mark.parametrize("relu_type", ["RELU"]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize( "input_zero_point, input_scale, kernel_scale, out_channels", [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], ) -def test_conv2d_int8( - ifm_shape, - kernel_size, +def test_conv2d_symmetric_padding_int8( + padding, + enable_bias, + relu_type, + input_zero_point, + input_scale, + kernel_scale, + out_channels, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + ifm_shape = (1, 64, 100, 4) + kernel_size = (3, 3) + strides = (1, 1) + dilation = (1, 1) + dtype = "int8" + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsis-nn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +@pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], +) +def test_conv2d_asymmetric_padding_int8( padding, - strides, - dilation, enable_bias, relu_type, input_zero_point, @@ -151,6 +247,10 @@ def test_conv2d_int8( use_unpacked_api = True test_runner = AOT_CORSTONE300_RUNNER + ifm_shape = (1, 25, 25, 12) + kernel_size = (5, 5) + strides = (2, 2) + dilation = (1, 1) dtype = "int8" groups = 1 weight_format = "HWIO" diff --git a/tests/python/contrib/test_ethosu/cascader/conftest.py b/tests/python/contrib/test_ethosu/cascader/conftest.py index 58ffb51a5967..eacf57c251a8 100644 --- a/tests/python/contrib/test_ethosu/cascader/conftest.py +++ b/tests/python/contrib/test_ethosu/cascader/conftest.py @@ -29,7 +29,11 @@ from tvm.relay.testing import run_opt_pass from .infra import create_te_graph - from ..infra import make_ethosu_conv2d + from ..infra import ( + make_ethosu_conv2d, + make_ethosu_depthwise_conv2d, + make_ethosu_binary_elementwise, + ) def make_TwoConv2DWithSliceTE(): def _get_func(): @@ -71,3 +75,62 @@ def _get_func(): @pytest.fixture def TwoConv2DWithSliceTE(): return make_TwoConv2DWithSliceTE() + + def make_MobileNetv2DiamondTE(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 56, 56, 96), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm=ifm, + ifm_channels=96, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + conv2 = make_ethosu_conv2d( + ifm=conv1, + ifm_channels=24, + ofm_channels=144, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + depth1 = make_ethosu_depthwise_conv2d( + ifm=conv2, + channels=144, + kernel_shape=(3, 3), + padding=(1, 1, 1, 1), + strides=(1, 1), + dilation=(1, 1), + ) + conv3 = make_ethosu_conv2d( + ifm=depth1, + ifm_channels=144, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + add1 = make_ethosu_binary_elementwise( + ifm=conv1, + ifm2=conv3, + ifm_channels=24, + ifm2_channels=24, + operator_type="ADD", + ofm_dtype="int8", + ) + func = relay.Function(relay.analysis.free_vars(add1), add1) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + te_graph, const_dict = create_te_graph(func) + sch = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + return sch, te_graph, const_dict + + @pytest.fixture + def MobileNetv2DiamondTE(): + return make_MobileNetv2DiamondTE() diff --git a/tests/python/contrib/test_ethosu/cascader/infra.py b/tests/python/contrib/test_ethosu/cascader/infra.py index c2b6073fb62e..5f41dce30147 100644 --- a/tests/python/contrib/test_ethosu/cascader/infra.py +++ b/tests/python/contrib/test_ethosu/cascader/infra.py @@ -29,7 +29,9 @@ def create_te_graph(func): return te_graph, consts -def make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): +def make_matrices( + op_type, kernel, stride, padding, ifm_layout, ofm_layout, dilation=(1, 1), ifm_channels=1 +): kernel_h, kernel_w = kernel stride_h, stride_w = stride dilation_h, dilation_w = dilation @@ -50,20 +52,51 @@ def make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, o [0, 0, 16, 0, 1, -16], [0, 0, 0, 0, 0, 1], ] - ifm_matrix = [ - [1, 0, 0, 0, 0], - [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], - [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - weight_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, kernel_h], - [0, 0, 0, 0, kernel_w], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] + if op_type == "ethosu_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_depthwise_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_pooling": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] scale_bias_matrix = [ [0, 0, 0, 1, 0], [0, 0, 0, 0, 10], diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py new file mode 100644 index 000000000000..bb1be7b8e251 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.binary_elementwise import ( + match_ethosu_binary_elementwise, + binary_elementwise_compute, +) + + +def _make_matrices(broadcast, ifm_layout, ifm2_layout, ofm_layout): + broadcast_h, broadcast_w, broadcast_c = broadcast + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - broadcast_h), 0, 0, broadcast_h], + [0, 0, (1 - broadcast_w), 0, broadcast_w], + [0, 0, 0, (1 - broadcast_c), broadcast_c], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + + return (ifm_matrix, ifm2_matrix) + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 73, 51, 20], + [1, 124, 172, 5], + ], +) +@pytest.mark.parametrize("ifm2_broadcast", [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ifm2_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["MUL", "ADD", "MIN"]) +def test_ethosu_binary_elementwise_matcher( + ofm_shape, ifm2_broadcast, ifm_layout, ifm2_layout, ofm_layout, op_type +): + ifm_shape = ofm_shape.copy() + ifm2_shape = [1] + [1 if (b == 1) else a for a, b in zip(ofm_shape[1:], ifm2_broadcast)] + ifm_channels = ifm_shape[3] + ifm2_channels = ifm2_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + broadcast = [1 if a == 1 else 0 for a in ifm2_shape[1:]] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ifm2_layout == "NHCWB16": + ifm2_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm2_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + ifm2 = te.placeholder(ifm2_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = binary_elementwise_compute( + ifm=ifm, + ifm2=ifm2, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ifm2_scale=1, + ifm2_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ifm_channels=ifm_channels, + ifm2_channels=ifm2_channels, + reversed_operands=False, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ifm2_layout=ifm2_layout, + ofm_layout=ofm_layout, + ofm_dtype="int8", + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + ifm2_propagator = out.op.attrs["ifm2_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + (ifm_transform, ifm2_transform) = _make_matrices( + broadcast, + ifm_layout, + ifm2_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_binary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 2 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[1].transform == ifm2_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + propagated_ifm2 = ifm2_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm(2)_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + if ifm2_layout != ofm_layout: + assert ifm2_shape[:-1] == propagated_ifm2[:-1] + assert ((ifm2_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm2[-1] + else: + assert ifm2_shape == propagated_ifm2 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py index 3418bb58351e..3f3935fff1f9 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py @@ -27,8 +27,9 @@ @pytest.mark.parametrize( - "id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", + "test_id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", [ + # Conv2D ( 0, "ethosu_conv2d", @@ -95,6 +96,52 @@ (1, 62, 94, 32), (1, 58, 90, 16), ), + # Depthwise Conv2D + ( + 6, + "ethosu_depthwise_conv2d", + "NONE", + (3, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 77, 23, 18), + (1, 75, 19, 18), + ), + ( + 7, + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + ), + # Pooling + ( + 8, + "ethosu_pooling", + "NONE", + (13, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 13, 5, 276), + (1, 1, 1, 276), + ), + ( + 9, + "ethosu_pooling", + "NONE", + (7, 3), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 317, 14, 21), + (1, 156, 12, 21), + ), ], ) @pytest.mark.parametrize( @@ -112,51 +159,79 @@ ( "ethos-u55-32", [ + # Conv2D ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 6, 5, 16), (1, 6, 1, 5, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 8, 4, 16), (1, 8, 1, 4, 16)), - ((1, 10, 6, 4), (1, 16, 1, 4, 4)), - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 16, 1, 4, 4)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), ], ), ( "ethos-u55-64", [ + # Conv2D ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 6, 5, 16), (1, 6, 1, 5, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 8, 4, 16), (1, 8, 1, 4, 16)), ((1, 10, 6, 8), (1, 16, 1, 4, 8)), - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), ], ), ( "ethos-u55-128", [ + # Conv2D ((1, 7, 6, 16), (1, 7, 1, 6, 16)), ((1, 5, 8, 16), (1, 5, 1, 8, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), ((1, 16, 4, 16), (1, 16, 1, 4, 16)), ((1, 8, 12, 8), (1, 8, 1, 12, 8)), ((1, 10, 6, 16), (1, 10, 1, 6, 16)), + # Depthwise Conv2D + ((1, 7, 10, 16), (1, 7, 1, 10, 16)), + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), + # Pooling + ((1, 1, 2, 80), (1, 1, 5, 2, 16)), + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), ], ), ( "ethos-u55-256", [ + # Conv2D ((1, 14, 8, 16), (1, 14, 1, 8, 16)), ((1, 16, 8, 16), (1, 16, 1, 8, 16)), ((1, 4, 4, 16), (1, 4, 1, 4, 16)), - ((1, 32, 4, 16), (1, 32, 1, 4, 16)), + ((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)), ((1, 20, 12, 8), (1, 20, 1, 12, 8)), - ((1, 20, 6, 16), (1, 20, 1, 6, 16)), + ((1, 12, 10, 16), (1, 12, 1, 10, 16)), + # Depthwise Conv2D + ((1, 8, 20, 16), (1, 8, 1, 20, 16)), + ((1, 14, 6, 16), (1, 14, 1, 6, 16)), + # Pooling + ((1, 2, 2, 48), (1, 2, 3, 2, 16)), + ((1, 10, 12, 16), (1, 10, 1, 12, 16)), ], ), ], ) def test_best_block_config( - id, + test_id, op_type, activation, kernel, @@ -185,7 +260,7 @@ def test_best_block_config( [0, 0, 0, 0, 0, 1], ] ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( - kernel, stride, dilation, padding, in_shape[3], layouts[0], layouts[1] + op_type, kernel, stride, padding, layouts[0], layouts[1], dilation, in_shape[3] ) ofm_channels = out_shape[3] @@ -252,10 +327,8 @@ def test_best_block_config( block = part.get_block_config(stripe_config) block_shape = tuple(int(a) for a in block.output_shape) - if layouts[1] == "NHCWB16": - assert block_shape == expected_block_configs[id][1] - else: - assert block_shape == expected_block_configs[id][0] + + assert block_shape in expected_block_configs[test_id] if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py index 8ff5ef09fdc3..5bd2be49f620 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py @@ -24,8 +24,6 @@ from .infra import make_matrices -import pytest - @pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) @pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) @@ -76,13 +74,14 @@ def test_ethosu_conv2d_matcher( scale_bias_transform, scale_bias_offset, ) = make_matrices( + "ethosu_conv2d", kernel, stride, - dilation, padding, - ifm_channels, ifm_layout, ofm_layout, + dilation, + ifm_channels, ) device_config = cs.EthosuDeviceConfig("ethos-u55-256") diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py new file mode 100644 index 000000000000..c2c45b6524f1 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.depthwise import ( + match_ethosu_depthwise_conv2d, + depthwise_conv2d_compute, +) +from .infra import make_matrices + + +@pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("dilation", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_depthwise2d_matcher(kernel, stride, dilation, padding, ifm_layout, ofm_layout): + ofm_channels = 57 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + kernel_h, kernel_w = kernel + ifm = te.placeholder(ifm_shape, dtype="int8") + weight = te.placeholder((ofm_channels, kernel_h, kernel_w, 1), dtype="int8") + scale_bias = te.placeholder((ofm_channels, 10), dtype="uint8") + lut = te.placeholder((), dtype="uint8") + out = depthwise_conv2d_compute( + ifm=ifm, + weight=weight, + scale_bias=scale_bias, + lut=lut, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + weight_zero_point=0, + strides=stride, + padding=padding, + dilation=dilation, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ofm_dtype=ifm.dtype, + ) + ( + ifm_transform, + ifm_offset, + weight_transform, + weight_offset, + scale_bias_transform, + scale_bias_offset, + ) = make_matrices( + "ethosu_depthwise_conv2d", + kernel, + stride, + padding, + ifm_layout, + ofm_layout, + dilation, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_depthwise_conv2d(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 3 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + assert part.propagators[1].transform == weight_transform + assert part.propagators[1].offset == weight_offset + assert part.propagators[2].transform == scale_bias_transform + assert part.propagators[2].offset == scale_bias_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py index 297fbaa89059..ba6346afa5d5 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py @@ -122,6 +122,34 @@ def test_device_config_cycles(acc_config, expected): (1, 18, 14, 8), 174105, ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + (1, 7, 6, 16), + (1, 15, 14, 16), + 17590, + ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (4, 9), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 28, 81, 42), + (1, 25, 73, 41), + (1, 4, 16, 16), + (1, 7, 24, 16), + 173414, + ), ], ) def test_conv_performance( @@ -138,16 +166,17 @@ def test_conv_performance( input_block_shape, expected, ): + ifm_channels = in_shape[3] ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + op_type, kernel, stride, - dilation, padding, - in_shape[3], "NHWC", "NHWC", + dilation, + ifm_channels, ) - ifm_channels = in_shape[3] propagator = cs.Propagator(ifm_matrix, ifm_offset) weight_propagator = cs.Propagator(weight_matrix, weight_offset) @@ -191,7 +220,7 @@ def test_conv_performance( stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) compute_cycles = part.get_performance_info(stripe_config, cs.BufferMode.ROLLING).compute_cycles - tolerance = expected * 0.05 + tolerance = expected * 0.1 assert expected - tolerance <= compute_cycles <= expected + tolerance diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py new file mode 100644 index 000000000000..6ce8ee9a2986 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.pooling import match_ethosu_pooling, pooling_compute +from .infra import make_matrices + + +@pytest.mark.parametrize("pool_shape", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_pooling_matcher(pool_shape, stride, padding, ifm_layout, ofm_layout): + ofm_channels = 21 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = pooling_compute( + ifm=ifm, + lut=lut, + pooling_type="MAX", + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + pool_shape=pool_shape, + ofm_channels=ofm_channels, + strides=stride, + padding=padding, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + (ifm_transform, ifm_offset, _, _, _, _) = make_matrices( + "ethosu_pooling", + pool_shape, + stride, + padding, + ifm_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_pooling(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py new file mode 100644 index 000000000000..0570524e0907 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.unary_elementwise import ( + match_ethosu_unary_elementwise, + unary_elementwise_compute, +) + + +def _make_matrices(ifm_layout, ofm_layout): + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + return ifm_matrix + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 53, 91, 7], + [1, 182, 12, 72], + ], +) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["ABS", "CLZ"]) +def test_ethosu_unary_elementwise_matcher(ofm_shape, ifm_layout, ofm_layout, op_type): + ifm_shape = ofm_shape.copy() + ofm_channels = ofm_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = unary_elementwise_compute( + ifm=ifm, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ofm_channels=ofm_channels, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + ifm_transform = _make_matrices(ifm_layout, ofm_layout) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_unary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py index da31ad346b4f..616800f69d7e 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_graph.py +++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py @@ -176,5 +176,29 @@ def test_create_cascader_graph(TwoConv2DWithSliceTE): assert conv1_part.input_tensors[2].is_constant +def test_create_diamond_graph(MobileNetv2DiamondTE): + _, te_graph, const_dict = MobileNetv2DiamondTE + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + graph = cs.create_cascader_graph(te_graph, const_dict, device_config) + + output_tensor = graph.output_tensors[0] + assert output_tensor.shape == [1, 56, 56, 24] + assert len(output_tensor.producers) == 1 + assert not output_tensor.is_constant + + add1_part = output_tensor.producers[0] + assert isinstance(add1_part, cs.EthosuPart) + assert len(add1_part.input_tensors) == 2 + assert graph.get_part_id(add1_part) == 0 + + assert add1_part.input_tensors[0].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + assert add1_part.input_tensors[1].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index aafae1497ea4..62a1fabe0b98 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -33,14 +33,17 @@ from tvm import relay from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func from . import infra -def _run_pass(expr, relay_pass): - """Create IRModule and run Relay pass.""" +def _optimize(expr, optimize=True): + """Create IRModule and run layout optimizer pass.""" mod = tvm.IRModule.from_expr(expr) - mod = relay_pass(mod) + mod = relay.transform.InferType()(mod) + if optimize: + mod = LayoutOptimizer()(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -111,8 +114,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -144,8 +147,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -176,8 +179,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -222,8 +225,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -268,8 +271,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -322,8 +325,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(pool_3), pool_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -368,8 +371,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv), conv) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -413,8 +416,8 @@ def get_graph(get_expected=False): concat = relay.concatenate(poolings, axis=0) return relay.Function(relay.analysis.free_vars(concat), concat) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -467,8 +470,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(add_3), add_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -500,8 +503,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -530,8 +533,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -619,5 +622,32 @@ def representative_dataset(): _compile_and_compare_model(create_model(), ifm_shape, dtype) +def test_layout_optimizer_runs_in_compilation_pipeline(): + """Checks that the layout optimization pass runs as part of the NPU compilation + pipeline.""" + + def get_graph(): + x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8") + for _ in range(2): + x = relay.nn.max_pool2d(x, layout="NHWC") + + func = relay.Function(relay.analysis.free_vars(x), x) + return tvm.IRModule.from_expr(func) + + mod = get_graph() + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the layout optimization pass has ran + ops = prim_func.body.body.seq + max_pool1, max_pool2 = ops + + assert str(max_pool1.value.args[31]) == '"NHCWB16"' + assert str(max_pool2.value.args[14]) == '"NHCWB16"' + + if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index 16835ce94ed7..d9a543c1a771 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -21,9 +21,16 @@ pytest.importorskip("ethosu.vela") +import tensorflow as tf +import numpy as np + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.op.contrib.ethosu import partition_for_ethosu + +from .test_codegen import _get_tflite_graph from . import infra @@ -59,6 +66,7 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) @@ -91,5 +99,35 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) + + +def test_lut_optimizer_runs_in_compilation_pipeline(): + """Test that the LUT optimization pass runs as part of the NPU compilation pipeline.""" + ifm_shape = (1, 4, 4, 4) + + @tf.function + def get_graph(x): + weight1 = tf.constant(np.random.uniform(size=(1, 1, 4, 4)), dtype=tf.float32) + op = tf.nn.conv2d(x, weight1, (1, 1), "VALID") + op = tf.nn.tanh(op) + weight2 = tf.constant(np.random.uniform(size=(1, 1, 4, 1)), dtype=tf.float32) + op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID") + return tf.nn.tanh(op) + + mod, _ = _get_tflite_graph(get_graph, [ifm_shape]) + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the LUT optimization pass has ran. + # If the module was optimized, there should be no identity operations. + def check_identity(stmt): + if isinstance(stmt, tvm.tir.expr.Call): + assert stmt.args[0] != "ethosu_identity" + + tvm.tir.stmt_functor.post_order_visit(prim_func.body, check_identity) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index be32ca308ba1..c76803b8fb3c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1537,19 +1537,29 @@ def run(dtype_str, infer_shape): element_shape = tf.TensorShape([tf.Dimension(None)]) else: element_shape = None - t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) - indices = tf.constant([2, 1, 0]) - ta1 = tf.TensorArray( - dtype=dtype, size=3, infer_shape=infer_shape, element_shape=element_shape - ) - ta2 = ta1.scatter(indices, t) - out0 = ta2.read(0) - out1 = ta2.read(1) - out2 = ta2.read(2) + ta0 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 3) + out0 = ta0.read(0) + out1 = ta0.read(1) + out2 = ta0.read(2) + ta1 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 4) + out4 = ta1.read(0) g = tf.get_default_graph() compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="vm") compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="vm") compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="vm") + compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0", out4.name], mode="vm") + + def _construct_scatter(dtype, dtype_str, element_shape, infer_shape, size): + arr = [[float(i)] for i in range(size)] + indices_arr = [i for i in range(size - 1, -1, -1)] + + t = tf.constant(np.array(arr).astype(dtype_str), dtype=dtype) + indices = tf.constant(indices_arr) + ta1 = tf.TensorArray( + dtype=dtype, size=size, infer_shape=infer_shape, element_shape=element_shape + ) + ta2 = ta1.scatter(indices, t) + return ta2 for dtype in ["float32", "int8"]: run(dtype, False) diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py index f2927f1e3ab7..530abe8029a6 100644 --- a/tests/scripts/git_utils.py +++ b/tests/scripts/git_utils.py @@ -39,7 +39,7 @@ def graphql(self, query: str) -> Dict[str, Any]: return self._post("https://api.github.com/graphql", {"query": query}) def _post(self, full_url: str, body: Dict[str, Any]) -> Dict[str, Any]: - print("Requesting", full_url) + print("Requesting POST to", full_url, "with", body) req = request.Request(full_url, headers=self.headers(), method="POST") req.add_header("Content-Type", "application/json; charset=utf-8") data = json.dumps(body) @@ -55,7 +55,7 @@ def post(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: def get(self, url: str) -> Dict[str, Any]: url = self.base + url - print("Requesting", url) + print("Requesting GET to", url) req = request.Request(url, headers=self.headers()) with request.urlopen(req) as response: response = json.loads(response.read()) @@ -63,7 +63,7 @@ def get(self, url: str) -> Dict[str, Any]: def delete(self, url: str) -> Dict[str, Any]: url = self.base + url - print("Requesting", url) + print("Requesting DELETE to", url) req = request.Request(url, headers=self.headers(), method="DELETE") with request.urlopen(req) as response: response = json.loads(response.read()) diff --git a/tests/scripts/github_cc_reviewers.py b/tests/scripts/github_cc_reviewers.py index 48420822ad55..8e7198aa7b8f 100755 --- a/tests/scripts/github_cc_reviewers.py +++ b/tests/scripts/github_cc_reviewers.py @@ -20,6 +20,7 @@ import json import argparse import re +from urllib import error from typing import Dict, Any, List @@ -70,4 +71,11 @@ def find_reviewers(body: str) -> List[str]: if not args.dry_run: github = GitHubRepo(token=os.environ["GITHUB_TOKEN"], user=user, repo=repo) - github.post(f"pulls/{number}/requested_reviewers", {"reviewers": to_add}) + + # Add reviewers 1 by 1 since GitHub will error out if any of the + # requested reviewers aren't members / contributors + for reviewer in to_add: + try: + github.post(f"pulls/{number}/requested_reviewers", {"reviewers": [reviewer]}) + except error.HTTPError as e: + print(f"Failed to add reviewer {reviewer}: {e}") diff --git a/web/src/compact.ts b/web/src/compact.ts index 29569b5d005d..ac6af35abeff 100644 --- a/web/src/compact.ts +++ b/web/src/compact.ts @@ -19,9 +19,9 @@ /** NodeJS and Web compact layer */ /** - * Get performance masurement. + * Get performance measurement. */ -export function getPeformance(): Performance { +export function getPerformance(): Performance { if (typeof performance == "undefined") { // eslint-disable-next-line @typescript-eslint/no-var-requires const performanceNode = require("perf_hooks"); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 60a28d53f361..b0e71d945f8a 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -653,7 +653,7 @@ class GraphExecutor implements Disposable { */ async benchmarkRuns(dev: DLDevice, number=10, repeat=4): Promise { // Skip first run as it can involve GPU warmup and module loading time. - const perf = compact.getPeformance(); + const perf = compact.getPerformance(); const results = []; this.run(); await dev.sync(); @@ -1049,7 +1049,7 @@ export class Instance implements Disposable { /** Register global packed functions needed by the backend to the env. */ private registerEnvGlobalPackedFuncs(): void { // Register the timer function to enable the time_evaluator. - const perf = compact.getPeformance(); + const perf = compact.getPerformance(); // Helper function to time the finvoke const timeExecution = async (