Skip to content

Commit

Permalink
roi_pool bw (failed prec)
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jun 13, 2023
1 parent 4e47e24 commit da30d23
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 25 deletions.
49 changes: 25 additions & 24 deletions torchvision/csrc/ops/mps/roi_pool_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,28 @@
at::Tensor roi_pool_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
int64_t width) {

using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor");

at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3};

at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});

float spatial_scale_f = static_cast<float>(spatial_scale);
auto num_rois = rois.size(0);

at::Tensor grad_input = at::zeros(
{batch_size, channels, height, width}, grad.options());
Expand All @@ -139,10 +140,11 @@
int64_t output_size = grad.numel();

at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
auto rois_ = rois.contiguous();
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();

id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
Expand All @@ -155,27 +157,26 @@
id<MTLComputePipelineState> binaryPSO = mps::binaryPipelineState(device, kernel);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_});
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_});

[computeEncoder setComputePipelineState:binaryPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
[computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
Expand All @@ -198,9 +199,9 @@
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
//m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
// TORCH_FN(roi_pool_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}

} // namespace ops
Expand Down
75 changes: 74 additions & 1 deletion torchvision/csrc/ops/mps/vision_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ kernel void roi_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * argmax [[buffer(3)]], \
device int64_t * argmax_data [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
Expand All @@ -581,6 +581,79 @@ kernel void roi_pool<DTYPE>( \
REGISTER_ROI_POOL_OP(float);
REGISTER_ROI_POOL_OP(half);
template<typename T>
kernel void roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * argmax_data [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
constant int64_t & n_stride [[buffer(11)]],
constant int64_t & c_stride [[buffer(12)]],
constant int64_t & h_stride [[buffer(13)]],
constant int64_t & w_stride [[buffer(14)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
const int output_offset = n * n_stride + c * c_stride;
constant int64_t * argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
const int argmax = argmax_data_offset[ph * pooled_width + pw];
const int offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) {
device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + argmax);
// atomic_float data type is supported on Metal 3 onward.
// TODO: Use native atomic_fetch_add_explicit for Metal 3.
atomic_add_float(xAtomic, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE) \
template \
[[host_name("roi_pool_backward_" #DTYPE)]] \
kernel void roi_pool_backward<DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * argmax_data [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
constant int64_t & n_stride [[buffer(11)]], \
constant int64_t & c_stride [[buffer(12)]], \
constant int64_t & h_stride [[buffer(13)]], \
constant int64_t & w_stride [[buffer(14)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_ROI_POOL_BACKWARD_OP(float);
REGISTER_ROI_POOL_BACKWARD_OP(half);
)VISION_METAL";

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
Expand Down

0 comments on commit da30d23

Please sign in to comment.