From 1264f743b67d36ba419dc96386d39caf59252251 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 29 Jul 2021 12:14:13 -0600 Subject: [PATCH] speed up reference resize kernel --- python/tvm/topi/testing/resize_python.py | 100 ++++++++++------------- 1 file changed, 41 insertions(+), 59 deletions(-) diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py index e8d5c0599887..13b460f07e1d 100644 --- a/python/tvm/topi/testing/resize_python.py +++ b/python/tvm/topi/testing/resize_python.py @@ -66,51 +66,52 @@ def resize3d_nearest(arr, scale, coordinate_transformation_mode): def resize3d_linear(data_in, scale, coordinate_transformation_mode): """Trilinear 3d scaling using python""" + dtype = data_in.dtype d, h, w = data_in.shape new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] data_out = np.ones((new_d, new_h, new_w)) - def _lerp(A, B, t): - return A * (1.0 - t) + B * t + indexes = np.mgrid[0:2, 0:2, 0:2] - def _in_coord(new_coord, in_shape, out_shape): - in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode) - coord0 = int(math.floor(in_coord)) - coord1 = max(min(coord0 + 1, in_shape - 1), 0) - coord0 = max(coord0, 0) - coord_lerp = in_coord - math.floor(in_coord) - return coord0, coord1, coord_lerp + def _get_patch(zint, yint, xint): + # Get the surrounding values + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] + return p for m in range(new_d): for j in range(new_h): for k in range(new_w): - z0, z1, z_lerp = _in_coord(m, d, new_d) - y0, y1, y_lerp = _in_coord(j, h, new_h) - x0, x1, x_lerp = _in_coord(k, w, new_w) - - A0 = data_in[z0][y0][x0] - B0 = data_in[z0][y0][x1] - C0 = data_in[z0][y1][x0] - D0 = data_in[z0][y1][x1] - A1 = data_in[z1][y0][x0] - B1 = data_in[z1][y0][x1] - C1 = data_in[z1][y1][x0] - D1 = data_in[z1][y1][x1] - - A = _lerp(A0, A1, z_lerp) - B = _lerp(B0, B1, z_lerp) - C = _lerp(C0, C1, z_lerp) - D = _lerp(D0, D1, z_lerp) - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp)) + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + wz = np.array([1.0 - zfract, zfract], dtype=dtype) + wy = np.array([1.0 - yfract, yfract], dtype=dtype) + wx = np.array([1.0 - xfract, xfract], dtype=dtype) + + p = _get_patch(zint, yint, xint) + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) return data_out def resize3d_cubic(data_in, scale, coordinate_transformation_mode): """Tricubic 3d scaling using python""" + dtype = data_in.dtype d, h, w = data_in.shape new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] data_out = np.ones((new_d, new_h, new_w)) @@ -123,29 +124,17 @@ def _cubic_spline_weights(t, alpha=-0.5): w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t w4 = -alpha * t3 + alpha * t2 - return [w1, w2, w3, w4] + return np.array([w1, w2, w3, w4]) - def _cubic_kernel(inputs, w): - """perform cubic interpolation in 1D""" - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) - - def _get_input_value(z, y, x): - z = max(min(z, d - 1), 0) - y = max(min(y, h - 1), 0) - x = max(min(x, w - 1), 0) - return data_in[z][y][x] + indexes = np.mgrid[-1:3, -1:3, -1:3] def _get_patch(zint, yint, xint): # Get the surrounding values - p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] - for kk in range(4): - for jj in range(4): - for ii in range(4): - p[kk][jj][ii] = _get_input_value( - zint + kk - 1, - yint + jj - 1, - xint + ii - 1, - ) + indices = indexes.copy() + indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0) + indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0) + indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0) + p = data_in[indices[0], indices[1], indices[2]] return p for m in range(new_d): @@ -169,16 +158,9 @@ def _get_patch(zint, yint, xint): p = _get_patch(zint, yint, xint) - l = [[0 for i in range(4)] for j in range(4)] - for jj in range(4): - for ii in range(4): - l[jj][ii] = _cubic_kernel(p[jj][ii], wx) - - col0 = _cubic_kernel(l[0], wy) - col1 = _cubic_kernel(l[1], wy) - col2 = _cubic_kernel(l[2], wy) - col3 = _cubic_kernel(l[3], wy) - data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz) + l = np.sum(p * wx, axis=-1) + col = np.sum(l * wy, axis=-1) + data_out[m, j, k] = np.sum(col * wz) return data_out