Skip to content

Commit

Permalink
Fix striding support
Browse files Browse the repository at this point in the history
  • Loading branch information
supersat committed Feb 16, 2023
1 parent 8738657 commit cac9e62
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc(

BH = AH + PH_low + PH_high - FH + 1
BW = AW + PW_low + PW_high - FW + 1
output_shape = [1, K, BH // strides[0], BW // strides[1]]
SBH = (AH + PH_low + PH_high - FH) // strides[0] + 1
SBW = (AW + PW_low + PW_high - FW) // strides[1] + 1
output_shape = [1, K, SBH, SBW]

BC_pad_w_low = (AC_pad_w_low - PW_low + (FW - 1)) % TW
BC_pad_h_low = (AC_pad_h_low - PH_low) % TH
Expand Down Expand Up @@ -254,18 +256,15 @@ def output_tiled_to_NCHW(B_handle: T.handle, BC_handle: T.handle):
B = T.match_buffer(B_handle, shape=output_shape, dtype=b_dtype)
BC = T.match_buffer(BC_handle, shape=transformed_output_shape, dtype=B.dtype, scope=mem_scope)
with T.block("output_tiled_to_NCHW"):
for n, k, bh, bw in T.grid(1, K, BH, BW):
for n, k, sbh, sbw in T.grid(1, K, SBH, SBW):
ko = k // 32
ki = k % 32

sbh = T.floordiv(bh, strides[0])
sbw = T.floordiv(bw, strides[1])
bho = (sbh * strides[0] + BC_pad_h_low) // TH
bhi = (sbh * strides[0] + BC_pad_h_low) % TH

bho = (sbh + BC_pad_h_low) // TH
bhi = (sbh + BC_pad_h_low) % TH

bwo = (sbw + BC_pad_w_low) // TW
bwi = (sbw + BC_pad_w_low) % TW
bwo = (sbw * strides[1] + BC_pad_w_low) // TW
bwi = (sbw * strides[1] + BC_pad_w_low) % TW

B[n, k, sbh, sbw] = BC[n, bho, bwo, ko, bhi, bwi, ki]

Expand Down

0 comments on commit cac9e62

Please sign in to comment.