Skip to content

Commit

Permalink
UpSample-nearest cuda kernel update (#21694)
Browse files Browse the repository at this point in the history
Summary:
updating upsampling kernel:
1. avoids atomicAdd for better fp16 performance.
2. better launch configures for 2D input.
Pull Request resolved: pytorch/pytorch#21694

Differential Revision: D15875791

Pulled By: ezyang

fbshipit-source-id: 426fc5d5f0c0cdf58bfa1a2b564f17a9ea286fa4
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Jun 18, 2019
1 parent aa9e51d commit 427e298
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 356 deletions.
17 changes: 17 additions & 0 deletions aten/src/ATen/native/cuda/LaunchUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

namespace at {
namespace native {

// returns 2**floor(log2(n))
static int lastPow2(unsigned int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n - (n >> 1);
}

} // namespace native
} // namespace at
10 changes: 1 addition & 9 deletions aten/src/ATen/native/cuda/Normalization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/LaunchUtils.h>

namespace at { namespace native {

Expand Down Expand Up @@ -38,15 +39,6 @@ static int getNumThreads(int nElem) {
return MAX_BLOCK_SIZE;
}

static int lastPow2(unsigned int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n - (n >> 1);
}

// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/UpSample.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ __device__ __forceinline__ static int nearest_neighbor_compute_source_index(
int dst_index,
int input_size) {
const int src_index =
min<int>(static_cast<int>(floorf(dst_index * scale)), input_size - 1);
min(static_cast<int>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}

Expand All @@ -173,8 +173,8 @@ __device__ __forceinline__ static scalar_t upsample_get_value_bounded(
int width,
int y,
int x) {
int access_y = max<int>(min<int>(y, height - 1), 0);
int access_x = max<int>(min<int>(x, width - 1), 0);
int access_y = max(min(y, height - 1), 0);
int access_x = max(min(x, width - 1), 0);
return data[batch][channel][access_y][access_x];
}

Expand All @@ -189,8 +189,8 @@ __device__ __forceinline__ static void upsample_increment_value_bounded(
int y,
int x,
accscalar_t value) {
int access_y = max<int>(min<int>(y, height - 1), 0);
int access_x = max<int>(min<int>(x, width - 1), 0);
int access_y = max(min(y, height - 1), 0);
int access_x = max(min(x, width - 1), 0);
/* TODO: result here is trucated to scalar_t,
check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912
*/
Expand Down
176 changes: 79 additions & 97 deletions aten/src/ATen/native/cuda/UpSampleNearest1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,92 +11,78 @@ namespace at {
namespace native {
namespace {

template <typename scalar_t, typename accscalar_t>
#define MAX_THREADS 512

template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest1d_out_frame(
const int n,
const PackedTensorAccessor<scalar_t, 3> idata,
PackedTensorAccessor<scalar_t, 3> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int width1 = idata.size(2);
const int width2 = odata.size(2);

const float scale = (float)width1 / (float)width2;

if (index < n) {
const int w2 = index % width2;
// special case: just copy
if (width1 == width2) {
const int w1 = w2;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = idata[n][c][w1];
odata[n][c][w2] = val;
}
}
return;
}
//
const int w1 = nearest_neighbor_compute_source_index(scale, w2, width1);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = idata[n][c][w1];
odata[n][c][w2] = val;
}
}
const scalar_t* input,
size_t dim_b,
size_t dim_c,
size_t src_dim_w,
size_t dst_dim_w,
scalar_t* output) {
size_t dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_w)
return;

float scale_factor = (float)src_dim_w / (float)dst_dim_w;

int c = (dst_idx / dst_dim_w) % dim_c;

int dst_x = dst_idx % dst_dim_w;
int src_x = nearest_neighbor_compute_source_index(scale_factor, dst_x, src_dim_w);

size_t src_idx = c * src_dim_w + src_x;
int src_stride = dim_c * src_dim_w;
int dst_stride = dim_c * dst_dim_w;

for (int b = 0; b < dim_b; b++) {
output[dst_idx] = input[src_idx];
src_idx += src_stride;
dst_idx += dst_stride;
}
}

// Backward operation
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_nearest1d_backward_out_frame(
const int n,
PackedTensorAccessor<scalar_t, 3> idata,
const PackedTensorAccessor<scalar_t, 3> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;

const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int width1 = idata.size(2);
const int width2 = odata.size(2);

const float scale = (float)width1 / (float)width2;

if (index < n) {
const int w2 = index % width2;
// special case: just copy
if (width1 == width2) {
const int w1 = w2;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][w1];
idata[n][c][w2] = val;
}
}
return;
}
//
const int w1 = nearest_neighbor_compute_source_index(scale, w2, width1);

for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t d2val = odata[n][c][w2];
atomicAdd(&idata[n][c][w1], d2val);
}
const scalar_t* grad_o,
size_t dim_b,
size_t dim_c,
size_t src_dim_w,
size_t dst_dim_w,
scalar_t* grad_i) {

size_t dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_w)
return;

float scale_factor = (float)src_dim_w / (float)dst_dim_w;

int c = (dst_idx / (dst_dim_w)) % dim_c;

int dst_x = dst_idx % dst_dim_w;
int src_x = nearest_neighbor_compute_source_index(scale_factor, dst_x, src_dim_w);
int src_x_up = nearest_neighbor_compute_source_index(scale_factor, dst_x+1, src_dim_w+1);

for (int b = 0; b < dim_b; b++) {
accscalar_t grad = 0;
size_t src_idx = b * dim_c * src_dim_w + c * src_dim_w + src_x;
for (int x = src_x; x < src_x_up; x++) {
grad += grad_o[src_idx++];
}
grad_i[dst_idx] = grad;
dst_idx += dim_c * dst_dim_w;
}
}

static void upsample_nearest1d_out_cuda_template(
Tensor& output,
const Tensor& input,
const Tensor& input_,
IntArrayRef output_size) {
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_nearest1d_out_cuda", {input_arg, output_arg});

TORCH_CHECK(
Expand All @@ -106,35 +92,33 @@ static void upsample_nearest1d_out_cuda_template(

int output_width = output_size[0];

int nbatch = input.size(0);
int channels = input.size(1);
int input_width = input.size(2);
int nbatch = input_.size(0);
int channels = input_.size(1);
int input_width = input_.size(2);

upsample_1d_shape_check(
input, Tensor(), nbatch, channels, input_width, output_width);
input_, Tensor(), nbatch, channels, input_width, output_width);

AT_ASSERT(input_width > 0 && output_width > 0);

Tensor input = input_.contiguous();
output.resize_({input.size(0), input.size(1), output_width});
output.zero_();

const int num_kernels = output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
// upsample_1d_shape_check makes sure `nbatch != 0`
unsigned int n = output.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{cuda::ATenCeilDiv(n, bdim.x)};
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_nearest1d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = input.packed_accessor<scalar_t, 3>();
auto odata = output.packed_accessor<scalar_t, 3>();
auto idata = input.data<scalar_t>();
auto odata = output.data<scalar_t>();

upsample_nearest1d_out_frame<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(num_kernels, idata, odata);
upsample_nearest1d_out_frame<scalar_t><<<gdim, bdim, 0, stream>>>(
idata, nbatch, channels, input_width, output_width, odata);
});

AT_CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -172,25 +156,23 @@ static void upsample_nearest1d_backward_out_cuda_template(

Tensor grad_output = grad_output_.contiguous();
grad_input.resize_({nbatch, channels, input_width});
grad_input.zero_();

const int num_kernels = output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
// upsample_1d_shape_check makes sure `nbatch != 0`
unsigned int n = grad_input.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{cuda::ATenCeilDiv(n, bdim.x)};
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;

auto idata = grad_input.packed_accessor<scalar_t, 3>();
auto odata = grad_output.packed_accessor<scalar_t, 3>();
auto idata = grad_input.data<scalar_t>();
auto odata = grad_output.data<scalar_t>();

upsample_nearest1d_backward_out_frame<scalar_t, accscalar_t>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(num_kernels, idata, odata);
<<<gdim, bdim, 0, stream>>>(
odata, nbatch, channels, output_width, input_width, idata);
});

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
Loading

0 comments on commit 427e298

Please sign in to comment.