Skip to content

Commit

Permalink
Replaced gpuAtomicAdd by fastAtomicAdd
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed May 22, 2023
1 parent fc838ad commit 76e3419
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 43 deletions.
5 changes: 3 additions & 2 deletions torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "cuda_helpers.h"

Expand Down Expand Up @@ -372,6 +372,7 @@ __global__ void deformable_col2im_kernel(

const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const index_t grad_im_numel = width * height * channels * batch_sz;

for (index_t dy = -1; dy <= 1; dy++) {
for (index_t dx = -1; dx <= 1; dx++) {
Expand All @@ -381,7 +382,7 @@ __global__ void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
at::native::fastAtomicAdd(grad_im, grad_pos, grad_im_numel, mask_value * weight * col[index], true);
}
}
}
Expand Down
24 changes: 15 additions & 9 deletions torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "cuda_helpers.h"

Expand Down Expand Up @@ -212,7 +212,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
int sampling_ratio,
int channels_out,
T* grad_input,
const T* rois) {
const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
Expand All @@ -235,8 +236,6 @@ __global__ void ps_roi_align_backward_kernel_impl(
T bin_size_w = roi_width / static_cast<T>(pooled_width);

int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;

// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
Expand All @@ -252,6 +251,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;

const int offset = (roi_batch_ind * channels + c_in) * height * width;

for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
Expand Down Expand Up @@ -285,10 +286,14 @@ __global__ void ps_roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1);
gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2);
gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3);
gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4);
at::native::fastAtomicAdd(
grad_input, offset + y_low * width + x_low, memory_span, static_cast<T>(g1), true);
at::native::fastAtomicAdd(
grad_input, offset + y_low * width + x_high, memory_span, static_cast<T>(g2), true);
at::native::fastAtomicAdd(
grad_input, offset + y_high * width + x_low, memory_span, static_cast<T>(g3), true);
at::native::fastAtomicAdd(
grad_input, offset + y_high * width + x_high, memory_span, static_cast<T>(g4), true);
} // if
} // ix
} // iy
Expand Down Expand Up @@ -430,7 +435,8 @@ at::Tensor ps_roi_align_backward_kernel(
sampling_ratio,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
rois_.data_ptr<scalar_t>(),
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
Expand Down
14 changes: 8 additions & 6 deletions torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "cuda_helpers.h"

Expand Down Expand Up @@ -91,7 +91,8 @@ __global__ void ps_roi_pool_backward_kernel_impl(
int pooled_width,
int channels_out,
T* grad_input,
const T* rois) {
const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
Expand Down Expand Up @@ -124,14 +125,14 @@ __global__ void ps_roi_pool_backward_kernel_impl(
bool is_empty = (hend <= hstart) || (wend <= wstart);

int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;

const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val);
at::native::fastAtomicAdd(grad_input, offset + grad_input_index, memory_span, diff_val, true);
}
}
}
Expand Down Expand Up @@ -269,7 +270,8 @@ at::Tensor ps_roi_pool_backward_kernel(
pooled_width,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
rois_.data_ptr<scalar_t>(),
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
Expand Down
31 changes: 16 additions & 15 deletions torchvision/csrc/ops/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "cuda_helpers.h"

Expand Down Expand Up @@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl(
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
Expand Down Expand Up @@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);

// We need to index the gradient using the tensor strides to access the
// correct values.
int output_offset = n * n_stride + c * c_stride;
const int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
Expand All @@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl(
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

const int input_offset = (roi_batch_ind * channels + c) * height * width;

for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
Expand Down Expand Up @@ -301,14 +301,14 @@ __global__ void roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
gpuAtomicAdd(
offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
gpuAtomicAdd(
offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
gpuAtomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
at::native::fastAtomicAdd(
grad_input, input_offset + y_low * width + x_low, memory_span, static_cast<T>(g1), true);
at::native::fastAtomicAdd(
grad_input, input_offset + y_low * width + x_high, memory_span, static_cast<T>(g2), true);
at::native::fastAtomicAdd(
grad_input, input_offset + y_high * width + x_low, memory_span, static_cast<T>(g3), true);
at::native::fastAtomicAdd(
grad_input, input_offset + y_high * width + x_high, memory_span, static_cast<T>(g4), true);
} // if
} // ix
} // iy
Expand Down Expand Up @@ -442,7 +442,8 @@ at::Tensor roi_align_backward_kernel(
n_stride,
c_stride,
h_stride,
w_stride);
w_stride,
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
Expand Down
21 changes: 10 additions & 11 deletions torchvision/csrc/ops/cuda/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "cuda_helpers.h"

Expand Down Expand Up @@ -94,7 +94,8 @@ __global__ void roi_pool_backward_kernel_impl(
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
Expand All @@ -104,19 +105,16 @@ __global__ void roi_pool_backward_kernel_impl(

const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);

int output_offset = n * n_stride + c * c_stride;
const int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
int argmax = argmax_data_offset[ph * pooled_width + pw];
const int argmax = argmax_data_offset[ph * pooled_width + pw];
const int offset = (roi_batch_ind * channels + c) * height * width;

if (argmax != -1) {
gpuAtomicAdd(
grad_input_offset + argmax,
static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride]));
at::native::fastAtomicAdd(grad_input, offset + argmax, memory_span,
static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]), true);
}
}
}
Expand Down Expand Up @@ -253,7 +251,8 @@ at::Tensor roi_pool_backward_kernel(
n_stride,
c_stride,
h_stride,
w_stride);
w_stride,
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
Expand Down

0 comments on commit 76e3419

Please sign in to comment.