Skip to content

Commit

Permalink
add test stub
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2022
1 parent c86b128 commit 2bf68c7
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,67 @@ def test_softmax():
verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)


def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0):
batch = 3
in_channel = 4
out_channel = 16
filter_h, filter_w = 3, 3
pad_h, pad_w = 1, 1
stride_h, stride_w = 1, 1
height, width = 32, 32

# schedule
if tensor_format == 0:
xshape = [batch, in_channel, height, width]
wshape = [in_channel, out_channel, filter_h, filter_w]
else:
xshape = [batch, height, width, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel]

oshape = xshape

dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
dx_np = tvm.topi.testing.conv2d_transpose_nchw_python(
dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)
)

dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
w = te.placeholder(wshape, name="dw", dtype=data_dtype)
dx = cudnn.conv_forward_backward_data(
dy,
w,
[pad_h, pad_w],
[stride_h, stride_w],
[dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
conv_dtype=conv_dtype,
groups=1,
)

s = te.create_schedule(Y.op)

# validation
dev = tvm.cuda(0)
f = tvm.build(s, [dy, w, x], "cuda --host=llvm", name="conv2d_backward_data")

dx_np = np.zeros(oshape).astype(data_dtype)

dy = tvm.nd.array(dy_np, dev)
w = tvm.nd.array(w_np, dev)
dx = tvm.nd.array(dx_np, dev)

f(dy, w, dx)
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=1e-2, rtol=1e-2)


@tvm.testing.requires_gpu
@requires_cudnn
def test_conv2d_backward_data():
verify_conv2d_backward_data("float32", "float32", tensor_format=0)


test_kwargs_default_2d = {
"tensor_format": 0,
"pad": [1, 1],
Expand Down Expand Up @@ -308,4 +369,5 @@ def conv_output_shape_kwargs(request):


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
# sys.exit(pytest.main(sys.argv))
test_conv2d_backward_data()

0 comments on commit 2bf68c7

Please sign in to comment.