Skip to content

Commit

Permalink
[microNPU] Add support for unary elementwise CLZ (apache#9577)
Browse files Browse the repository at this point in the history
Add support for the CLZ (count leading zeros) operator
and the codegen test.


Co-authored-by: Rishabh Jain <rishabh.jain2@arm.com>
  • Loading branch information
2 people authored and masahi committed Dec 1, 2021
1 parent af9a446 commit 1498239
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def ethosu_unary_elementwise(
operator_type: str
The type of the unary elementwise operator.
"ABS"
"CLZ"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
Expand Down Expand Up @@ -143,7 +144,7 @@ def ethosu_unary_elementwise(
Returns
-------
out : tvm.relay.Call
A call to the ethosu_binary_elementwise op.
A call to the ethosu_unary_elementwise op.
"""
return _make.ethosu_unary_elementwise(
ifm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def unary_elementwise_compute(
operator_type: str
The type of the unary elementwise operator.
"ABS"
"CLZ"
ifm_scale : float
The quantization scale for the Input Feature Map tensor.
ifm_zero_point : int
Expand Down Expand Up @@ -111,7 +112,11 @@ def unary_elementwise_compute(
"rounding_mode": rounding_mode,
}

operators = {"ABS": te.abs}
def clz_imp(inp):
# Assuming that it's a 32 bit int
return 32 - te.log2(inp)

operators = {"ABS": te.abs, "CLZ": clz_imp}

unary_elementwise = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def get_unary_elementwise_params(stmt, producers, consumers):
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
input_pointer = None
if isinstance(inner.value, tir.expr.Select):
# ABS
input_pointer = inner.value.condition.b.buffer_var
if isinstance(inner.value, tir.expr.Sub):
# CLZ
input_pointer = inner.value.b.args[0].buffer_var
output_pointer = inner.buffer_var
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ def _create_npu_op_unary_elementwise(serial_unary_elementwise):
operator_type = serial_unary_elementwise.operator_type
if operator_type == "ABS":
op = vapi.NpuElementWiseOp.ABS
if operator_type == "CLZ":
op = vapi.NpuElementWiseOp.CLZ

npu_unary_elementwise_op = vapi.NpuElementWiseOperation(op)
npu_unary_elementwise_op.ifm = _create_npu_feature_map(serial_unary_elementwise.ifm)
Expand Down
27 changes: 18 additions & 9 deletions src/relay/op/contrib/ethosu/unary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ struct EthosuUnaryElementwiseAttrs : public tvm::AttrsNode<EthosuUnaryElementwis
TVM_ATTR_FIELD(operator_type)
.describe(
"The type of the unary elementwise operator."
"'ABS'");
"'ABS'"
"'CLZ'");
TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor.");
TVM_ATTR_FIELD(ifm_zero_point)
.describe("The quantization zero point for the Input Feature Map tensor.");
Expand Down Expand Up @@ -104,20 +105,28 @@ bool EthosuUnaryElementwiseRel(const Array<Type>& types, int num_inputs, const A
CHECK(param != nullptr) << "EthosuUnaryElementwiseAttrs cannot be nullptr.";

String operator_type = param->operator_type;
if (operator_type != "ABS") {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_unary_elementwise 'ABS' for operator_type but was"
<< operator_type);
if (operator_type != "ABS" && operator_type != "CLZ") {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_unary_elementwise 'ABS' "
"or 'CLZ' for operator_type but was"
<< operator_type);
return false;
}

auto ifm_dtype = ifm->dtype;
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && operator_type == "ABS") {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_unary_elementwise "
<< operator_type << "input data type "
<< "of type(uint8) or type(int8) but was " << ifm_dtype);
return false;
}

if (ifm_dtype != DataType::Int(32) && operator_type == "CLZ") {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_unary_elementwise input data type "
<< "of type(uint8) or type(int8) but was " << ifm_dtype);
<< "Invalid operator: expected ethosu_unary_elementwise CLZ input data type "
<< "of type(int32) but was " << ifm_dtype);
return false;
}

Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,9 @@ def make_partitioned_function(relay_op):

ifm0 = relay.analysis.free_vars(relay_op)
ifm_shape = ifm0[0].type_annotation.shape
ifm_dtype = ifm0[0].type_annotation.dtype

ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
ifm = relay.var("ifm", shape=ifm_shape, dtype=ifm_dtype)

glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")

Expand Down
34 changes: 34 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,5 +969,39 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp
assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_ethosu_clz(accel_type):
ifm_shape = (1, 42, 5, 4)
# Create a "partitioned" Relay function
ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int32")
clz = infra.make_ethosu_unary_elementwise(ifm0, 4, "CLZ")
mod = infra.make_partitioned_function(clz)

in_data = np.random.randint(-500000, high=500000, size=ifm_shape, dtype="int32")

def clz_comp(n):
n_bin = np.binary_repr(n)
if n_bin[0] == "-":
return 0
else:
return 32 - len(n_bin)

out_data = np.array([clz_comp(i) for i in in_data.ravel()]).reshape(ifm_shape).astype("int32")

compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type)

imported_modules = compiled_model[0].executor_factory.lib.imported_modules
assert len(imported_modules) == 2
ethosu_module = imported_modules[0]

# Verify generated C source
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)

infra.print_payload(cmms)
infra.verify_source(compiled_model, accel_type)


if __name__ == "__main__":
pytest.main([__file__])
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_unary_elementwise_args(call, include_buffers=False, remove_constants=Fa
((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"),
],
)
@pytest.mark.parametrize("operator_type", ["ABS"])
@pytest.mark.parametrize("operator_type, data_type", [("ABS", "int8"), ("CLZ", "int32")])
@pytest.mark.parametrize("activation", ["NONE"])
def test_unary_elementwise_single(
ifm_shape,
Expand All @@ -60,8 +60,9 @@ def test_unary_elementwise_single(
rounding_mode,
operator_type,
activation,
data_type,
):
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
ifm = relay.var("ifm", shape=ifm_shape, dtype=data_type)

unary_elementwise = make_ethosu_unary_elementwise(
ifm, ifm_channels, operator_type, activation, ifm_layout, ofm_layout, rounding_mode
Expand Down Expand Up @@ -102,7 +103,7 @@ def _visit(stmt):

serial_unary_elementwise = spec.SerialUnaryElementwise(
ifm=spec.SerialFeatureMap(
data_type="int8",
data_type=data_type,
height=ifm_shape[1],
width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
channels=ifm_channels,
Expand All @@ -121,7 +122,7 @@ def _visit(stmt):
stride_c=ifm_stride_c,
),
ofm=spec.SerialFeatureMap(
data_type="int8",
data_type=data_type,
height=ofm_height,
width=ofm_width,
channels=ifm_channels,
Expand Down
6 changes: 4 additions & 2 deletions tests/python/contrib/test_ethosu/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,16 @@ def test_ethosu_identity_invalid_dtype():
@pytest.mark.parametrize(
"ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")]
)
@pytest.mark.parametrize("operator_type, data_type", [("ABS", "int8"), ("CLZ", "int32")])
def test_ethosu_unary_elementwise_type_inference(
ifm_shape,
ifm_layout,
ofm_shape,
ofm_layout,
operator_type,
data_type,
):
ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
operator_type = "ABS"
ifm = relay.var("ifm", shape=ifm_shape, dtype=data_type)
ofm_channels = 33
unary_elementwise = make_ethosu_unary_elementwise(
ifm,
Expand Down

0 comments on commit 1498239

Please sign in to comment.