diff --git a/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py index a339561d97e3..35104da92e8b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py @@ -109,6 +109,7 @@ def ethosu_unary_elementwise( operator_type: str The type of the unary elementwise operator. "ABS" + "CLZ" ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -143,7 +144,7 @@ def ethosu_unary_elementwise( Returns ------- out : tvm.relay.Call - A call to the ethosu_binary_elementwise op. + A call to the ethosu_unary_elementwise op. """ return _make.ethosu_unary_elementwise( ifm, diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index d45a8f4fc43d..0aefc1c35d4c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -48,6 +48,7 @@ def unary_elementwise_compute( operator_type: str The type of the unary elementwise operator. "ABS" + "CLZ" ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -111,7 +112,11 @@ def unary_elementwise_compute( "rounding_mode": rounding_mode, } - operators = {"ABS": te.abs} + def clz_imp(inp): + # Assuming that it's a 32 bit int + return 32 - te.log2(inp) + + operators = {"ABS": te.abs, "CLZ": clz_imp} unary_elementwise = te.compute( (1, ofm_height, ofm_width, ofm_channels), diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index 6dc801f2b28c..4910330a67f4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -52,7 +52,11 @@ def get_unary_elementwise_params(stmt, producers, consumers): _, _, _, _, _, inner = get_outer_loops(body, "NHWC") input_pointer = None if isinstance(inner.value, tir.expr.Select): + # ABS input_pointer = inner.value.condition.b.buffer_var + if isinstance(inner.value, tir.expr.Sub): + # CLZ + input_pointer = inner.value.b.args[0].buffer_var output_pointer = inner.buffer_var # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index d276417bde3b..4e84febe5e48 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -799,6 +799,8 @@ def _create_npu_op_unary_elementwise(serial_unary_elementwise): operator_type = serial_unary_elementwise.operator_type if operator_type == "ABS": op = vapi.NpuElementWiseOp.ABS + if operator_type == "CLZ": + op = vapi.NpuElementWiseOp.CLZ npu_unary_elementwise_op = vapi.NpuElementWiseOperation(op) npu_unary_elementwise_op.ifm = _create_npu_feature_map(serial_unary_elementwise.ifm) diff --git a/src/relay/op/contrib/ethosu/unary_elementwise.cc b/src/relay/op/contrib/ethosu/unary_elementwise.cc index 60f1eefaa6b2..9dc07e031d75 100644 --- a/src/relay/op/contrib/ethosu/unary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/unary_elementwise.cc @@ -50,7 +50,8 @@ struct EthosuUnaryElementwiseAttrs : public tvm::AttrsNode& types, int num_inputs, const A CHECK(param != nullptr) << "EthosuUnaryElementwiseAttrs cannot be nullptr."; String operator_type = param->operator_type; - if (operator_type != "ABS") { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise 'ABS' for operator_type but was" - << operator_type); + if (operator_type != "ABS" && operator_type != "CLZ") { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_unary_elementwise 'ABS' " + "or 'CLZ' for operator_type but was" + << operator_type); return false; } auto ifm_dtype = ifm->dtype; - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && operator_type == "ABS") { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_unary_elementwise " + << operator_type << "input data type " + << "of type(uint8) or type(int8) but was " << ifm_dtype); + return false; + } + + if (ifm_dtype != DataType::Int(32) && operator_type == "CLZ") { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise input data type " - << "of type(uint8) or type(int8) but was " << ifm_dtype); + << "Invalid operator: expected ethosu_unary_elementwise CLZ input data type " + << "of type(int32) but was " << ifm_dtype); return false; } diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 38bd88c10e48..5f339267e0b8 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -312,8 +312,9 @@ def make_partitioned_function(relay_op): ifm0 = relay.analysis.free_vars(relay_op) ifm_shape = ifm0[0].type_annotation.shape + ifm_dtype = ifm0[0].type_annotation.dtype - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + ifm = relay.var("ifm", shape=ifm_shape, dtype=ifm_dtype) glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0") diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 92a1ad71deda..b6cf873cb6f3 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -969,5 +969,39 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +def test_ethosu_clz(accel_type): + ifm_shape = (1, 42, 5, 4) + # Create a "partitioned" Relay function + ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int32") + clz = infra.make_ethosu_unary_elementwise(ifm0, 4, "CLZ") + mod = infra.make_partitioned_function(clz) + + in_data = np.random.randint(-500000, high=500000, size=ifm_shape, dtype="int32") + + def clz_comp(n): + n_bin = np.binary_repr(n) + if n_bin[0] == "-": + return 0 + else: + return 32 - len(n_bin) + + out_data = np.array([clz_comp(i) for i in in_data.ravel()]).reshape(ifm_shape).astype("int32") + + compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type) + + imported_modules = compiled_model[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_model, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py index eff81c4e6cbd..e1c633e1d569 100644 --- a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py +++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py @@ -50,7 +50,7 @@ def _get_unary_elementwise_args(call, include_buffers=False, remove_constants=Fa ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"), ], ) -@pytest.mark.parametrize("operator_type", ["ABS"]) +@pytest.mark.parametrize("operator_type, data_type", [("ABS", "int8"), ("CLZ", "int32")]) @pytest.mark.parametrize("activation", ["NONE"]) def test_unary_elementwise_single( ifm_shape, @@ -60,8 +60,9 @@ def test_unary_elementwise_single( rounding_mode, operator_type, activation, + data_type, ): - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + ifm = relay.var("ifm", shape=ifm_shape, dtype=data_type) unary_elementwise = make_ethosu_unary_elementwise( ifm, ifm_channels, operator_type, activation, ifm_layout, ofm_layout, rounding_mode @@ -102,7 +103,7 @@ def _visit(stmt): serial_unary_elementwise = spec.SerialUnaryElementwise( ifm=spec.SerialFeatureMap( - data_type="int8", + data_type=data_type, height=ifm_shape[1], width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], channels=ifm_channels, @@ -121,7 +122,7 @@ def _visit(stmt): stride_c=ifm_stride_c, ), ofm=spec.SerialFeatureMap( - data_type="int8", + data_type=data_type, height=ofm_height, width=ofm_width, channels=ifm_channels, diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 778e4efc4b24..9b606562c5c0 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -381,14 +381,16 @@ def test_ethosu_identity_invalid_dtype(): @pytest.mark.parametrize( "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] ) +@pytest.mark.parametrize("operator_type, data_type", [("ABS", "int8"), ("CLZ", "int32")]) def test_ethosu_unary_elementwise_type_inference( ifm_shape, ifm_layout, ofm_shape, ofm_layout, + operator_type, + data_type, ): - ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") - operator_type = "ABS" + ifm = relay.var("ifm", shape=ifm_shape, dtype=data_type) ofm_channels = 33 unary_elementwise = make_ethosu_unary_elementwise( ifm,