Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up reference resize kernel #8592

Merged
merged 1 commit into from
Jul 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 41 additions & 59 deletions python/tvm/topi/testing/resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,52 @@ def resize3d_nearest(arr, scale, coordinate_transformation_mode):

def resize3d_linear(data_in, scale, coordinate_transformation_mode):
"""Trilinear 3d scaling using python"""
dtype = data_in.dtype
d, h, w = data_in.shape
new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)]
data_out = np.ones((new_d, new_h, new_w))

def _lerp(A, B, t):
return A * (1.0 - t) + B * t
indexes = np.mgrid[0:2, 0:2, 0:2]

def _in_coord(new_coord, in_shape, out_shape):
in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode)
coord0 = int(math.floor(in_coord))
coord1 = max(min(coord0 + 1, in_shape - 1), 0)
coord0 = max(coord0, 0)
coord_lerp = in_coord - math.floor(in_coord)
return coord0, coord1, coord_lerp
def _get_patch(zint, yint, xint):
# Get the surrounding values
indices = indexes.copy()
indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0)
indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0)
indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0)
p = data_in[indices[0], indices[1], indices[2]]
return p

for m in range(new_d):
for j in range(new_h):
for k in range(new_w):
z0, z1, z_lerp = _in_coord(m, d, new_d)
y0, y1, y_lerp = _in_coord(j, h, new_h)
x0, x1, x_lerp = _in_coord(k, w, new_w)

A0 = data_in[z0][y0][x0]
B0 = data_in[z0][y0][x1]
C0 = data_in[z0][y1][x0]
D0 = data_in[z0][y1][x1]
A1 = data_in[z1][y0][x0]
B1 = data_in[z1][y0][x1]
C1 = data_in[z1][y1][x0]
D1 = data_in[z1][y1][x1]

A = _lerp(A0, A1, z_lerp)
B = _lerp(B0, B1, z_lerp)
C = _lerp(C0, C1, z_lerp)
D = _lerp(D0, D1, z_lerp)
top = _lerp(A, B, x_lerp)
bottom = _lerp(C, D, x_lerp)

data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp))
in_z = get_inx(m, d, new_d, coordinate_transformation_mode)
in_y = get_inx(j, h, new_h, coordinate_transformation_mode)
in_x = get_inx(k, w, new_w, coordinate_transformation_mode)
zint = math.floor(in_z)
zfract = in_z - math.floor(in_z)

yint = math.floor(in_y)
yfract = in_y - math.floor(in_y)

xint = math.floor(in_x)
xfract = in_x - math.floor(in_x)

wz = np.array([1.0 - zfract, zfract], dtype=dtype)
wy = np.array([1.0 - yfract, yfract], dtype=dtype)
wx = np.array([1.0 - xfract, xfract], dtype=dtype)

p = _get_patch(zint, yint, xint)
l = np.sum(p * wx, axis=-1)
col = np.sum(l * wy, axis=-1)
data_out[m, j, k] = np.sum(col * wz)

return data_out


def resize3d_cubic(data_in, scale, coordinate_transformation_mode):
"""Tricubic 3d scaling using python"""
dtype = data_in.dtype
d, h, w = data_in.shape
new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)]
data_out = np.ones((new_d, new_h, new_w))
Expand All @@ -123,29 +124,17 @@ def _cubic_spline_weights(t, alpha=-0.5):
w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1
w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t
w4 = -alpha * t3 + alpha * t2
return [w1, w2, w3, w4]
return np.array([w1, w2, w3, w4])

def _cubic_kernel(inputs, w):
"""perform cubic interpolation in 1D"""
return sum([a_i * w_i for a_i, w_i in zip(inputs, w)])

def _get_input_value(z, y, x):
z = max(min(z, d - 1), 0)
y = max(min(y, h - 1), 0)
x = max(min(x, w - 1), 0)
return data_in[z][y][x]
indexes = np.mgrid[-1:3, -1:3, -1:3]

def _get_patch(zint, yint, xint):
# Get the surrounding values
p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)]
for kk in range(4):
for jj in range(4):
for ii in range(4):
p[kk][jj][ii] = _get_input_value(
zint + kk - 1,
yint + jj - 1,
xint + ii - 1,
)
indices = indexes.copy()
indices[0] = np.maximum(np.minimum(indexes[0] + zint, d - 1), 0)
indices[1] = np.maximum(np.minimum(indexes[1] + yint, h - 1), 0)
indices[2] = np.maximum(np.minimum(indexes[2] + xint, w - 1), 0)
p = data_in[indices[0], indices[1], indices[2]]
return p

for m in range(new_d):
Expand All @@ -169,16 +158,9 @@ def _get_patch(zint, yint, xint):

p = _get_patch(zint, yint, xint)

l = [[0 for i in range(4)] for j in range(4)]
for jj in range(4):
for ii in range(4):
l[jj][ii] = _cubic_kernel(p[jj][ii], wx)

col0 = _cubic_kernel(l[0], wy)
col1 = _cubic_kernel(l[1], wy)
col2 = _cubic_kernel(l[2], wy)
col3 = _cubic_kernel(l[3], wy)
data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz)
l = np.sum(p * wx, axis=-1)
col = np.sum(l * wy, axis=-1)
data_out[m, j, k] = np.sum(col * wz)

return data_out

Expand Down