Skip to content

Commit

Permalink
nhwc test also worked
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2022
1 parent c0609ab commit c2a34d4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
5 changes: 4 additions & 1 deletion python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,10 @@ def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, co
assert isinstance(x.shape[0], tvm.tir.expr.IntImm), "Dynamic batch is not supported for cudnn conv2d backwad data yet."
# TODO: fix oshape
oshape = x_shape
oshape[1] = w.shape[1]
if tensor_format == 0:
oshape[1] = w.shape[1]
else:
oshape[3] = w.shape[3]

algo = conv_backward_data_find_algo(
tensor_format,
Expand Down
22 changes: 13 additions & 9 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_softmax():
verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)


def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0):
def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-5):
batch = 3
in_channel = 4
out_channel = 16
Expand All @@ -249,19 +249,22 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0):
if tensor_format == 0:
xshape = [batch, in_channel, height, width]
wshape = [out_channel, in_channel, filter_h, filter_w]
oshape = xshape
oshape[1] = out_channel
ref_func = tvm.topi.testing.conv2d_transpose_nchw_python
else:
xshape = [batch, height, width, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel]

oshape = xshape
oshape[1] = out_channel
oshape = xshape
oshape[3] = out_channel
ref_func = lambda dy_np, w_np, strides, padding, out_pad: tvm.topi.testing.conv2d_transpose_nhwc_python(
dy_np, np.transpose(w_np, [1, 2, 3, 0]), "HWOI", strides, padding, out_pad
)

dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)

dx_np = tvm.topi.testing.conv2d_transpose_nchw_python(
dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)
)
dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0))

dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
w = te.placeholder(wshape, name="dw", dtype=data_dtype)
Expand Down Expand Up @@ -289,13 +292,14 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0):

f(dy, w, dx)
print(np.max(np.abs(dx.numpy() - dx_np)))
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=1e-5, rtol=1e-5)
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol)


@tvm.testing.requires_gpu
@requires_cudnn
def test_conv2d_backward_data():
verify_conv2d_backward_data("float32", "float32", tensor_format=0)
verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5)
verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2)


test_kwargs_default_2d = {
Expand Down

0 comments on commit c2a34d4

Please sign in to comment.