Skip to content

Commit

Permalink
Merge pull request apache#12 from octoml/hmx_stride_support
Browse files Browse the repository at this point in the history
HMX conv2d stride support
  • Loading branch information
supersat authored Feb 16, 2023
2 parents 88c40b9 + cac9e62 commit 584c90f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 22 deletions.
6 changes: 4 additions & 2 deletions tests/python/contrib/test_hexagon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def input_bias(filter_shape):

@tvm.testing.fixture
def reference_output_quantized(
input_activation, input_weights, input_bias, workload_padding, quantization_parameters
input_activation, input_weights, input_bias, workload_padding, strides, quantization_parameters
):
data = relay.var("data", shape=input_activation.shape, dtype="uint8")
weight = relay.var("weight", shape=input_weights.shape, dtype="int8")
Expand All @@ -91,7 +91,7 @@ def reference_output_quantized(
kernel_size=(input_weights.shape[2], input_weights.shape[3]),
channels=input_weights.shape[0],
padding=workload_padding,
strides=(1, 1),
strides=strides,
out_dtype="int32",
)

Expand Down Expand Up @@ -123,6 +123,7 @@ def compose_one_convolution_quantized_separate_layout_transforms(
filter_shape,
input_tile_offset,
workload_padding,
strides,
mem_scope,
):
shape_A = input_shape
Expand All @@ -133,6 +134,7 @@ def compose_one_convolution_quantized_separate_layout_transforms(
filter_shape,
tile_offset_A,
workload_padding=workload_padding,
strides=strides,
mem_scope=mem_scope
)

Expand Down
29 changes: 18 additions & 11 deletions tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc(
filter_shape,
input_tile_offset,
workload_padding=(0, 0, 0, 0),
strides=(1, 1),
tile_shape_hw=(8, 8),
mem_scope="global",
):
Expand Down Expand Up @@ -42,6 +43,10 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc(
Only padding with zeroes is supported.
strides
A tuple of length 1, specifying the strides in (height,width).
tile_shape_hw
A tuple of length 2, specifying the tile size in (height,width).
Expand Down Expand Up @@ -132,7 +137,9 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc(

BH = AH + PH_low + PH_high - FH + 1
BW = AW + PW_low + PW_high - FW + 1
output_shape = [1, K, BH, BW]
SBH = (AH + PH_low + PH_high - FH) // strides[0] + 1
SBW = (AW + PW_low + PW_high - FW) // strides[1] + 1
output_shape = [1, K, SBH, SBW]

BC_pad_w_low = (AC_pad_w_low - PW_low + (FW - 1)) % TW
BC_pad_h_low = (AC_pad_h_low - PH_low) % TH
Expand All @@ -145,13 +152,13 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc(
CO = tir.ceildiv(C, 32)
CII = 4
CIO = tir.ceildiv(CI, CII)
AHO = tir.ceildiv((AC_pad_h_low + AH), TH)
AWO = tir.ceildiv((AC_pad_w_low + AW), TW)
AHO = tir.ceildiv(AC_pad_h_low + AH, TH)
AWO = tir.ceildiv(AC_pad_w_low + AW, TW)

KI = 32
KO = tir.ceildiv(K, KI)
BHO = tir.ceildiv((BC_pad_h_low + BH), TH)
BWO = tir.ceildiv((BC_pad_w_low + BW), TW)
BHO = tir.ceildiv(BC_pad_h_low + BH, TH)
BWO = tir.ceildiv(BC_pad_w_low + BW, TW)

transformed_input_shape = [
1,
Expand Down Expand Up @@ -268,17 +275,17 @@ def output_tiled_to_NCHW(B_handle: T.handle, BC_handle: T.handle):
B = T.match_buffer(B_handle, shape=output_shape, dtype=b_dtype)
BC = T.match_buffer(BC_handle, shape=transformed_output_shape, dtype=B.dtype, scope=mem_scope)
with T.block("output_tiled_to_NCHW"):
for n, k, bh, bw in T.grid(1, K, BH, BW):
for n, k, sbh, sbw in T.grid(1, K, SBH, SBW):
ko = k // 32
ki = k % 32

bho = (bh + BC_pad_h_low) // TH
bhi = (bh + BC_pad_h_low) % TH
bho = (sbh * strides[0] + BC_pad_h_low) // TH
bhi = (sbh * strides[0] + BC_pad_h_low) % TH

bwo = (bw + BC_pad_w_low) // TW
bwi = (bw + BC_pad_w_low) % TW
bwo = (sbw * strides[1] + BC_pad_w_low) // TW
bwi = (sbw * strides[1] + BC_pad_w_low) % TW

B[n, k, bh, bw] = BC[n, bho, bwo, ko, bhi, bwi, ki]
B[n, k, sbh, sbw] = BC[n, bho, bwo, ko, bhi, bwi, ki]

@T.prim_func
def bias_and_filter_constexpr(
Expand Down
47 changes: 38 additions & 9 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,45 @@ class Conv2dHMXTestingBase:
}
)

strides = tvm.testing.parameter(
by_dict={
"stride-1": (1, 1),
"stride-2": (2, 2),
}
)

quantization_parameters = tvm.testing.parameter(
{
"activation_zero_point": 2,
"filter_zero_point": 0,
"activation_scale": 1.0,
"filter_scale": 1.0,
"requantize_input_scale": 1.0,
"requantize_input_zero_point": 0,
"requantize_output_scale": 1.0,
"requantize_output_zero_point": 0,
by_dict = {
"qp-none": {
"activation_zero_point": 0,
"filter_zero_point": 0,
"activation_scale": 1.0,
"filter_scale": 1.0,
"requantize_input_scale": 1.0,
"requantize_input_zero_point": 0,
"requantize_output_scale": 1.0,
"requantize_output_zero_point": 0,
},
"qp-azp-2": {
"activation_zero_point": 2,
"filter_zero_point": 0,
"activation_scale": 1.0,
"filter_scale": 1.0,
"requantize_input_scale": 1.0,
"requantize_input_zero_point": 0,
"requantize_output_scale": 1.0,
"requantize_output_zero_point": 0,
},
"qp-rq": {
"activation_zero_point": 0,
"filter_zero_point": 0,
"activation_scale": 1.0,
"filter_scale": 1.0,
"requantize_input_scale": 1.0,
"requantize_input_zero_point": 0,
"requantize_output_scale": 2.0,
"requantize_output_zero_point": 3,
},
}
)

Expand Down

0 comments on commit 584c90f

Please sign in to comment.