diff --git a/tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py b/tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py index 137c056f99305..d52e4d5bcaeb2 100644 --- a/tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py +++ b/tests/python/contrib/test_hexagon/hmx_qnn_conv2d_generator.py @@ -137,7 +137,9 @@ def make_hexagon_conv2d_quantized_no_layout_transform_nhwc( BH = AH + PH_low + PH_high - FH + 1 BW = AW + PW_low + PW_high - FW + 1 - output_shape = [1, K, BH // strides[0], BW // strides[1]] + SBH = (AH + PH_low + PH_high - FH) // strides[0] + 1 + SBW = (AW + PW_low + PW_high - FW) // strides[1] + 1 + output_shape = [1, K, SBH, SBW] BC_pad_w_low = (AC_pad_w_low - PW_low + (FW - 1)) % TW BC_pad_h_low = (AC_pad_h_low - PH_low) % TH @@ -254,18 +256,15 @@ def output_tiled_to_NCHW(B_handle: T.handle, BC_handle: T.handle): B = T.match_buffer(B_handle, shape=output_shape, dtype=b_dtype) BC = T.match_buffer(BC_handle, shape=transformed_output_shape, dtype=B.dtype, scope=mem_scope) with T.block("output_tiled_to_NCHW"): - for n, k, bh, bw in T.grid(1, K, BH, BW): + for n, k, sbh, sbw in T.grid(1, K, SBH, SBW): ko = k // 32 ki = k % 32 - sbh = T.floordiv(bh, strides[0]) - sbw = T.floordiv(bw, strides[1]) + bho = (sbh * strides[0] + BC_pad_h_low) // TH + bhi = (sbh * strides[0] + BC_pad_h_low) % TH - bho = (sbh + BC_pad_h_low) // TH - bhi = (sbh + BC_pad_h_low) % TH - - bwo = (sbw + BC_pad_w_low) // TW - bwi = (sbw + BC_pad_w_low) % TW + bwo = (sbw * strides[1] + BC_pad_w_low) // TW + bwi = (sbw * strides[1] + BC_pad_w_low) % TW B[n, k, sbh, sbw] = BC[n, bho, bwo, ko, bhi, bwi, ki]