-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add cuda ops: UpFirDn2d and fused_bias_leakyrelu (#900)
* add upfirdn2d op * fix bug in pybind * add fused bias leakyrelu * fix bug in fused-bias-leakyrelu * fix lint error * fix bug in build cpu version * fix bug in build cpu version * fix lint * fix comment from zww Co-authored-by: zhangshilong <zhangshilong@sensetime.com>
- Loading branch information
Showing
10 changed files
with
983 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Modified from | ||
// from | ||
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp | ||
#include "pytorch_cpp_helper.hpp" | ||
|
||
#ifdef MMCV_WITH_CUDA | ||
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input, | ||
const torch::Tensor& bias, | ||
const torch::Tensor& refer, int act, | ||
int grad, float alpha, float scale); | ||
|
||
#endif | ||
|
||
torch::Tensor fused_bias_leakyrelu(const torch::Tensor& input, | ||
const torch::Tensor& bias, | ||
const torch::Tensor& refer, int act, | ||
int grad, float alpha, float scale) { | ||
#ifdef MMCV_WITH_CUDA | ||
CHECK_CUDA(input); | ||
CHECK_CUDA(bias); | ||
|
||
return fused_bias_leakyrelu_op(input, bias, refer, act, grad, alpha, scale); | ||
#else | ||
AT_ERROR("Fused bias leakyrelu is not compiled with GPU support"); | ||
#endif | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// from | ||
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu | ||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. | ||
// | ||
// This work is made available under the Nvidia Source Code License-NC. | ||
// To view a copy of this license, visit | ||
// https://nvlabs.github.io/stylegan2/license.html | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/AccumulateType.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <torch/types.h> | ||
|
||
#include <ATen/cuda/CUDAApplyUtils.cuh> | ||
|
||
template <typename scalar_t> | ||
static __global__ void fused_bias_act_kernel( | ||
scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, | ||
const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, | ||
int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { | ||
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; | ||
|
||
scalar_t zero = 0.0; | ||
|
||
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; | ||
loop_idx++, xi += blockDim.x) { | ||
scalar_t x = p_x[xi]; | ||
|
||
if (use_bias) { | ||
x += p_b[(xi / step_b) % size_b]; | ||
} | ||
|
||
scalar_t ref = use_ref ? p_ref[xi] : zero; | ||
|
||
scalar_t y; | ||
|
||
// act = 1: linear layer | ||
// act = 3: leaky relu layer | ||
// grad = 0: direct forward path | ||
// grad = 1: first order deviation | ||
// grad = 2: second order deviation | ||
switch (act * 10 + grad) { | ||
default: | ||
case 10: | ||
y = x; | ||
break; | ||
case 11: | ||
y = x; | ||
break; | ||
case 12: | ||
y = 0.0; | ||
break; | ||
|
||
case 30: | ||
y = (x > 0.0) ? x : x * alpha; | ||
break; | ||
case 31: | ||
y = (ref > 0.0) ? x : x * alpha; | ||
break; | ||
case 32: | ||
y = 0.0; | ||
break; | ||
} | ||
|
||
out[xi] = y * scale; | ||
} | ||
} | ||
|
||
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input, | ||
const torch::Tensor& bias, | ||
const torch::Tensor& refer, int act, | ||
int grad, float alpha, float scale) { | ||
int curDevice = -1; | ||
cudaGetDevice(&curDevice); | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); | ||
|
||
auto x = input.contiguous(); | ||
auto b = bias.contiguous(); | ||
auto ref = refer.contiguous(); | ||
|
||
int use_bias = b.numel() ? 1 : 0; | ||
int use_ref = ref.numel() ? 1 : 0; | ||
|
||
int size_x = x.numel(); | ||
int size_b = b.numel(); | ||
int step_b = 1; | ||
|
||
for (int i = 1 + 1; i < x.dim(); i++) { | ||
step_b *= x.size(i); | ||
} | ||
|
||
int loop_x = 4; | ||
int block_size = 4 * 32; | ||
int grid_size = (size_x - 1) / (loop_x * block_size) + 1; | ||
|
||
auto y = torch::empty_like(x); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
x.scalar_type(), "fused_bias_act_kernel", [&] { | ||
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>( | ||
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), | ||
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha, | ||
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); | ||
}); | ||
|
||
return y; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// from | ||
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp | ||
#include "pytorch_cpp_helper.hpp" | ||
|
||
#ifdef MMCV_WITH_CUDA | ||
torch::Tensor upfirdn2d_op(const torch::Tensor& input, | ||
const torch::Tensor& kernel, int up_x, int up_y, | ||
int down_x, int down_y, int pad_x0, int pad_x1, | ||
int pad_y0, int pad_y1); | ||
|
||
#endif | ||
|
||
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, | ||
int up_x, int up_y, int down_x, int down_y, int pad_x0, | ||
int pad_x1, int pad_y0, int pad_y1) { | ||
#ifdef MMCV_WITH_CUDA | ||
CHECK_CUDA(input); | ||
CHECK_CUDA(kernel); | ||
|
||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, | ||
pad_y0, pad_y1); | ||
#else | ||
AT_ERROR("UpFirDn2d is not compiled with GPU support"); | ||
#endif | ||
} |
Oops, something went wrong.