Skip to content

Commit

Permalink
[microNPU] Support different constant datatypes
Browse files Browse the repository at this point in the history
Currently only uint8 datatype is supported for constants, as this is
all that was necessary until now. This PR allows different datatypes
to be used for constants, including different datatypes within the
same graph.

A workaround was previously added for Mean legalization, this has
also been removed and replaced with the expected datatype of the
constant.

Change-Id: I99e34fe17905b1bb7d916e346cebfc324e3a2a0c
  • Loading branch information
lhutton1 committed Dec 1, 2021
1 parent e950ce5 commit 7d6e8ef
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 53 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def callback(
n = int(filter_height * filter_width)
eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0

scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="uint8"), dtype="uint8")
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), dtype="int16")

reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
Expand Down Expand Up @@ -1156,6 +1156,7 @@ def transform_module(
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
print(mod)
return mod

def __call__(self, *args, **kwargs):
Expand Down
45 changes: 22 additions & 23 deletions python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,11 @@ def translate(tir_module, params):
_npu_ops = list()
for call_extern in call_extern_list:
_npu_ops.append(translate_ethosu_tir_call_extern(call_extern))
_npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops)
_npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops)
target_accel_config = vela_api.get_accelerator_config()
cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config)
payload = vapi.npu_create_driver_payload(cmds, target_accel_config)
hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex()
return payload.hex(), hex_value, scratch_size
return payload.hex(), constant_data, scratch_size


def extract_call_extern_list(mod):
Expand Down Expand Up @@ -277,27 +276,24 @@ def classify_io(buffer):
raise ValueError(f"Unused IO : {buffer} in tir module.")

scratch_size = 0
constant_tensor = None
constant_hex_data = []
total_constant_len = 0
buffer_addresses = dict()
for _buffer, info in buffer_info.items():
dtype_bytes = np.iinfo(np.dtype(info.dtype)).bits // 8
if info.values is not None:
assert np.dtype(info.dtype) == np.uint8
assert info.btype == BufferType.constant
assert len(info.shape) == 1
if constant_tensor is None:
buffer_addresses[_buffer] = (0, info.btype)
assert info.values.dtype == np.uint8
size_in_bytes = info.values.size
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
constant_tensor = np.resize(info.values, size_in_bytes)
else:
buffer_addresses[_buffer] = (constant_tensor.size, info.btype)
assert info.values.dtype == np.uint8
size_in_bytes = info.values.size
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes))
buffer_addresses[_buffer] = (
(total_constant_len, info.btype) if constant_hex_data else (0, info.btype)
)
size_in_bytes = dtype_bytes * np.prod(list(info.shape))
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
constant_tensor = np.resize(info.values, size_in_bytes // dtype_bytes)
constant_tensor = constant_tensor.tobytes().hex()
constant_hex_data.append(constant_tensor)
total_constant_len += len(constant_tensor) // 2
else:
if info.btype == BufferType.input_or_output:
buffer_type = classify_io(_buffer)
Expand All @@ -310,9 +306,7 @@ def classify_io(buffer):
address = arch_config.lut_start_address
buffer_addresses[_buffer] = (address, info.btype)
else:
size_in_bytes = int(
(np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
)
size_in_bytes = int(dtype_bytes * np.prod(list(info.shape)))
# Every memory address the NPU access have to be 16 byte aligned
size_in_bytes = util.round_up(size_in_bytes, 16)
assert info.btype == BufferType.scratch
Expand All @@ -330,7 +324,12 @@ def classify_io(buffer):
else:
setattr(npu_op, attr_name, replace_tir_loads(attr))

return npu_ops, constant_tensor, scratch_size
constant_data = "".join(constant_hex_data)
return (
npu_ops,
constant_data,
scratch_size,
)


def translate_ethosu_tir_call_extern(tir_call_extern):
Expand Down
23 changes: 5 additions & 18 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,6 @@ struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode<EthosuBinaryElementw

TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs);

bool IsScalarTensor(const Array<PrimExpr>& ifm_shape, const DataType& ifm_dtype) {
if (ifm_dtype != DataType::UInt(8)) {
return false;
}

for (const auto& expr : ifm_shape) {
const auto& dim_int_node = expr.as<IntImmNode>();
CHECK(dim_int_node) << "Expected IntImmNode for shape dimensions.";
int dim = dim_int_node->value;
if (dim != 1) return false;
}

return true;
}

bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const int ifm_index = 0;
Expand All @@ -167,11 +152,13 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm_dtype != ifm2_dtype && !IsScalarTensor(ifm2->shape, ifm2_dtype)) {
if (ifm_dtype != ifm2_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< "type for ifm2 be the same of ifm but was " << ifm2_dtype
Expand All @@ -189,11 +176,11 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(32)) {
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype);
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
} else if (operator_type == "MIN" || operator_type == "MAX") {
Expand Down
10 changes: 7 additions & 3 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def create_mod_from_relay():


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_binary_add_from_constant_scalar(accel_type):
dtype = "uint8"
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
def test_elementwise_add_from_constant_scalar(accel_type, dtype):
ifm_shape = (1, 4, 4, 8)

def create_relay_graph():
Expand All @@ -631,7 +631,11 @@ def create_relay_graph():
partitioned_mod = partition_for_ethosu(mod)

# Generate reference data
input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
input_data = {
"input": np.random.randint(
low=np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=ifm_shape, dtype=dtype
),
}
output_data = generate_ref_data(mod, input_data)

compiled_models = infra.build_source(
Expand Down
62 changes: 54 additions & 8 deletions tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,31 @@ def populate_ethosu_copy_calls(stmt):
assert npu_dma_op.dest.length == test_case["ref"][idx]["length"]


# fmt: off
@tvm.script.ir_module
class MixedConstantDatatypes:
@T.prim_func
def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle, placeholder_3: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
placeholder_4 = T.match_buffer(placeholder, [1, 8, 16, 16], dtype="int8")
buffer = T.match_buffer(placeholder_1, [160], dtype="uint8")
placeholder_5 = T.match_buffer(placeholder_2, [1, 1, 1, 1], dtype="int16")
ethosu_write_1 = T.match_buffer(ethosu_write, [1, 1, 1, 16], dtype="int8")
buffer_1 = T.match_buffer(placeholder_3, [272], dtype="uint8")
# body
placeholder_global = T.allocate([272], "uint8", "global")
placeholder_d_global = T.allocate([160], "uint8", "global")
ethosu_write_2 = T.allocate([16], "int16", "global")
placeholder_d_global_1 = T.allocate([1], "int16", "global")
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8"))
T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 0, T.load("uint8", placeholder_d_global, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16"))
T.evaluate(T.call_extern("ethosu_copy", T.load("int16", placeholder_5.data, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16"))
T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, T.load("int16", placeholder_d_global_1, 0), 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8"))
# fmt: on


def test_assign_addresses():
test_cases = [
{
Expand Down Expand Up @@ -683,6 +708,15 @@ def test_assign_addresses():
11: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"),
},
},
{
# Stimulus
"tir_module": MixedConstantDatatypes,
"param_dict": {
1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"),
2: np.random.randint(np.iinfo("int16").min, np.iinfo("int16").max, [1], "int16"),
4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [272], "uint8"),
},
},
]

def extract_call_extern_list(mod):
Expand Down Expand Up @@ -747,24 +781,36 @@ def _check_buffer(address, region, length, buffer_var):
4: tir_to_cs_translator.BufferType.output,
}
buffer_type = inverse_region_map[region]
buffer_dtype = buffer_var.type_annotation.element_type.dtype
dtype_bytes = np.iinfo(np.dtype(buffer_dtype)).bits // 8
if buffer_type == tir_to_cs_translator.BufferType.constant:
ref = buffer_info[buffer_var].values
assert (constant_tensor[address : address + length] == ref).all()
hex_from = address * dtype_bytes * 2
hex_to = hex_from + length * dtype_bytes * 2
constant_hex = constant_hex_string[hex_from:hex_to]
constant_tensor = np.frombuffer(bytearray.fromhex(constant_hex), dtype=buffer_dtype)
np.array_equal(constant_tensor, ref)
# Every buffer is adjusted to align to 16 bytes
length = util.round_up(length, 16)
# Mark these constants are read at least once
constant_tensor_read_mask[address : address + length] = np.ones(length, dtype="uint8")
constant_tensor_read_mask[address : address + length] = np.ones(
length, dtype=buffer_dtype
)
elif buffer_type == tir_to_cs_translator.BufferType.scratch:
shape = list(buffer_info[buffer_var].shape)
assert length == np.prod(shape)
assert address < scratch_size

size_in_bytes = int(np.prod(shape)) * dtype_bytes
# Every buffer is adjusted to align to 16 bytes
length = util.round_up(length, 16)
assert address + length <= scratch_size
size_in_bytes = util.round_up(size_in_bytes, 16)
assert address + size_in_bytes <= scratch_size
# The scratch area should not be used by anyother buffer
assert not scratch_allocation_mask[address : address + length].any()
assert not scratch_allocation_mask[address : address + size_in_bytes].any()
# The scratch area is marked as used
scratch_allocation_mask[address : address + length] = np.ones(length, dtype="uint8")
scratch_allocation_mask[address : address + size_in_bytes] = np.ones(
size_in_bytes, dtype="uint8"
)
elif buffer_type == tir_to_cs_translator.BufferType.input:
assert address == 0
else:
Expand Down Expand Up @@ -841,11 +887,11 @@ def check_buffer(address, region, length, buffer_var):
for extern_call in extern_calls:
_npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call))
npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops)
_npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses(
_npu_ops, constant_hex_string, scratch_size = tir_to_cs_translator.assign_addresses(
buffer_info, _npu_ops
)
scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8")
constant_tensor_read_mask = np.zeros(constant_tensor.size, dtype="uint8")
constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8")
verify(_npu_ops)
# This will be only 1 if all allocated scratch is used.
assert np.prod(scratch_allocation_mask) == 1
Expand Down

0 comments on commit 7d6e8ef

Please sign in to comment.