diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 6c3970e13d664..be00b9dd2b9dc 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -155,9 +155,6 @@ struct AT_CUDA_API ConvolutionDescriptor AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, CUDNN_CROSS_CORRELATION, mathType)); AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); - if(dataType == CUDNN_DATA_HALF) - AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); } }; diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp index ac657a7d31e56..10ff929353cda 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv.cpp @@ -94,6 +94,25 @@ std::tuple cudnn_convolution_transpose_backwar #include #include +// Note [chooseAlgorithm doesn't respect mathType] +// You might be wondering, why are we calling cudnnSetConvolutionMathType after +// calling chooseAlgorithm... +// Turns out, the mathType returned by the chooseAlgorithm can be different +// from what we set in the descriptor and hence, we have to explicitly update it +// after the chooseAlgorithm has found the best pair of algorithm+mathType. +// Otherwise, even though we'll be calling cudnnConvolutionForward with the +// fastest algorithm, under the hood, cudnn will run it with the slower kernel +// since it sees fastest algorithm combination with a sub optimal mathType. + +// Note [cudnnSetConvolutionMathType cannot be called in descriptor] +// When cudnnSetConvolutionMathType is called before cudnnGetConvolutionForwardAlgorithm_v7, +// cudnnGet finds an algorithm based on the mathType set by cudnnSetConvolutionMathType. +// That is, if we call cudnnSetConvolutionMathType in the setter of the descriptor +// (to have some default values, e.g. CUDNN_TENSOR_OP when fp16), cudnnGet*_v7 returns +// algo1 with CUDNN_TENSOR_OP math type, instead of not caring about what was set by +// cudnnSetConvolutionMathType before it (and returning algo1 with CUDNN_DEFAULT_MATH +// which is performant). A bug has been filed internally at NVIDIA. + namespace at { namespace native { // TODO: Go through all the checking code again and make sure @@ -340,9 +359,9 @@ struct BenchmarkCache { } }; -BenchmarkCache fwd_algos; -BenchmarkCache bwd_data_algos; -BenchmarkCache bwd_filter_algos; +BenchmarkCache fwd_algos; +BenchmarkCache bwd_data_algos; +BenchmarkCache bwd_filter_algos; // TODO: Stop manually allocating CUDA memory; allocate an ATen byte // tensor instead. @@ -363,7 +382,7 @@ struct Workspace { void* data; }; -template +template struct algorithm_search { }; @@ -452,14 +471,14 @@ perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) { } template<> -struct algorithm_search { +struct algorithm_search { using perf_t = cudnnConvolutionFwdAlgoPerf_t; using algo_t = cudnnConvolutionFwdAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - static BenchmarkCache& cache() { return fwd_algos; } + static BenchmarkCache& cache() { return fwd_algos; } - static perf_t findAlgorithm(const ConvolutionArgs& args) { + static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { static const algo_t algos[] = { CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_FFT, @@ -475,36 +494,32 @@ struct algorithm_search { "Missing cuDNN convolution forward algorithms"); int perf_count; std::unique_ptr perf_results(new perf_t[num_algos]); - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); - Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( - args.handle, - args.idesc.desc(), args.input.data_ptr(), - args.wdesc.desc(), args.weight.data_ptr(), - args.cdesc.desc(), - args.odesc.desc(), args.output.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); - return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); - } - - static void getAlgorithm( - const ConvolutionArgs& args, - algo_t* algo) - { - cudnnConvolutionFwdPreference_t pref = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST; - AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm( - args.handle, - args.idesc.desc(), - args.wdesc.desc(), - args.cdesc.desc(), - args.odesc.desc(), - pref, - 0, - algo)); + if (!benchmark) { + AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7( + args.handle, + args.idesc.desc(), + args.wdesc.desc(), + args.cdesc.desc(), + args.odesc.desc(), + num_algos, + &perf_count, + perf_results.get())); + } else { + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); + Workspace ws(max_ws_size); + AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( + args.handle, + args.idesc.desc(), args.input.data_ptr(), + args.wdesc.desc(), args.weight.data_ptr(), + args.cdesc.desc(), + args.odesc.desc(), args.output.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); + } + return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); } static void getWorkspaceSize( @@ -523,14 +538,14 @@ struct algorithm_search { }; template<> -struct algorithm_search { +struct algorithm_search { using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; using algo_t = cudnnConvolutionBwdDataAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static BenchmarkCache& cache() { return bwd_data_algos; } + static BenchmarkCache& cache() { return bwd_data_algos; } - static perf_t findAlgorithm(const ConvolutionArgs& args) { + static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { static const algo_t algos[] = { CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, @@ -544,32 +559,32 @@ struct algorithm_search { "Missing cuDNN convolution backward data algorithms."); int perf_count; std::unique_ptr perf_results(new perf_t[num_algos]); - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); - Workspace ws(max_ws_size); - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, - args.wdesc.desc(), args.weight.data_ptr(), - args.odesc.desc(), args.output.data_ptr(), - args.cdesc.desc(), - args.idesc.desc(), args.input.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); - return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); - } - - static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( - args.handle, - args.wdesc.desc(), - args.odesc.desc(), - args.cdesc.desc(), - args.idesc.desc(), - CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, - 0, - algo)); + if (!benchmark) { + AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7( + args.handle, + args.wdesc.desc(), + args.odesc.desc(), + args.cdesc.desc(), + args.idesc.desc(), + num_algos, + &perf_count, + perf_results.get())); + } else { + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); + Workspace ws(max_ws_size); + AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, + args.wdesc.desc(), args.weight.data_ptr(), + args.odesc.desc(), args.output.data_ptr(), + args.cdesc.desc(), + args.idesc.desc(), args.input.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); + } + return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); } static void getWorkspaceSize( @@ -588,15 +603,15 @@ struct algorithm_search { }; template<> -struct algorithm_search { +struct algorithm_search { using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; using algo_t = cudnnConvolutionBwdFilterAlgo_t; static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - static BenchmarkCache& cache() { return bwd_filter_algos; } + static BenchmarkCache& cache() { return bwd_filter_algos; } - static perf_t findAlgorithm(const ConvolutionArgs& args) { + static perf_t findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark) { static const algo_t algos[] = { CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, @@ -610,37 +625,35 @@ struct algorithm_search { static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); std::unique_ptr perf_results(new perf_t[num_algos]); - size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); int perf_count; - Workspace ws(max_ws_size); - - AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, - args.idesc.desc(), args.input.data_ptr(), - args.odesc.desc(), args.output.data_ptr(), - args.cdesc.desc(), - args.wdesc.desc(), args.weight.data_ptr(), - num_algos, - &perf_count, - perf_results.get(), - ws.data, - ws.size)); + if (!benchmark) { + AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + args.handle, + args.idesc.desc(), + args.odesc.desc(), + args.cdesc.desc(), + args.wdesc.desc(), + num_algos, + &perf_count, + perf_results.get())); + } else { + size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); + Workspace ws(max_ws_size); + AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, + args.idesc.desc(), args.input.data_ptr(), + args.odesc.desc(), args.output.data_ptr(), + args.cdesc.desc(), + args.wdesc.desc(), args.weight.data_ptr(), + num_algos, + &perf_count, + perf_results.get(), + ws.data, + ws.size)); + } return getBestAlgorithm(perf_results.get(), args.params.deterministic, perf_count); } - static void getAlgorithm(const ConvolutionArgs& args, algo_t* algo) { - AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( - args.handle, - args.idesc.desc(), - args.odesc.desc(), - args.cdesc.desc(), - args.wdesc.desc(), - CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, - 0, - algo) - ); - } - static void getWorkspaceSize(const ConvolutionArgs& args, algo_t algo, size_t* workspaceSize) { AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( @@ -654,70 +667,90 @@ struct algorithm_search { } }; -template -void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) { - using search = algorithm_search; +template +void findAlgorithm(const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark, perf_t* algoPerf) { + using search = algorithm_search; auto& cache = search::cache(); - if (cache.find(args.params, algo)) { + if (cache.find(args.params, algoPerf)) { return; } if (args.params.deterministic && !benchmark) { - *algo = search::DEFAULT_ALGO; - return; - } - - if (!benchmark) { - search::getAlgorithm(args, algo); + algoPerf->algo = search::DEFAULT_ALGO; + // Note [cudnnSetConvolutionMathType cannot be called in descriptor] + if (dataType == CUDNN_DATA_HALF) { + algoPerf->mathType = CUDNN_TENSOR_OP_MATH; + } else { + algoPerf->mathType = CUDNN_DEFAULT_MATH; + } + search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); return; } - if (cache.find(args.params, algo)) { - // re-check cache since another thread may have benchmarked the algorithm - return; - } + if (benchmark) { + if (cache.find(args.params, algoPerf)) { + // re-check cache since another thread may have benchmarked the algorithm + return; + } + } - auto perfResults = search::findAlgorithm(args); + auto perfResults = search::findAlgorithm(dataType, args, benchmark); // for deterministic algo, look at all the perf results and return the best // deterministic algo if (perfResults.status == CUDNN_STATUS_SUCCESS && !(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) { - *algo = perfResults.algo; + + // if benchmarking, map the original params with the found algo+math type for re-use + if (benchmark) { + cache.insert(args.params, perfResults); + + // Free the cached blocks in our caching allocator. They are + // needed here because the above benchmarking uses a huge amount of memory, + // e.g. a few GBs. + THCCachingAllocator_emptyCache(); + } + + *algoPerf = perfResults; } else { - *algo = search::DEFAULT_ALGO; + algoPerf->algo = search::DEFAULT_ALGO; + // Note [cudnnSetConvolutionMathType cannot be called in descriptor] + if (dataType == CUDNN_DATA_HALF) { + algoPerf->mathType = CUDNN_TENSOR_OP_MATH; + } else { + algoPerf->mathType = CUDNN_DEFAULT_MATH; + } + search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); } - cache.insert(args.params, *algo); - - // Free the cached blocks in our caching allocator. They are - // needed here because the above benchmarking uses a huge amount of memory, - // e.g. a few GBs. - THCCachingAllocator_emptyCache(); } -template +template Workspace chooseAlgorithm( + const cudnnDataType_t dataType, const ConvolutionArgs& args, bool benchmark, - algo_t* algo) + perf_t* algoPerf) { - findAlgorithm(args, benchmark, algo); + findAlgorithm(dataType, args, benchmark, algoPerf); - using search = algorithm_search; - size_t workspace_size; - search::getWorkspaceSize(args, *algo, &workspace_size); + using search = algorithm_search; try { - return Workspace(workspace_size); + return Workspace(algoPerf->memory); } catch (const std::exception& e) { cudaGetLastError(); // clear OOM error // switch to default algorithm and record it in the cache to prevent // further OOM errors - *algo = search::DEFAULT_ALGO; - search::cache().insert(args.params, *algo); - - search::getWorkspaceSize(args, *algo, &workspace_size); - return Workspace(workspace_size); + algoPerf->algo = search::DEFAULT_ALGO; + // Note [cudnnSetConvolutionMathType cannot be called in descriptor] + if (dataType == CUDNN_DATA_HALF) { + algoPerf->mathType = CUDNN_TENSOR_OP_MATH; + } else { + algoPerf->mathType = CUDNN_DEFAULT_MATH; + } + search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory)); + search::cache().insert(args.params, *algoPerf); + return Workspace(algoPerf->memory); } } @@ -811,8 +844,13 @@ void raw_cudnn_convolution_forward_out( // wasteful; we'd rather reuse the workspace. OTOH, legacy group // convolution support is already pretty slow, so this might not // matter. (This applies to raw_cudnn_convolution_backward_input as well.) - cudnnConvolutionFwdAlgo_t fwdAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg); + cudnnConvolutionFwdAlgoPerf_t fwdAlgPerf; + Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &fwdAlgPerf); + + // update convDesc mathType since cudnn now requires both algo + mathType to figure out + // whether to use Tensor cores or not + // See Note [chooseAlgorithm doesn't respect mathType] + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType)); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -821,7 +859,7 @@ void raw_cudnn_convolution_forward_out( args.handle, &one, args.idesc.desc(), input.data_ptr(), args.wdesc.desc(), weight.data_ptr(), - args.cdesc.desc(), fwdAlg, workspace.data, workspace.size, + args.cdesc.desc(), fwdAlgPerf.algo, workspace.data, workspace.size, &zero, args.odesc.desc(), output.data_ptr())); } @@ -930,8 +968,13 @@ void raw_cudnn_convolution_backward_input_out( args.odesc.set(grad_output); args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); - cudnnConvolutionBwdDataAlgo_t bwdDataAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg); + cudnnConvolutionBwdDataAlgoPerf_t bwdDataAlgPerf; + Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdDataAlgPerf); + + // update convDesc mathType since cudnn now requires both algo + mathType to figure out + // whether to use Tensor cores or not + // See Note [chooseAlgorithm doesn't respect mathType] + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType)); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -940,7 +983,7 @@ void raw_cudnn_convolution_backward_input_out( args.handle, &one, args.wdesc.desc(), weight.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), - args.cdesc.desc(), bwdDataAlg, workspace.data, workspace.size, + args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data, workspace.size, &zero, args.idesc.desc(), grad_input.data_ptr())); } @@ -1066,8 +1109,13 @@ void raw_cudnn_convolution_backward_weight_out( args.odesc.set(grad_output); args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); - cudnnConvolutionBwdFilterAlgo_t bwdFilterAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); + cudnnConvolutionBwdFilterAlgoPerf_t bwdFilterAlgPerf; + Workspace workspace = chooseAlgorithm(dataType, args, benchmark, &bwdFilterAlgPerf); + + // update convDesc mathType since cudnn now requires both algo + mathType to figure out + // whether to use Tensor cores or not + // See Note [chooseAlgorithm doesn't respect mathType] + AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType)); Constant one(dataType, 1); Constant zero(dataType, 0); @@ -1076,7 +1124,7 @@ void raw_cudnn_convolution_backward_weight_out( args.handle, &one, args.idesc.desc(), input.data_ptr(), args.odesc.desc(), grad_output.data_ptr(), - args.cdesc.desc(), bwdFilterAlg, workspace.data, workspace.size, + args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data, workspace.size, &zero, args.wdesc.desc(), grad_weight.data_ptr())); }