Skip to content

Commit

Permalink
dgrad compiled
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 17, 2022
1 parent 47f35be commit e311ba3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 45 deletions.
88 changes: 44 additions & 44 deletions src/runtime/contrib/cudnn/conv_backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,70 +32,70 @@ namespace contrib {
using namespace runtime;

void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[],
const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
DLTensor* y, const std::string& conv_dtype) {
const int stride[], const int dilation[], DLTensor* dy, DLTensor* w,
DLTensor* dx, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
y->shape, x->dtype, conv_dtype);
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy->shape, w->shape,
dx->shape, dy->dtype, conv_dtype);
// Set Device
entry_ptr->conv_entry.device = x->device;
entry_ptr->conv_entry.device = dy->device;
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
entry_ptr->conv_entry.bwd_data_algo = static_cast<cudnnConvolutionBwdDataAlgo_t>(algo);

// Set workspace
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.fwd_algo, &workspace_size));
CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.bwd_data_algo, &workspace_size));
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
CUDNN_CALL(cudnnConvolutionForward(
CUDNN_CALL(cudnnConvolutionBackwardData(
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.output_desc, dy->data,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo,
entry_ptr->conv_entry.workspace, workspace_size,
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
entry_ptr->conv_entry.output_desc, y->data));
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.input_desc,
dx->data));
}

void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[],
const int y_dim[], const std::string& data_dtype,
const int dilation[], const int dy_dim[], const int w_dim[],
const int dx_dim[], const std::string& data_dtype,
const std::string& conv_dtype, TVMRetValue* ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
const int full_dims = dims + 2;
std::vector<int64_t> x_dim_int64(full_dims);
std::vector<int64_t> dy_dim_int64(full_dims);
std::vector<int64_t> w_dim_int64(full_dims);
std::vector<int64_t> y_dim_int64(full_dims);
std::vector<int64_t> dx_dim_int64(full_dims);
for (int i = 0; i < full_dims; ++i) {
x_dim_int64[i] = x_dim[i];
dy_dim_int64[i] = dy_dim[i];
w_dim_int64[i] = w_dim[i];
y_dim_int64[i] = y_dim[i];
dx_dim_int64[i] = dx_dim[i];
}
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, dy_dim_int64.data(),
w_dim_int64.data(), dx_dim_int64.data(), String2DLDataType(data_dtype),
conv_dtype);

int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));

const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};

cudnnConvolutionBwdDataAlgoPerf_t perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT];
CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.output_desc,
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, perf_results));

const std::vector<std::string> fwd_algo_names{
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"};

auto best_algo = perf_results[0].algo;
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data algorithms, choosing "
<< fwd_algo_names[best_algo];
for (int i = 0; i < returned_algo_count; ++i) {
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
Expand All @@ -117,13 +117,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
}
DLTensor* x = args[9];
DLTensor* dy = args[9];
DLTensor* w = args[10];
DLTensor* y = args[11];
DLTensor* dx = args[11];
std::string conv_dtype = args[12];
int groups = args[13];

ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y,
ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, dy, w, dx,
conv_dtype);
});

Expand All @@ -134,14 +134,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
int* dx_dim = static_cast<int*>(static_cast<void*>(args[7]));
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim,
BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, dy_dim, w_dim, dx_dim,
data_dtype, conv_dtype, ret);
});

Expand Down
4 changes: 3 additions & 1 deletion src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) {
struct ConvEntry {
cudnnConvolutionDescriptor_t conv_desc;
cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION};
cudnnFilterDescriptor_t filter_desc;
cudnnDataType_t data_type;
cudnnTensorFormat_t tensor_format;
cudnnTensorDescriptor_t input_desc;
cudnnFilterDescriptor_t filter_desc;
cudnnTensorDescriptor_t output_desc;
cudnnConvolutionFwdAlgo_t fwd_algo;
cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
// cudnnMathType_t math_type;
Device device;
runtime::DeviceAPI* cuda_api;
Expand Down

0 comments on commit e311ba3

Please sign in to comment.