From e084caac860d45bae91bbb73d74a6678926b1d23 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Jan 2022 12:54:02 +0900 Subject: [PATCH 01/21] Dgrad nchw, nhwc, fp16 working commit 426e5dca446a27da49270f45171b58f1bfa21fa9 Author: Masahiro Masuda Date: Tue Jan 18 11:48:53 2022 +0900 black commit 211a58b80f4d0f0b5b0230720e41f35e50cb1eaf Author: Masahiro Masuda Date: Tue Jan 18 11:43:52 2022 +0900 fp16 also works commit c2a34d473b063873628bff00e51a44cd8e4d0e4f Author: Masahiro Masuda Date: Tue Jan 18 11:36:36 2022 +0900 nhwc test also worked commit c0609ab147fef30c230a94d16b6c1ba35f7dd9c0 Author: Masahiro Masuda Date: Tue Jan 18 11:21:23 2022 +0900 nchw test worked commit 2bf68c72763708151e9f49f09916a210b2547be8 Author: Masahiro Masuda Date: Tue Jan 18 10:41:35 2022 +0900 add test stub commit c86b1288d5e371f12cba4e1b1866966cb9264401 Author: Masahiro Masuda Date: Tue Jan 18 10:32:09 2022 +0900 add python definition stub commit 3166952f9673376801bf4b5b39eeb6f89452f30a Author: Masahiro Masuda Date: Tue Jan 18 06:57:18 2022 +0900 bwd filter compiled commit e311ba3d05c5f9424ecb952cb5a520ce81a0828a Author: Masahiro Masuda Date: Tue Jan 18 06:27:55 2022 +0900 dgrad compiled commit 47f35beb5eeeb7cbf9f6ec7cf8f5c80c65e8da46 Author: Masahiro Masuda Date: Tue Jan 18 06:16:43 2022 +0900 add dgrad stub commit ebed032d15b1c3895f541c46ce5d80b6dd769034 Author: Masahiro Masuda Date: Mon Jan 17 17:01:56 2022 +0900 cpplint commit 834f54a8c13512130e7d91ca0f54268dc06c5481 Author: Masahiro Masuda Date: Mon Jan 17 16:55:58 2022 +0900 remove cudnn get output commit dcbd9c95fdb8ffef9db9c2350430b270461a31c3 Author: Masahiro Masuda Date: Mon Jan 17 16:28:07 2022 +0900 more refactor commit 146464e8496fff972bdb1687c4e9d432fe3278d5 Author: Masahiro Masuda Date: Mon Jan 17 15:57:35 2022 +0900 Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc --- python/tvm/contrib/cudnn.py | 154 ++++++++++ .../topi/testing/conv2d_transpose_python.py | 4 +- src/runtime/contrib/cudnn/conv_backward.cc | 265 ++++++++++++++++++ src/runtime/contrib/cudnn/cudnn_utils.h | 4 +- tests/python/contrib/test_cudnn.py | 83 +++++- 5 files changed, 506 insertions(+), 4 deletions(-) create mode 100644 src/runtime/contrib/cudnn/conv_backward.cc diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 9b92c7cc2773..47d77999007d 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -352,6 +352,73 @@ def conv_find_algo( ) +def conv_backward_data_find_algo( + tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best algo for the given input. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + x_shape: list + input shape + w_shape: list + weight shape + y_shape: list + output shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + algo: int + algo chosen by CUDNN + """ + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, xshape, wshape = _prepare_global_func_params( + dims - 2, pad, stride, dilation, x_shape, w_shape + ) + yshape = np.array(y_shape, dtype=np.int32) + func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_data_find_algo") + return func( + tensor_format, + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(xshape), + _get_np_int32_array_handle(wshape), + _get_np_int32_array_handle(yshape), + data_dtype, + conv_dtype, + groups, + ) + + def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, conv_dtype, groups=1): """Create an extern op that compute 2D or 3D convolution with CuDNN @@ -496,6 +563,93 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co ) +def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): + """Create an extern op that compute 2D or 3D convolution with CuDNN + + Parameters + ---------- + x: Tensor + input feature map + w: Tensor + convolution weight + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + conv_mode: int + 0: CUDNN_CONVOLUTION + 1: CUDNN_CROSS_CORRELATION + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + conv_dtype: str + convolution type + groups: int + the number of groups + + Returns + ------- + y: Tensor + The result tensor + """ + dims = len(x.shape) + assert dims in (4, 5) + + conv_dtype = x.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) + + x_shape = list(x.shape) + + assert isinstance( + x.shape[0], tvm.tir.expr.IntImm + ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." + # TODO: fix oshape + oshape = x_shape + if tensor_format == 0: + oshape[1] = w.shape[1] + else: + oshape[3] = w.shape[3] + + algo = conv_backward_data_find_algo( + tensor_format, + pad, + stride, + dilation, + list(x.shape), + list(w.shape), + oshape, + x.dtype, + conv_dtype, + groups, + ) + + return te.extern( + oshape, + [x, w], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d.backward_data", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ins[0], + ins[1], + outs[0], + conv_dtype, + groups, + ), + name="y", + ) + + def softmax(x, axis=-1): """Compute softmax using CuDNN diff --git a/python/tvm/topi/testing/conv2d_transpose_python.py b/python/tvm/topi/testing/conv2d_transpose_python.py index a38d8bc9f031..678b5fe5d003 100644 --- a/python/tvm/topi/testing/conv2d_transpose_python.py +++ b/python/tvm/topi/testing/conv2d_transpose_python.py @@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): dilated_a_np.shape[2] + bpad_top + bpad_bottom, dilated_a_np.shape[3] + bpad_left + bpad_right, ) - ) + ).astype(a_np.dtype) padded_a_np[ :, :, @@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding): # convolution stage out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w - b_np = np.zeros((batch, out_c, out_h, out_w)) + b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype) for n in range(batch): for f in range(out_c): for c in range(in_c): diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc new file mode 100644 index 000000000000..2537c98264ba --- /dev/null +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Use external cudnn utils function + */ +#include +#include +#include + +#include "cudnn_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[], + const int stride[], const int dilation[], DLTensor* dy, DLTensor* w, + DLTensor* dx, const std::string& conv_dtype) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx->shape, w->shape, + dy->shape, dy->dtype, conv_dtype); + // Set Device + entry_ptr->conv_entry.device = dy->device; + // Set Algo + entry_ptr->conv_entry.bwd_data_algo = static_cast(algo); + + // Set workspace + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.bwd_data_algo, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + CUDNN_CALL(cudnnConvolutionBackwardData( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc, + dx->data)); +} + +void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int dy_dim[], const int w_dim[], + const int dx_dim[], const std::string& data_dtype, + const std::string& conv_dtype, TVMRetValue* ret) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + const int full_dims = dims + 2; + std::vector dy_dim_int64(full_dims); + std::vector w_dim_int64(full_dims); + std::vector dx_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + dy_dim_int64[i] = dy_dim[i]; + w_dim_int64[i] = w_dim[i]; + dx_dim_int64[i] = dx_dim[i]; + } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dx_dim_int64.data(), + w_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); + + int returned_algo_count = 0; + + cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT]; + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector bwd_data_algo_names{ + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", // non-deterministic + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"}; + + auto best_algo = perf_results[0].algo; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing " + << bwd_data_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << bwd_data_algo_names[perf_results[i].algo] + << " - time: " << perf_results[i].time << " ms" + << ", Memory: " << perf_results[i].memory; + } + + ret[0] = best_algo; +} + +void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups, + const int pad[], const int stride[], const int dilation[], + DLTensor* x, DLTensor* dy, DLTensor* dw, + const std::string& conv_dtype) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, dw->shape, + dy->shape, x->dtype, conv_dtype); + // Set Device + entry_ptr->conv_entry.device = x->device; + // Set Algo + entry_ptr->conv_entry.bwd_filter_algo = static_cast(algo); + + // Set workspace + size_t workspace_size = 0; + CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.bwd_filter_algo, &workspace_size)); + entry_ptr->conv_entry.UpdateWorkspace(workspace_size); + CUDNN_CALL(cudnnConvolutionBackwardFilter( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.output_desc, dy->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.filter_desc, dw->data)); +} + +void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int dy_dim[], + const int dw_dim[], const std::string& data_dtype, + const std::string& conv_dtype, TVMRetValue* ret) { + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + const int full_dims = dims + 2; + std::vector x_dim_int64(full_dims); + std::vector dy_dim_int64(full_dims); + std::vector dw_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + x_dim_int64[i] = x_dim[i]; + dy_dim_int64[i] = dy_dim[i]; + dw_dim_int64[i] = dw_dim[i]; + } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), + dw_dim_int64.data(), dy_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); + + int returned_algo_count = 0; + + cudnnConvolutionBwdFilterAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT]; + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector bwd_filter_algo_names{ + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", // non-deterministic + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED", + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING", + }; + + auto best_algo = perf_results[0].algo; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter algorithms, choosing " + << bwd_filter_algo_names[best_algo]; + for (int i = 0; i < returned_algo_count; ++i) { + LOG(INFO) << "\t\t" << i << ") " << bwd_filter_algo_names[perf_results[i].algo] + << " - time: " << perf_results[i].time << " ms" + << ", Memory: " << perf_results[i].memory; + } + + ret[0] = best_algo; +} + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* dy = args[9]; + DLTensor* w = args[10]; + DLTensor* dx = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx, + conv_dtype); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* dy_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* dx_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim, + data_dtype, conv_dtype, ret); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* x = args[9]; + DLTensor* dy = args[10]; + DLTensor* dw = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, dy, + dw, conv_dtype); + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* dy_dim = static_cast(static_cast(args[6])); + int* dw_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, dy_dim, dw_dim, + data_dtype, conv_dtype, ret); + }); + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 89de0e90df90..426ccfdf37af 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -67,12 +67,14 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { struct ConvEntry { cudnnConvolutionDescriptor_t conv_desc; cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION}; - cudnnFilterDescriptor_t filter_desc; cudnnDataType_t data_type; cudnnTensorFormat_t tensor_format; cudnnTensorDescriptor_t input_desc; + cudnnFilterDescriptor_t filter_desc; cudnnTensorDescriptor_t output_desc; cudnnConvolutionFwdAlgo_t fwd_algo; + cudnnConvolutionBwdDataAlgo_t bwd_data_algo; + cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo; // cudnnMathType_t math_type; Device device; runtime::DeviceAPI* cuda_api; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index bc2cc80f362d..bc7669d3d95c 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -236,6 +236,86 @@ def test_softmax(): verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True) +def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): + batch = 3 + in_channel = 4 + out_channel = 16 + filter_h, filter_w = 3, 3 + pad_h, pad_w = 1, 1 + stride_h, stride_w = 1, 1 + height, width = 32, 32 + + # schedule + if tensor_format == 0: + xshape = [batch, in_channel, height, width] + wshape = [out_channel, in_channel, filter_h, filter_w] + oshape = xshape + oshape[1] = out_channel + ref_func = tvm.topi.testing.conv2d_transpose_nchw_python + else: + xshape = [batch, height, width, in_channel] + wshape = [out_channel, filter_h, filter_w, in_channel] + oshape = xshape + oshape[3] = out_channel + ref_func = lambda dy_np, w_np, strides, padding, out_pad: tvm.topi.testing.conv2d_transpose_nhwc_python( + dy_np, np.transpose(w_np, [1, 2, 3, 0]), "HWOI", strides, padding, out_pad + ) + + dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + + if data_dtype == "float16": + dx_np = ref_func( + dy_np.astype("float32"), + w_np.astype("float32"), + (stride_h, stride_w), + (pad_h, pad_w), + (0, 0), + ) + dx_np = dx_np.astype("float16") + else: + dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)) + + dy = te.placeholder(oshape, name="dy", dtype=data_dtype) + w = te.placeholder(wshape, name="dw", dtype=data_dtype) + dx = cudnn.conv_backward_data( + dy, + w, + [pad_h, pad_w], + [stride_h, stride_w], + [1, 1], + conv_mode=1, + tensor_format=tensor_format, + conv_dtype=conv_dtype, + groups=1, + ) + + s = te.create_schedule(dx.op) + + # validation + dev = tvm.cuda(0) + f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", name="conv2d_backward_data") + + dy = tvm.nd.array(dy_np, dev) + w = tvm.nd.array(w_np, dev) + dx = tvm.nd.array(dx_np, dev) + + f(dy, w, dx) + print(np.max(np.abs(dx.numpy() - dx_np))) + print(np.mean(np.abs(dx.numpy() - dx_np))) + tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol) + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv2d_backward_data(): + verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5) + verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2) + # The scipy convolve function does not support fp16, so the reference will be computed with + # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok). + verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) + + test_kwargs_default_2d = { "tensor_format": 0, "pad": [1, 1], @@ -308,4 +388,5 @@ def conv_output_shape_kwargs(request): if __name__ == "__main__": - sys.exit(pytest.main(sys.argv)) + # sys.exit(pytest.main(sys.argv)) + test_conv2d_backward_data() From 204eb137f0938aace3d6c33741ecec9ec1e05609 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Jan 2022 17:53:41 +0900 Subject: [PATCH 02/21] add python function for cudnn wgrad --- python/tvm/contrib/cudnn.py | 185 +++++++++++++++++++++++++++--------- 1 file changed, 139 insertions(+), 46 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 47d77999007d..57a49831a1af 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -357,9 +357,9 @@ def conv_backward_data_find_algo( pad, stride, dilation, - x_shape, + dy_shape, w_shape, - y_shape, + dx_shape, data_dtype, conv_dtype, groups=1, @@ -396,13 +396,13 @@ def conv_backward_data_find_algo( algo: int algo chosen by CUDNN """ - dims = len(x_shape) + dims = len(dy_shape) assert dims in (4, 5) - pad, stride, dilation, xshape, wshape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, w_shape + pad, stride, dilation, dy_shape, w_shape = _prepare_global_func_params( + dims - 2, pad, stride, dilation, dy_shape, w_shape ) - yshape = np.array(y_shape, dtype=np.int32) + dx_shape = np.array(dx_shape, dtype=np.int32) func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_data_find_algo") return func( tensor_format, @@ -410,9 +410,76 @@ def conv_backward_data_find_algo( _get_np_int32_array_handle(pad), _get_np_int32_array_handle(stride), _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(xshape), - _get_np_int32_array_handle(wshape), - _get_np_int32_array_handle(yshape), + _get_np_int32_array_handle(dy_shape), + _get_np_int32_array_handle(w_shape), + _get_np_int32_array_handle(dx_shape), + data_dtype, + conv_dtype, + groups, + ) + + +def conv_backward_filter_find_algo( + tensor_format, + pad, + stride, + dilation, + x_shape, + dy_shape, + dw_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best algo for the given input. + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + 2: CUDNN_TENSOR_NCHW_VECT_C + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + x_shape: list + input shape + w_shape: list + weight shape + y_shape: list + output shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + algo: int + algo chosen by CUDNN + """ + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, x_shape, dy_shape = _prepare_global_func_params( + dims - 2, pad, stride, dilation, x_shape, dy_shape + ) + dw_shape = np.array(dw_shape, dtype=np.int32) + func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_filter_find_algo") + return func( + tensor_format, + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(x_shape), + _get_np_int32_array_handle(dy_shape), + _get_np_int32_array_handle(dw_shape), data_dtype, conv_dtype, groups, @@ -563,38 +630,63 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co ) -def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): - """Create an extern op that compute 2D or 3D convolution with CuDNN +def conv_backward_data(dy, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): + dims = len(dy.shape) + assert dims in (4, 5) - Parameters - ---------- - x: Tensor - input feature map - w: Tensor - convolution weight - pad: int or list - padding - stride: int or list - stride - dilation: int or list - dilation - conv_mode: int - 0: CUDNN_CONVOLUTION - 1: CUDNN_CROSS_CORRELATION - tensor_format: int - 0: CUDNN_TENSOR_NCHW - 1: CUDNN_TENSOR_NHWC - 2: CUDNN_TENSOR_NCHW_VECT_C - conv_dtype: str - convolution type - groups: int - the number of groups + conv_dtype = dy.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) - Returns - ------- - y: Tensor - The result tensor - """ + x_shape = list(dy.shape) + + assert isinstance( + dy.shape[0], tvm.tir.expr.IntImm + ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." + # TODO: fix oshape + oshape = x_shape + if tensor_format == 0: + oshape[1] = w.shape[1] + else: + oshape[3] = w.shape[3] + + algo = conv_backward_data_find_algo( + tensor_format, + pad, + stride, + dilation, + list(dy.shape), + list(w.shape), + oshape, + dy.dtype, + conv_dtype, + groups, + ) + + return te.extern( + oshape, + [dy, w], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.conv2d.backward_data", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ins[0], + ins[1], + outs[0], + conv_dtype, + groups, + ), + name="dx", + ) + + +def conv_backward_filter(x, dy, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): dims = len(x.shape) assert dims in (4, 5) @@ -609,17 +701,17 @@ def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, co # TODO: fix oshape oshape = x_shape if tensor_format == 0: - oshape[1] = w.shape[1] + oshape[1] = dy.shape[1] else: - oshape[3] = w.shape[3] + oshape[3] = dy.shape[3] - algo = conv_backward_data_find_algo( + algo = conv_backward_filter_find_algo( tensor_format, pad, stride, dilation, list(x.shape), - list(w.shape), + list(dy.shape), oshape, x.dtype, conv_dtype, @@ -628,9 +720,9 @@ def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, co return te.extern( oshape, - [x, w], + [x, dy], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.cudnn.conv2d.backward_data", + "tvm.contrib.cudnn.conv2d.backward_filter", conv_mode, tensor_format, algo, @@ -646,10 +738,11 @@ def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, co conv_dtype, groups, ), - name="y", + name="dw", ) + def softmax(x, axis=-1): """Compute softmax using CuDNN From 03bb809d5d6073fda08442006e60565274b4a2d5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Jan 2022 18:05:48 +0900 Subject: [PATCH 03/21] adding wgrad test --- tests/python/contrib/test_cudnn.py | 95 ++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index bc7669d3d95c..dd5a8444c857 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -316,6 +316,100 @@ def test_conv2d_backward_data(): verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) +def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): + N, C, H, W = x_np.shape + _, K, P, Q = dy_np.shape + R, S = kernel_size + pad_h, pad_w = padding + stride_h, stride_w = stride + dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + + for k in range(K): + for r in range(R): + for s in range(S): + for c in range(C): + acc = 0 + for n in range(N): + for p in range(P): + for q in range(Q): + coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) + + if ( + coord[2] < H + and coord[2] >= 0 + and coord[3] < W + and coord[3] >= 0 + ): + acc += dy_np[n, k, p, q] * x_np[coord] + + dw[k, c, r, s] = acc + + return dw + + +def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): + batch = 3 + in_channel = 4 + out_channel = 16 + filter_h, filter_w = 3, 3 + pad_h, pad_w = 1, 1 + stride_h, stride_w = 1, 1 + height, width = 32, 32 + + # schedule + if tensor_format == 0: + x_shape = [batch, in_channel, height, width] + dy_shape = [batch, out_channel, height, width] + else: + x_shape = [batch, height, width, in_channel] + dy_shape = [batch, height, width, out_channel] + + x_np = np.random.uniform(-1, 1, x_shape).astype(data_dtype) + dy_np = np.random.uniform(-1, 1, dy_shape).astype(data_dtype) + + dw_np = conv2d_backward_weight_nchw_python( + x_np, dy_np, (filter_h, filter_w), (stride_h, stride_w), (pad_h, pad_w) + ) + + x = te.placeholder(x_shape, name="x", dtype=data_dtype) + dy = te.placeholder(dy_shape, name="dy", dtype=data_dtype) + dw = cudnn.conv_backward_filter( + x, + dy, + [pad_h, pad_w], + [stride_h, stride_w], + [1, 1], + conv_mode=1, + tensor_format=tensor_format, + conv_dtype=conv_dtype, + ) + + s = te.create_schedule(dw.op) + + # validation + dev = tvm.cuda(0) + f = tvm.build(s, [x, dy, dw], "cuda --host=llvm", name="conv2d_backward_filter") + + x = tvm.nd.array(x_np, dev) + dy = tvm.nd.array(dy_np, dev) + dw = tvm.nd.array(dw_np, dev) + + f(x, dy, dw) + print(np.max(np.abs(dw.numpy() - dw_np))) + print(np.mean(np.abs(dw.numpy() - dw_np))) + tvm.testing.assert_allclose(dw.numpy(), dw_np, atol=tol, rtol=tol) + + +@tvm.testing.requires_gpu +@requires_cudnn +def test_conv2d_backward_filter(): + verify_conv2d_backward_filter("float32", "float32", tensor_format=0, tol=1e-5) + # verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-2) + # # The scipy convolve function does not support fp16, so the reference will be computed with + # # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok). + # verify_conv2d_backward_filter("float16", "float16", tensor_format=1, tol=1e-1) + + test_kwargs_default_2d = { "tensor_format": 0, "pad": [1, 1], @@ -390,3 +484,4 @@ def conv_output_shape_kwargs(request): if __name__ == "__main__": # sys.exit(pytest.main(sys.argv)) test_conv2d_backward_data() + test_conv2d_backward_filter() From 61060bb0fe369754955b7a8f968fb92cb8660243 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Jan 2022 18:07:07 +0900 Subject: [PATCH 04/21] black --- python/tvm/contrib/cudnn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 57a49831a1af..5c2fd181fed3 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -630,7 +630,9 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co ) -def conv_backward_data(dy, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): +def conv_backward_data( + dy, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 +): dims = len(dy.shape) assert dims in (4, 5) @@ -686,7 +688,9 @@ def conv_backward_data(dy, w, pad, stride, dilation, conv_mode, tensor_format, c ) -def conv_backward_filter(x, dy, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1): +def conv_backward_filter( + x, dy, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 +): dims = len(x.shape) assert dims in (4, 5) @@ -742,7 +746,6 @@ def conv_backward_filter(x, dy, pad, stride, dilation, conv_mode, tensor_format, ) - def softmax(x, axis=-1): """Compute softmax using CuDNN From 5f13439b999fd80f7e7819e5c6dcd9b9e52b6745 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 Jan 2022 19:24:06 +0900 Subject: [PATCH 05/21] wgrad nchw and nhwc worked --- python/tvm/contrib/cudnn.py | 4 +- tests/python/contrib/test_cudnn.py | 77 ++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 5c2fd181fed3..32cc03c88af1 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -705,9 +705,9 @@ def conv_backward_filter( # TODO: fix oshape oshape = x_shape if tensor_format == 0: - oshape[1] = dy.shape[1] + oshape = [dy.shape[1], x_shape[1], 3, 3] else: - oshape[3] = dy.shape[3] + oshape = [dy.shape[3], 3, 3, x_shape[3]] algo = conv_backward_filter_find_algo( tensor_format, diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index dd5a8444c857..e76953eaa57f 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -316,13 +316,21 @@ def test_conv2d_backward_data(): verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) -def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): - N, C, H, W = x_np.shape - _, K, P, Q = dy_np.shape +def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): R, S = kernel_size + if layout == "NCHW": + N, C, H, W = x_np.shape + _, K, P, Q = dy_np.shape + w_shape = (K, C, R, S) + else: + N, H, W, C = x_np.shape + _, P, Q, K = dy_np.shape + w_shape = (K, R, S, C) + pad_h, pad_w = padding stride_h, stride_w = stride - dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + + dw = np.zeros(w_shape).astype(dy_np.dtype) for k in range(K): for r in range(R): @@ -332,17 +340,42 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding for n in range(N): for p in range(P): for q in range(Q): - coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) - - if ( - coord[2] < H - and coord[2] >= 0 - and coord[3] < W - and coord[3] >= 0 - ): - acc += dy_np[n, k, p, q] * x_np[coord] - dw[k, c, r, s] = acc + if layout == "NCHW": + coord = ( + n, + c, + p * stride_h - pad_h + r, + q * stride_w - pad_w + s, + ) + + if ( + coord[2] < H + and coord[2] >= 0 + and coord[3] < W + and coord[3] >= 0 + ): + acc += dy_np[n, k, p, q] * x_np[coord] + else: + coord = ( + n, + p * stride_h - pad_h + r, + q * stride_w - pad_w + s, + c, + ) + + if ( + coord[1] < H + and coord[1] >= 0 + and coord[2] < W + and coord[2] >= 0 + ): + acc += dy_np[n, p, q, k] * x_np[coord] + + if layout == "NCHW": + dw[k, c, r, s] = acc + else: + dw[k, r, s, c] = acc return dw @@ -367,8 +400,13 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 x_np = np.random.uniform(-1, 1, x_shape).astype(data_dtype) dy_np = np.random.uniform(-1, 1, dy_shape).astype(data_dtype) - dw_np = conv2d_backward_weight_nchw_python( - x_np, dy_np, (filter_h, filter_w), (stride_h, stride_w), (pad_h, pad_w) + dw_np = conv2d_backward_weight_python( + dy_np, + x_np, + (filter_h, filter_w), + (stride_h, stride_w), + (pad_h, pad_w), + "NCHW" if tensor_format == 0 else "NHWC", ) x = te.placeholder(x_shape, name="x", dtype=data_dtype) @@ -404,10 +442,7 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 @requires_cudnn def test_conv2d_backward_filter(): verify_conv2d_backward_filter("float32", "float32", tensor_format=0, tol=1e-5) - # verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-2) - # # The scipy convolve function does not support fp16, so the reference will be computed with - # # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok). - # verify_conv2d_backward_filter("float16", "float16", tensor_format=1, tol=1e-1) + verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-5) test_kwargs_default_2d = { @@ -483,5 +518,5 @@ def conv_output_shape_kwargs(request): if __name__ == "__main__": # sys.exit(pytest.main(sys.argv)) - test_conv2d_backward_data() + # test_conv2d_backward_data() test_conv2d_backward_filter() From 64f76c85378ea5b24b425bddd27f719cb4820125 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 11:07:24 +0900 Subject: [PATCH 06/21] remove bwd algo name stuff --- python/tvm/contrib/cudnn.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 32cc03c88af1..eb518225e22b 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -36,33 +36,6 @@ "CUDNN_CONVOLUTION_FWD_ALGO_COUNT", ] -_BWD_FILTER_ALGOS = [ - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0", - # non-deterministic - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3", - # non-deterministic, algo0 with workspaceS - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD", - # not implemented - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT", -] - -_BWD_DATA_ALGOS = [ - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", - # non-deterministic - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED", - "CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT", -] - -_ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"] - def exists(): """ From 0ede2164fb494dc77a0ab2a297285e8062411db5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 14:50:24 +0900 Subject: [PATCH 07/21] compute output shape properly --- python/tvm/contrib/cudnn.py | 95 ++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 16 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index eb518225e22b..06fc8ffeaed3 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -258,6 +258,74 @@ def conv_output_shape( return output +def conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy_shape, w_shape, data_dtype, conv_dtype, groups=1 +): + """Get output shape of conv2d gradient with respect to data + + Paramters + --------- + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + dy_shape: list + output gradient shape + w_shape: list + weight shape + data_dtype: str + data type + conv_dtype: str + convolution type + groups: int + number of groups + + Returns + ------- + oshape: list + output shape + """ + + assert len(dy_shape) == len(w_shape) + assert len(dy_shape) == 4 + + if tensor_format == 0: + N = dy_shape[0] + K = w_shape[0] + C = w_shape[1] + P, Q = dy_shape[2:] + R, S = w_shape[2:] + elif tensor_format == 1: + N = dy_shape[0] + K = w_shape[0] + C = w_shape[1] + P, Q = dy_shape[2:] + R, S = w_shape[2:] + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + input_dims = [] + for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip( + dy_shape, w_shape, pad, stride, dilation + ): + input_dim = (dy_shape_i - 1) * stride_i - 2 * pad_i + (((w_shape_i - 1) * dilation_i) + 1) + input_dims.append(input_dim) + + if tensor_format == 0: + output = [N, C, *input_dims] + elif tensor_format == 1: + output = [N, *input_dims, C] + else: + raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + + return output + + def conv_find_algo( tensor_format, pad, @@ -612,17 +680,11 @@ def conv_backward_data( conv_dtype = dy.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) - x_shape = list(dy.shape) - assert isinstance( dy.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - # TODO: fix oshape - oshape = x_shape - if tensor_format == 0: - oshape[1] = w.shape[1] - else: - oshape[3] = w.shape[3] + + x_shape = conv_dgrad_shape(tensor_format, pad, stride, dilation, dy.shape, w.shape, dy.dtype, conv_dtype, groups) algo = conv_backward_data_find_algo( tensor_format, @@ -631,14 +693,14 @@ def conv_backward_data( dilation, list(dy.shape), list(w.shape), - oshape, + x_shape, dy.dtype, conv_dtype, groups, ) return te.extern( - oshape, + x_shape, [dy, w], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv2d.backward_data", @@ -662,10 +724,10 @@ def conv_backward_data( def conv_backward_filter( - x, dy, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 + x, dy, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 ): dims = len(x.shape) - assert dims in (4, 5) + assert dims == 4 conv_dtype = x.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) @@ -675,12 +737,13 @@ def conv_backward_filter( assert isinstance( x.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - # TODO: fix oshape + oshape = x_shape + filter_h, filter_w = kernel_size if tensor_format == 0: - oshape = [dy.shape[1], x_shape[1], 3, 3] + oshape = [dy.shape[1], x_shape[1], filter_h, filter_w] else: - oshape = [dy.shape[3], 3, 3, x_shape[3]] + oshape = [dy.shape[3], filter_h, filter_w, x_shape[3]] algo = conv_backward_filter_find_algo( tensor_format, @@ -697,7 +760,7 @@ def conv_backward_filter( return te.extern( oshape, - [x, dy], + [dy, x], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv2d.backward_filter", conv_mode, From 0771b5f497ffb1ed8245010920bcc6332cbd8a32 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 14:50:44 +0900 Subject: [PATCH 08/21] swap arg order in wgrad --- src/runtime/contrib/cudnn/conv_backward.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 2537c98264ba..f615ffe3ff0d 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -108,7 +108,7 @@ void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], con void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int groups, const int pad[], const int stride[], const int dilation[], - DLTensor* x, DLTensor* dy, DLTensor* dw, + DLTensor* dy, DLTensor* x, DLTensor* dw, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode @@ -137,7 +137,7 @@ void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int gro } void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], const int stride[], - const int dilation[], const int x_dim[], const int dy_dim[], + const int dilation[], const int dy_dim[], const int x_dim[], const int dw_dim[], const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); @@ -233,13 +233,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") stride_v[i] = args[5 + i]; dilation_v[i] = args[7 + i]; } - DLTensor* x = args[9]; - DLTensor* dy = args[10]; + DLTensor* dy = args[9]; + DLTensor* x = args[10]; DLTensor* dw = args[11]; std::string conv_dtype = args[12]; int groups = args[13]; - ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, dy, + ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, dx, dw, conv_dtype); }); @@ -250,14 +250,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") int* pad = static_cast(static_cast(args[2])); int* stride = static_cast(static_cast(args[3])); int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* dy_dim = static_cast(static_cast(args[6])); + int* dy_dim = static_cast(static_cast(args[5])); + int* x_dim = static_cast(static_cast(args[6])); int* dw_dim = static_cast(static_cast(args[7])); std::string data_dtype = args[8]; std::string conv_dtype = args[9]; int groups = args[10]; - BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, dy_dim, dw_dim, + BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, x_dim, dw_dim, data_dtype, conv_dtype, ret); }); From 282dd05849d0dd431d3ded87ee9f66577e42a16a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 14:51:30 +0900 Subject: [PATCH 09/21] add kernel size arg in test --- tests/python/contrib/test_cudnn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index e76953eaa57f..c912440dea3c 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -414,6 +414,7 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 dw = cudnn.conv_backward_filter( x, dy, + (filter_h, filter_w), [pad_h, pad_w], [stride_h, stride_w], [1, 1], From ce8fde09fea6ee5cad10375e7d15a692122e0a45 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 14:52:23 +0900 Subject: [PATCH 10/21] black --- python/tvm/contrib/cudnn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 06fc8ffeaed3..5353e9b5a91b 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -684,7 +684,9 @@ def conv_backward_data( dy.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - x_shape = conv_dgrad_shape(tensor_format, pad, stride, dilation, dy.shape, w.shape, dy.dtype, conv_dtype, groups) + x_shape = conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy.shape, w.shape, dy.dtype, conv_dtype, groups + ) algo = conv_backward_data_find_algo( tensor_format, @@ -724,7 +726,7 @@ def conv_backward_data( def conv_backward_filter( - x, dy, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 + x, dy, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 ): dims = len(x.shape) assert dims == 4 From cebc8d1770346d6a549fb2ea6ceef64a53e32019 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 16:03:36 +0900 Subject: [PATCH 11/21] cleanup --- python/tvm/contrib/cudnn.py | 229 +++++++++++++++------- src/runtime/contrib/cudnn/conv_forward.cc | 2 +- tests/python/contrib/test_cudnn.py | 2 +- 3 files changed, 155 insertions(+), 78 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 5353e9b5a91b..9c4bd20abdba 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -326,7 +326,8 @@ def conv_dgrad_shape( return output -def conv_find_algo( +def _conv_find_algo( + func_name, tensor_format, pad, stride, @@ -338,7 +339,46 @@ def conv_find_algo( conv_dtype, groups=1, ): - """Choose the best algo for the given input. + """ + Common function to choose the best cudnn convolution algorithm for the given input + and the convolution type. + """ + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, xshape, wshape = _prepare_global_func_params( + dims - 2, pad, stride, dilation, x_shape, w_shape + ) + yshape = np.array(y_shape, dtype=np.int32) + func = tvm._ffi.get_global_func(func_name) + return func( + tensor_format, + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(xshape), + _get_np_int32_array_handle(wshape), + _get_np_int32_array_handle(yshape), + data_dtype, + conv_dtype, + groups, + ) + + +def conv_forward_find_algo( + tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, + data_dtype, + conv_dtype, + groups=1, +): + """Choose the best forward algorithm for the given input. Paramters --------- @@ -370,26 +410,18 @@ def conv_find_algo( algo: int algo chosen by CUDNN """ - dims = len(x_shape) - assert dims in (4, 5) - - pad, stride, dilation, xshape, wshape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, w_shape - ) - yshape = np.array(y_shape, dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.find_algo") - return func( + return _conv_find_algo( + "tvm.contrib.cudnn.conv.forward_find_algo", tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(xshape), - _get_np_int32_array_handle(wshape), - _get_np_int32_array_handle(yshape), + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, data_dtype, conv_dtype, - groups, + groups=1, ) @@ -405,7 +437,7 @@ def conv_backward_data_find_algo( conv_dtype, groups=1, ): - """Choose the best algo for the given input. + """Choose the best backward data algorithm for the given input. Paramters --------- @@ -419,12 +451,12 @@ def conv_backward_data_find_algo( stride dilation: int or list dilation - x_shape: list - input shape + dy_shape: list + output gradient shape w_shape: list weight shape - y_shape: list - output shape + dx_shape: list + dgrad shape data_dtype: str data type conv_dtype: str @@ -437,26 +469,18 @@ def conv_backward_data_find_algo( algo: int algo chosen by CUDNN """ - dims = len(dy_shape) - assert dims in (4, 5) - - pad, stride, dilation, dy_shape, w_shape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, dy_shape, w_shape - ) - dx_shape = np.array(dx_shape, dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_data_find_algo") - return func( + return _conv_find_algo( + "tvm.contrib.cudnn.conv.backward_data_find_algo", tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(dy_shape), - _get_np_int32_array_handle(w_shape), - _get_np_int32_array_handle(dx_shape), + pad, + stride, + dilation, + dy_shape, + w_shape, + dx_shape, data_dtype, conv_dtype, - groups, + groups=1, ) @@ -465,8 +489,8 @@ def conv_backward_filter_find_algo( pad, stride, dilation, - x_shape, dy_shape, + x_shape, dw_shape, data_dtype, conv_dtype, @@ -486,12 +510,12 @@ def conv_backward_filter_find_algo( stride dilation: int or list dilation + dy_shape: list + output gradient shape x_shape: list - input shape - w_shape: list weight shape - y_shape: list - output shape + dw_shape: list + wgrad shape data_dtype: str data type conv_dtype: str @@ -504,26 +528,18 @@ def conv_backward_filter_find_algo( algo: int algo chosen by CUDNN """ - dims = len(x_shape) - assert dims in (4, 5) - - pad, stride, dilation, x_shape, dy_shape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, dy_shape - ) - dw_shape = np.array(dw_shape, dtype=np.int32) - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.backward_filter_find_algo") - return func( + return _conv_find_algo( + "tvm.contrib.cudnn.conv.backward_filter_find_algo", tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(x_shape), - _get_np_int32_array_handle(dy_shape), - _get_np_int32_array_handle(dw_shape), + pad, + stride, + dilation, + dy_shape, + x_shape, + dw_shape, data_dtype, conv_dtype, - groups, + groups=1, ) @@ -589,7 +605,7 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co if tensor_format == 1 and conv_dtype == "int32": algo = 1 else: - algo = conv_find_algo( + algo = conv_forward_find_algo( tensor_format, pad, stride, @@ -674,8 +690,38 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co def conv_backward_data( dy, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 ): + """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to data. + + Parameters + ---------- + dy: Tensor + output gradient + w: Tensor + convolution weight + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + conv_mode: int + 0: CUDNN_CONVOLUTION + 1: CUDNN_CROSS_CORRELATION + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + conv_dtype: str + convolution type + groups: int + the number of groups + + Returns + ------- + dx: Tensor + dgrad tensor + """ dims = len(dy.shape) - assert dims in (4, 5) + assert dims == 4 conv_dtype = dy.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) @@ -684,7 +730,7 @@ def conv_backward_data( dy.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - x_shape = conv_dgrad_shape( + dx_shape = conv_dgrad_shape( tensor_format, pad, stride, dilation, dy.shape, w.shape, dy.dtype, conv_dtype, groups ) @@ -695,7 +741,7 @@ def conv_backward_data( dilation, list(dy.shape), list(w.shape), - x_shape, + dx_shape, dy.dtype, conv_dtype, groups, @@ -726,42 +772,73 @@ def conv_backward_data( def conv_backward_filter( - x, dy, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 + dy, x, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 ): + """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to data. + + Parameters + ---------- + dy: Tensor + output gradient + x: Tensor + input tensor + kernel_size: a pair of int + The spatial size of the corresponding forward convolution kernel + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation + conv_mode: int + 0: CUDNN_CONVOLUTION + 1: CUDNN_CROSS_CORRELATION + tensor_format: int + 0: CUDNN_TENSOR_NCHW + 1: CUDNN_TENSOR_NHWC + conv_dtype: str + convolution type + groups: int + the number of groups + + Returns + ------- + dw: Tensor + wgrad tensor + """ dims = len(x.shape) assert dims == 4 conv_dtype = x.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) + filter_h, filter_w = kernel_size x_shape = list(x.shape) assert isinstance( x.shape[0], tvm.tir.expr.IntImm - ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." + ), "Dynamic batch is not supported for cudnn conv2d backwad filter yet." - oshape = x_shape - filter_h, filter_w = kernel_size if tensor_format == 0: - oshape = [dy.shape[1], x_shape[1], filter_h, filter_w] + dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w] else: - oshape = [dy.shape[3], filter_h, filter_w, x_shape[3]] + dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]] algo = conv_backward_filter_find_algo( tensor_format, pad, stride, dilation, - list(x.shape), list(dy.shape), - oshape, + list(x.shape), + dw_shape, x.dtype, conv_dtype, groups, ) return te.extern( - oshape, + dw_shape, [dy, x], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv2d.backward_filter", diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index b7476e5106fa..4b53fefe9bcd 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -147,7 +147,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; int dims = args[1]; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index c912440dea3c..c2bf52ccdd19 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -412,8 +412,8 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 x = te.placeholder(x_shape, name="x", dtype=data_dtype) dy = te.placeholder(dy_shape, name="dy", dtype=data_dtype) dw = cudnn.conv_backward_filter( - x, dy, + x, (filter_h, filter_w), [pad_h, pad_w], [stride_h, stride_w], From 62c96c9cc51e969d660f0b99c07d0c1ca40fd800 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 19 Jan 2022 16:12:20 +0900 Subject: [PATCH 12/21] more fix --- python/tvm/contrib/cudnn.py | 16 +++++++--------- src/runtime/contrib/cudnn/conv_backward.cc | 2 +- src/runtime/contrib/cudnn/conv_forward.cc | 2 +- tests/python/contrib/test_cudnn.py | 10 +++------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 9c4bd20abdba..82d3fc64a1bd 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -307,7 +307,7 @@ def conv_dgrad_shape( P, Q = dy_shape[2:] R, S = w_shape[2:] else: - raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + raise ValueError("Unsupported CuDNN tensor format: '{}'".format(tensor_format)) input_dims = [] for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip( @@ -318,10 +318,8 @@ def conv_dgrad_shape( if tensor_format == 0: output = [N, C, *input_dims] - elif tensor_format == 1: - output = [N, *input_dims, C] else: - raise ValueError("Unknown CuDNN tensor format: '{}'".format(tensor_format)) + output = [N, *input_dims, C] return output @@ -421,7 +419,7 @@ def conv_forward_find_algo( y_shape, data_dtype, conv_dtype, - groups=1, + groups, ) @@ -480,7 +478,7 @@ def conv_backward_data_find_algo( dx_shape, data_dtype, conv_dtype, - groups=1, + groups, ) @@ -496,7 +494,7 @@ def conv_backward_filter_find_algo( conv_dtype, groups=1, ): - """Choose the best algo for the given input. + """Choose the best backward filter algorithm for the given input. Paramters --------- @@ -539,7 +537,7 @@ def conv_backward_filter_find_algo( dw_shape, data_dtype, conv_dtype, - groups=1, + groups, ) @@ -774,7 +772,7 @@ def conv_backward_data( def conv_backward_filter( dy, x, kernel_size, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 ): - """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to data. + """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to weight. Parameters ---------- diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index f615ffe3ff0d..417183335a64 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -18,7 +18,7 @@ */ /*! - * \file Use external cudnn utils function + * \file cuDNN kernel calls for backward algorithms. */ #include #include diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 4b53fefe9bcd..f5e5ee889c55 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -18,7 +18,7 @@ */ /*! - * \file Use external cudnn utils function + * \file cuDNN kernel calls for the forward algorithm. */ #include #include diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index c2bf52ccdd19..d744cddfcb5e 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -245,7 +245,6 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- stride_h, stride_w = 1, 1 height, width = 32, 32 - # schedule if tensor_format == 0: xshape = [batch, in_channel, height, width] wshape = [out_channel, in_channel, filter_h, filter_w] @@ -292,7 +291,6 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- s = te.create_schedule(dx.op) - # validation dev = tvm.cuda(0) f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", name="conv2d_backward_data") @@ -389,7 +387,6 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 stride_h, stride_w = 1, 1 height, width = 32, 32 - # schedule if tensor_format == 0: x_shape = [batch, in_channel, height, width] dy_shape = [batch, out_channel, height, width] @@ -425,15 +422,14 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 s = te.create_schedule(dw.op) - # validation dev = tvm.cuda(0) - f = tvm.build(s, [x, dy, dw], "cuda --host=llvm", name="conv2d_backward_filter") + f = tvm.build(s, [dy, x, dw], "cuda --host=llvm", name="conv2d_backward_filter") x = tvm.nd.array(x_np, dev) dy = tvm.nd.array(dy_np, dev) dw = tvm.nd.array(dw_np, dev) - f(x, dy, dw) + f(dy, x, dw) print(np.max(np.abs(dw.numpy() - dw_np))) print(np.mean(np.abs(dw.numpy() - dw_np))) tvm.testing.assert_allclose(dw.numpy(), dw_np, atol=tol, rtol=tol) @@ -519,5 +515,5 @@ def conv_output_shape_kwargs(request): if __name__ == "__main__": # sys.exit(pytest.main(sys.argv)) - # test_conv2d_backward_data() + test_conv2d_backward_data() test_conv2d_backward_filter() From 42bc825b2698dcc8ce4e6dba6a7c69a746586f0f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 04:37:20 +0900 Subject: [PATCH 13/21] fix dgrad test --- python/tvm/contrib/cudnn.py | 8 ++++++-- src/runtime/contrib/cudnn/conv_backward.cc | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 82d3fc64a1bd..96001f721cb2 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -300,12 +300,16 @@ def conv_dgrad_shape( C = w_shape[1] P, Q = dy_shape[2:] R, S = w_shape[2:] + dy_shape = dy_shape[2:] + w_shape = w_shape[2:] elif tensor_format == 1: N = dy_shape[0] K = w_shape[0] - C = w_shape[1] + C = w_shape[-1] P, Q = dy_shape[2:] R, S = w_shape[2:] + dy_shape = dy_shape[1:-1] + w_shape = w_shape[1:-1] else: raise ValueError("Unsupported CuDNN tensor format: '{}'".format(tensor_format)) @@ -746,7 +750,7 @@ def conv_backward_data( ) return te.extern( - x_shape, + dx_shape, [dy, w], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv2d.backward_data", diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 417183335a64..af190d7c8c90 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -239,7 +239,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") std::string conv_dtype = args[12]; int groups = args[13]; - ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, dx, + ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, x, dw, conv_dtype); }); From 367ad62bc4291e647f88cc378e5a1fa1d7abe19b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 05:19:01 +0900 Subject: [PATCH 14/21] support running relay conv2d_backward_weight directly with cudnn --- python/tvm/relay/op/nn/_nn.py | 3 ++ python/tvm/relay/op/strategy/cuda.py | 19 ++++++++++++ python/tvm/relay/op/strategy/generic.py | 38 +++++++++++++++++++++++ python/tvm/topi/cuda/conv2d.py | 17 ++++++++++ src/relay/op/nn/convolution.cc | 1 - tests/python/contrib/test_cudnn.py | 4 +-- tests/python/relay/test_op_grad_level2.py | 33 ++++++++++---------- 7 files changed, 95 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 2a941cc8c28a..fc4f89fc02c0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1062,6 +1062,9 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_injective_schedule("nn.batch_to_space_nd") +reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy) +reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_legalize("nn.conv2d_backward_weight") def legalize_conv2d_backward_weight(attrs, inputs, types): """Legalize conv2d_backward_weight op. diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 69579f690c96..1020631ce2b2 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target): return strategy +@conv2d_backward_weight_strategy.register(["cuda"]) +def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target): + """conv2d_backward_weight cuda strategy""" + strategy = _op.OpStrategy() + if target.kind.name == "cuda" and "cudnn" in target.libs: + strategy.add_implementation( + wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d_backward_weight_strategy.cudnn", + plevel=15, + ) + else: + raise RuntimeError( + "conv2d_backward_weight on cuda is currently only supported with cudnn. " + "Please run Legalize pass to decompose this op into supported ops." + ) + return strategy + + @conv2d_transpose_strategy.register(["cuda", "gpu"]) def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): """conv2d_transpose cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index cc12fa127006..abd3e28bc3eb 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1841,3 +1841,41 @@ def einsum_strategy(attrs, inputs, out_type, target): name="einsum.generic", ) return strategy + + +# conv2d_backward_weight +def wrap_compute_conv2d_backward_weight(topi_compute): + """wrap conv2d_backward_weight topi compute""" + + def _compute_conv2d_backward_weight(attrs, inputs, out_dtype): + kernel_size = get_const_tuple(attrs.kernel_size) + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + out_dtype = attrs.out_dtype + layout = attrs.data_layout + out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype + out = topi_compute( + inputs[0], + inputs[1], + kernel_size, + padding, + strides, + dilation, + groups, + layout, + out_dtype, + ) + return [out] + + return _compute_conv2d_backward_weight + + +@override_native_generic_func("conv2d_backward_weight_strategy") +def conv2d_backward_weight_strategy(attrs, inputs, out_type, target): + """wgrad generic strategy""" + raise RuntimeError( + "conv2d_backward_weight is currently only supported with cudnn. " + "Please run Legalize pass to decompose this op into supported ops." + ) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index bd8d7ec19bb3..899d498caa44 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -123,3 +123,20 @@ def conv2d_cudnn( def schedule_conv2d_cudnn(cfg, outs): """Create the schedule for conv2d_cudnn""" return generic.schedule_extern(outs) + + +def conv2d_backward_weight(dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype): + """Compute conv2d wgrad using CuDNN library""" + assert layout in ["NCHW", "NHWC"] + return cudnn.conv_backward_filter( + dy, + x, + kernel_size, + padding, + stride, + dilation, + conv_mode=1, + tensor_format=0 if layout == "NCHW" else 1, + conv_dtype=output_dtype, + groups=groups, + ) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index f1d4eb3d87ea..30386bbf4415 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -665,7 +665,6 @@ given the original input data and the output gradient. .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel) - .set_attr("TNonComputational", true) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); } // namespace relay diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index d744cddfcb5e..0b96d571434a 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -514,6 +514,4 @@ def conv_output_shape_kwargs(request): if __name__ == "__main__": - # sys.exit(pytest.main(sys.argv)) - test_conv2d_backward_data() - test_conv2d_backward_filter() + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 1efdb262245f..498ca1105a3b 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -233,27 +233,28 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin dtype = "float32" dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) - dw = relay.nn.conv2d_backward_weight( - dy, x, strides=stride, padding=padding, kernel_size=kernel_size + dw_func = relay.Function( + [dy, x], + relay.nn.conv2d_backward_weight( + dy, x, strides=stride, padding=padding, kernel_size=kernel_size + ), ) - dw_func = relay.Function([dy, x], dw) dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) - target = "llvm" - dev = tvm.device(target, 0) - dy_np = np.random.randn(*dy_shape).astype(dtype) - x_np = np.random.randn(*x_shape).astype(dtype) + for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]: + if "cudnn" in target and not tvm.contrib.cudnn.exists(): + continue - dw_np = ( - relay.create_executor(device=dev, target=target) - .evaluate(dw_func_legalized)(dy_np, x_np) - .numpy() - ) - ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( - dy_np, x_np, kernel_size, stride, padding - ) + dev = tvm.device(target, 0) + dy_np = np.random.randn(*dy_shape).astype(dtype) + x_np = np.random.randn(*x_shape).astype(dtype) + + dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() + ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding + ) - np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) def test_conv2d_backward_weight(): From 09b32ba91b5191152422866056fa9edbb079a4c8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 05:20:09 +0900 Subject: [PATCH 15/21] black --- python/tvm/relay/op/nn/_nn.py | 1 + python/tvm/topi/cuda/conv2d.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index fc4f89fc02c0..1fa909e748a0 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1065,6 +1065,7 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy) reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE) + @reg.register_legalize("nn.conv2d_backward_weight") def legalize_conv2d_backward_weight(attrs, inputs, types): """Legalize conv2d_backward_weight op. diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 899d498caa44..b7368ab8454e 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -125,7 +125,9 @@ def schedule_conv2d_cudnn(cfg, outs): return generic.schedule_extern(outs) -def conv2d_backward_weight(dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype): +def conv2d_backward_weight( + dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype +): """Compute conv2d wgrad using CuDNN library""" assert layout in ["NCHW", "NHWC"] return cudnn.conv_backward_filter( From 64b42798965c0b9bebe005c86f3135ba9bbeff98 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 05:31:45 +0900 Subject: [PATCH 16/21] refactor reference function to support nhwc --- python/tvm/topi/testing/__init__.py | 2 +- .../testing/conv2d_backcward_weight_python.py | 44 ++++++++++++++++++- tests/python/contrib/test_cudnn.py | 2 +- tests/python/relay/test_op_grad_level2.py | 2 +- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 75eabffc957a..c3d222cfd120 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -75,4 +75,4 @@ from .nll_loss import nll_loss from .dense import dense from .searchsorted import searchsorted_ref -from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python +from .conv2d_backcward_weight_python import conv2d_backward_weight_python diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py index 587cd45b49c1..36a6b0616053 100644 --- a/python/tvm/topi/testing/conv2d_backcward_weight_python.py +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding Returns ------- - b_np : np.ndarray + dw_np : np.ndarray 4-D with shape [num_filter, in_channel, filter_height, filter_width] """ @@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding dw[k, c, r, s] = acc return dw + + +def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): + """Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout. + + Parameters + ---------- + dy_np : numpy.ndarray + 4-D with shape [batch, in_channel, out_height, out_width] for NCHW layout + + x_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout + + kernel_size : tuple of two ints + Height and width of the weight + + stride : tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : tuple of two ints + Spatial padding, or [pad_h, pad_w] + + layout: string + Layout of dy_np and x_np + + Returns + ------- + dw_np : np.ndarray + Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout, + [num_filter, filter_height, filter_width, in_channel] for NHWC layout. + """ + if layout == "NCHW": + return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding) + + dw_np_oihw = conv2d_backward_weight_nchw_python( + np.transpose(dy_np, [0, 3, 1, 2]), + np.transpose(x_np, [0, 3, 1, 2]), + kernel_size, + stride, + padding, + ) + return np.transpose(dw_np_oihw, [0, 2, 3, 1]) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 0b96d571434a..af3f8c0c9d06 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -397,7 +397,7 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 x_np = np.random.uniform(-1, 1, x_shape).astype(data_dtype) dy_np = np.random.uniform(-1, 1, dy_shape).astype(data_dtype) - dw_np = conv2d_backward_weight_python( + dw_np = tvm.topi.testing.conv2d_backward_weight_python( dy_np, x_np, (filter_h, filter_w), diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 498ca1105a3b..a5fc630f61dc 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -250,7 +250,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin x_np = np.random.randn(*x_shape).astype(dtype) dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() - ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( + ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python( dy_np, x_np, kernel_size, stride, padding ) From 06870e446b45166a28890dc33b40dbd5b5e6ce42 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 05:37:17 +0900 Subject: [PATCH 17/21] removed unused function --- tests/python/contrib/test_cudnn.py | 64 ------------------------------ 1 file changed, 64 deletions(-) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index af3f8c0c9d06..d45624c3ab33 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -314,70 +314,6 @@ def test_conv2d_backward_data(): verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) -def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): - R, S = kernel_size - if layout == "NCHW": - N, C, H, W = x_np.shape - _, K, P, Q = dy_np.shape - w_shape = (K, C, R, S) - else: - N, H, W, C = x_np.shape - _, P, Q, K = dy_np.shape - w_shape = (K, R, S, C) - - pad_h, pad_w = padding - stride_h, stride_w = stride - - dw = np.zeros(w_shape).astype(dy_np.dtype) - - for k in range(K): - for r in range(R): - for s in range(S): - for c in range(C): - acc = 0 - for n in range(N): - for p in range(P): - for q in range(Q): - - if layout == "NCHW": - coord = ( - n, - c, - p * stride_h - pad_h + r, - q * stride_w - pad_w + s, - ) - - if ( - coord[2] < H - and coord[2] >= 0 - and coord[3] < W - and coord[3] >= 0 - ): - acc += dy_np[n, k, p, q] * x_np[coord] - else: - coord = ( - n, - p * stride_h - pad_h + r, - q * stride_w - pad_w + s, - c, - ) - - if ( - coord[1] < H - and coord[1] >= 0 - and coord[2] < W - and coord[2] >= 0 - ): - acc += dy_np[n, p, q, k] * x_np[coord] - - if layout == "NCHW": - dw[k, c, r, s] = acc - else: - dw[k, r, s, c] = acc - - return dw - - def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1e-5): batch = 3 in_channel = 4 From 615b90d41312ea2b04ddc2ef29a525108210b701 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 11:11:37 +0900 Subject: [PATCH 18/21] lint --- python/tvm/contrib/cudnn.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 96001f721cb2..4ca90ccb574d 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -258,9 +258,7 @@ def conv_output_shape( return output -def conv_dgrad_shape( - tensor_format, pad, stride, dilation, dy_shape, w_shape, data_dtype, conv_dtype, groups=1 -): +def conv_dgrad_shape(tensor_format, pad, stride, dilation, dy_shape, w_shape): """Get output shape of conv2d gradient with respect to data Paramters @@ -296,18 +294,12 @@ def conv_dgrad_shape( if tensor_format == 0: N = dy_shape[0] - K = w_shape[0] C = w_shape[1] - P, Q = dy_shape[2:] - R, S = w_shape[2:] dy_shape = dy_shape[2:] w_shape = w_shape[2:] elif tensor_format == 1: N = dy_shape[0] - K = w_shape[0] C = w_shape[-1] - P, Q = dy_shape[2:] - R, S = w_shape[2:] dy_shape = dy_shape[1:-1] w_shape = w_shape[1:-1] else: @@ -732,9 +724,7 @@ def conv_backward_data( dy.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - dx_shape = conv_dgrad_shape( - tensor_format, pad, stride, dilation, dy.shape, w.shape, dy.dtype, conv_dtype, groups - ) + dx_shape = conv_dgrad_shape(tensor_format, pad, stride, dilation, dy.shape, w.shape) algo = conv_backward_data_find_algo( tensor_format, From 5b5992ca5cf969179d501f733b98d49acd9fbe8b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 13:39:20 +0900 Subject: [PATCH 19/21] enable offloading conv2d_transpose to cudnn dgrad --- python/tvm/contrib/cudnn.py | 27 ++++++++++++++----- python/tvm/relay/op/strategy/cuda.py | 9 +++++++ python/tvm/topi/cuda/conv2d_transpose_nchw.py | 8 ++++++ python/tvm/topi/nn/conv2d_transpose.py | 1 - tests/python/relay/test_op_level2.py | 17 ++++++++---- 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 4ca90ccb574d..c897de74b250 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -258,7 +258,9 @@ def conv_output_shape( return output -def conv_dgrad_shape(tensor_format, pad, stride, dilation, dy_shape, w_shape): +def conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy_shape, w_shape, output_padding=(0, 0) +): """Get output shape of conv2d gradient with respect to data Paramters @@ -306,10 +308,12 @@ def conv_dgrad_shape(tensor_format, pad, stride, dilation, dy_shape, w_shape): raise ValueError("Unsupported CuDNN tensor format: '{}'".format(tensor_format)) input_dims = [] - for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i in zip( - dy_shape, w_shape, pad, stride, dilation + for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i, out_pad in zip( + dy_shape, w_shape, pad, stride, dilation, output_padding ): - input_dim = (dy_shape_i - 1) * stride_i - 2 * pad_i + (((w_shape_i - 1) * dilation_i) + 1) + input_dim = ( + (dy_shape_i - 1) * stride_i - 2 * pad_i + (((w_shape_i - 1) * dilation_i) + 1) + out_pad + ) input_dims.append(input_dim) if tensor_format == 0: @@ -682,7 +686,16 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co def conv_backward_data( - dy, w, pad, stride, dilation, conv_mode, tensor_format, conv_dtype, groups=1 + dy, + w, + pad, + stride, + dilation, + conv_mode, + tensor_format, + conv_dtype, + groups=1, + output_padding=(0, 0), ): """Create a CuDNN extern op that computes the gradient of 2D convolution with respect to data. @@ -724,7 +737,9 @@ def conv_backward_data( dy.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad data yet." - dx_shape = conv_dgrad_shape(tensor_format, pad, stride, dilation, dy.shape, w.shape) + dx_shape = conv_dgrad_shape( + tensor_format, pad, stride, dilation, dy.shape, w.shape, output_padding + ) algo = conv_backward_data_find_algo( tensor_format, diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1020631ce2b2..017b3edf81ec 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -598,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw), name="conv2d_transpose_nchw.cuda", ) + + if target.kind.name == "cuda" and "cudnn" in target.libs and attrs.kernel_layout == "IOHW": + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn), + wrap_topi_schedule(topi.generic.schedule_extern), + name="conv2d_transpose.cudnn.cuda", + plevel=25, + ) + # TODO(masahi): Support conv2d_transpose NHWC. return strategy diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py b/python/tvm/topi/cuda/conv2d_transpose_nchw.py index 3b704170a2e9..36ce3a3d2454 100644 --- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py +++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py @@ -19,6 +19,7 @@ import tvm from tvm import te +from tvm.contrib import cudnn from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn @@ -286,3 +287,10 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)): + """Compute conv2d_tranpose using cudnn dgrad kernel""" + return cudnn.conv_backward_data( + x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding + ) diff --git a/python/tvm/topi/nn/conv2d_transpose.py b/python/tvm/topi/nn/conv2d_transpose.py index 2871699350ed..c408095eb7ab 100644 --- a/python/tvm/topi/nn/conv2d_transpose.py +++ b/python/tvm/topi/nn/conv2d_transpose.py @@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types): result : tvm.relay.Expr The legalized expr """ - data, kernel = inputs kernel_layout = attrs["kernel_layout"] if attrs["data_layout"] == "NHWC": diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index db712be4262e..6d428bfde21b 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -24,7 +24,7 @@ import tvm.testing import tvm.topi.testing from tvm import autotvm, relay, te -from tvm.contrib import utils +from tvm.contrib import utils, cudnn from tvm.ir.module import IRModule from tvm.relay import transform from tvm.relay.testing import run_infer_type @@ -838,10 +838,10 @@ def test_conv2d_transpose_infer_type(): @tvm.testing.uses_gpu def test_conv2d_transpose_nchw_run(): k_layouts = {"OIHW": (10, 3, 3, 3), "IOHW": (3, 10, 3, 3)} + output_padding = (1, 1) for k_layout, kshape in k_layouts.items(): dshape = (1, 3, 18, 18) - oshape = (1, 10, 36, 36) x = relay.var("x", shape=dshape) w = relay.var("w") y = relay.nn.conv2d_transpose( @@ -851,7 +851,7 @@ def test_conv2d_transpose_nchw_run(): kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), - output_padding=(1, 1), + output_padding=output_padding, kernel_layout=k_layout, data_layout="NCHW", ) @@ -866,9 +866,16 @@ def test_conv2d_transpose_nchw_run(): else: kernel_iohw = kernel - ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel_iohw, 2, 1, (1, 1)) + ref_res = tvm.topi.testing.conv2d_transpose_nchw_python( + data, kernel_iohw, 2, 1, output_padding + ) - for target, dev in tvm.testing.enabled_targets(): + enabled_targets = tvm.testing.enabled_targets() + + if cudnn.exists() and k_layout == "IOHW": + enabled_targets.append(("cuda -libs=cudnn", tvm.cuda(0))) + + for target, dev in enabled_targets: op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( data, kernel ) From c59567694ee2d5cbe811511557741a3ac0480b67 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 Jan 2022 09:06:23 +0900 Subject: [PATCH 20/21] relax tol --- tests/python/contrib/test_cudnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index d45624c3ab33..718f3878270b 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -374,8 +374,8 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 @tvm.testing.requires_gpu @requires_cudnn def test_conv2d_backward_filter(): - verify_conv2d_backward_filter("float32", "float32", tensor_format=0, tol=1e-5) - verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-5) + verify_conv2d_backward_filter("float32", "float32", tensor_format=0, tol=1e-4) + verify_conv2d_backward_filter("float32", "float32", tensor_format=1, tol=1e-4) test_kwargs_default_2d = { From 7ed295c9863219997566ed66a1bc41ece858ac16 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 22 Jan 2022 04:40:58 +0900 Subject: [PATCH 21/21] name fix, remove print --- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/topi/cuda/conv2d.py | 2 +- tests/python/contrib/test_cudnn.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 017b3edf81ec..af7451408d27 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -570,7 +570,7 @@ def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if target.kind.name == "cuda" and "cudnn" in target.libs: strategy.add_implementation( - wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight), + wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn), wrap_topi_schedule(topi.generic.schedule_extern), name="conv2d_backward_weight_strategy.cudnn", plevel=15, diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index b7368ab8454e..15fcaaa02134 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -125,7 +125,7 @@ def schedule_conv2d_cudnn(cfg, outs): return generic.schedule_extern(outs) -def conv2d_backward_weight( +def conv2d_backward_weight_cudnn( dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype ): """Compute conv2d wgrad using CuDNN library""" diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 718f3878270b..0c39a1a2428d 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -299,8 +299,6 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- dx = tvm.nd.array(dx_np, dev) f(dy, w, dx) - print(np.max(np.abs(dx.numpy() - dx_np))) - print(np.mean(np.abs(dx.numpy() - dx_np))) tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol) @@ -366,8 +364,6 @@ def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, tol=1 dw = tvm.nd.array(dw_np, dev) f(dy, x, dw) - print(np.max(np.abs(dw.numpy() - dw_np))) - print(np.mean(np.abs(dw.numpy() - dw_np))) tvm.testing.assert_allclose(dw.numpy(), dw_np, atol=tol, rtol=tol)