From 195d03a06e6e1934ad682f4da0df12cdfd65c42a Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 18:13:05 +0800 Subject: [PATCH] ps_roi_align fw --- .../csrc/ops/mps/ps_roi_align_kernel.mm | 215 ++++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 118 +++++++++- 2 files changed, 321 insertions(+), 12 deletions(-) create mode 100644 torchvision/csrc/ops/mps/ps_roi_align_kernel.mm diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm new file mode 100644 index 00000000000..9e63067d2cb --- /dev/null +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -0,0 +1,215 @@ +#include +#include +#include "vision_kernels.h" +#include "mps_helpers.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with the one in metal kernel. +int const threadsPerBlock = 512; + +std::tuple ps_roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kLong)); + + int64_t output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros( + {batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_forward_kernel)); + //m.impl( + // TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + // TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 1caf2a73a68..c83ced529ea 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -36,24 +36,17 @@ void atomic_add_float( device atomic_uint* atom_var, const float val ) float fetched_float, assigning_float; fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - fetched_float = *( (thread float*) &fetched_uint ); assigning_float = fetched_float + val; - assigning_uint = *( (thread uint*) &assigning_float ); while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) { - - uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - - float fetched_float_again = *( (thread float*) &fetched_uint_again ); - - fetched_float = *( (thread float*) &(fetched_uint) ); - - assigning_float = fetched_float_again + fetched_float; - - assigning_uint = *( (thread uint*) &assigning_float ); + uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); + float fetched_float_again = *( (thread float*) &fetched_uint_again ); + fetched_float = *( (thread float*) &(fetched_uint) ); + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float ); } } @@ -654,6 +647,107 @@ kernel void roi_pool_backward( \ REGISTER_ROI_POOL_BACKWARD_OP(float); REGISTER_ROI_POOL_BACKWARD_OP(half); +template +kernel void ps_roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c_out = (index / pooled_width / pooled_height) % channels_out; + int n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_PS_ROI_ALIGN_OP(float); +REGISTER_PS_ROI_ALIGN_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) {