From da9c2dec9b023ca573df8a1be557c46e7959d4cb Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 31 May 2023 00:26:57 +0800 Subject: [PATCH 01/28] Draft --- CMakeLists.txt | 9 ++ setup.py | 6 ++ test/common_utils.py | 5 + torchvision/csrc/ops/mps/nms_kernel.mm | 107 ++++++++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 96 +++++++++++++++++++ 5 files changed, 223 insertions(+) create mode 100644 torchvision/csrc/ops/mps/nms_kernel.mm create mode 100644 torchvision/csrc/ops/mps/vision_kernels.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 405f947c233..0cd485d7a24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17) file(STRINGS version.txt TORCHVISION_VERSION) option(WITH_CUDA "Enable CUDA support" OFF) +option(WITH_MPS "Enable MPS support" OFF) option(WITH_PNG "Enable features requiring LibPNG." ON) option(WITH_JPEG "Enable features requiring LibJPEG." ON) option(USE_PYTHON "Link to Python when building" OFF) @@ -15,6 +16,11 @@ if(WITH_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") endif() +if(WITH_MPS) + enable_language(OBJC OBJCXX) + add_definitions(-DWITH_MPS) +endif() + find_package(Torch REQUIRED) if (WITH_PNG) @@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() +if(WITH_MPS) + list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps) +endif() FOREACH(DIR ${ALLOW_LISTED}) file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*) diff --git a/setup.py b/setup.py index 8b8ddcde1b9..8124783feb5 100644 --- a/setup.py +++ b/setup.py @@ -137,10 +137,13 @@ def get_extensions(): + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) + source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) print("Compiling extensions with following flags:") force_cuda = os.getenv("FORCE_CUDA", "0") == "1" print(f" FORCE_CUDA: {force_cuda}") + force_mps = os.getenv("FORCE_MPS", "0") == "1" + print(f" FORCE_MPS: {force_mps}") debug_mode = os.getenv("DEBUG", "0") == "1" print(f" DEBUG: {debug_mode}") use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1" @@ -202,6 +205,9 @@ def get_extensions(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags + elif torch.backends.mps.is_available() or force_mps: + sources += source_mps + define_macros += [("WITH_MPS", None)] if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)] diff --git a/test/common_utils.py b/test/common_utils.py index 1d0b82a827c..f2c8125eddf 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -133,6 +133,11 @@ def needs_cuda(test_func): return pytest.mark.needs_cuda(test_func) +def needs_mps(test_func): + import pytest + + return pytest.mark.needs_mps(test_func) + def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm new file mode 100644 index 00000000000..e85acbc18d8 --- /dev/null +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -0,0 +1,107 @@ +//#include +#include +#include "vision_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor nms_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + + using namespace at::native::mps; + TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); + TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); + + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + + //at::Tensor input = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); + //at::Tensor other = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); + //at::Tensor out = at::zeros({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + int dets_num = dets.size(0); + float iou_threshold_f = static_cast(iou_threshold); + + //TODO: ceil_div + //const int col_blocks = ceil_div(dets_num, threadsPerBlock); + //at::Tensor mask = + // at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + at::Tensor mask = + at::empty({dets_num}, dets.options().dtype(at::kLong)); + + id inputBuffer = getMTLBufferStorage(dets_sorted); + id outputBuffer = getMTLBufferStorage(mask); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + //const uint32_t nDim = iter.ndim(); + //constexpr uint32_t nOffsets = 3; + const uint32_t numThreads = dets_num; + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + NSError* error = nil; + id computeEncoder = mpsStream->commandEncoder(); + MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); + + + const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + //getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other}); + + [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; + [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; + [computeEncoder setBytes:&dets_num length:sizeof(int) atIndex:2]; + [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; + + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > numThreads) { + tgSize = numThreads; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; + + //getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return mask; + +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h new file mode 100644 index 00000000000..9cd22cae7ac --- /dev/null +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -0,0 +1,96 @@ +#include + +namespace vision { +namespace ops { + +namespace mps { + +static const char* METAL_VISION = R"VISION_METAL( + +#include +using namespace metal; + +template +bool IoU( + constant T & a, + constant T & b, + scalar_t threshold) { + auto xx1 = max(a.x, b.x); + auto yy1 = max(a.y, b.y); + auto xx2 = min(a.z, b.z); + auto yy2 = min(a.w, b.w); + auto w = max(static_cast(0), xx2 - xx1); + auto h = max(static_cast(0), yy2 - yy1); + auto inter = w * h; + auto area_a = (a.z - a.x) * (a.w - a.y); + auto area_b = (b.z - b.x) * (b.w - b.y); + return (inter / (area_a + area_b - inter)) > threshold; +} + +template +kernel void nms(constant T * input [[buffer(0)]], + device int64_t * out [[buffer(1)]], + constant int & dets_num [[buffer(2)]], + constant float & iou_threshold [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + int t = 0; + for (int i = tid + 1; i < dets_num; i++){ + if (IoU(input[tid], input[i], iou_threshold)){ + t |= static_cast(1) << i; + } + } + out[tid] = static_cast(t); +} + +#define REGISTER_NMS_OP(DTYPE) \ +template \ +[[host_name("nms_" #DTYPE)]] \ +kernel void nms( \ + constant DTYPE ## 4 * input [[buffer(0)]], \ + device int64_t * out [[buffer(1)]], \ + constant int & dets_num [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_NMS_OP(float); +REGISTER_NMS_OP(half); + +)VISION_METAL"; + +static id compileBinaryOpsLibrary(id device) { + static id binaryLibrary = nil; + if (binaryLibrary) { + return binaryLibrary; + } + + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]); + return binaryLibrary; +} + +static id binaryPipelineState(id device, const std::string& kernel) { + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id binaryLib = compileBinaryOpsLibrary(device); + id binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} + +} +} +} // namespace \ No newline at end of file From 7f0d4ce4db5e7fc3a3047fd02d0280d4f62cde7f Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Fri, 9 Jun 2023 18:40:12 +0800 Subject: [PATCH 02/28] NMS f32 --- test/test_ops.py | 37 +++++++++++- torchvision/csrc/ops/cpu/nms_kernel.cpp | 4 +- torchvision/csrc/ops/mps/nms_kernel.mm | 66 ++++++++++++++++----- torchvision/csrc/ops/mps/vision_kernels.h | 70 +++++++++++++++++------ 4 files changed, 140 insertions(+), 37 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 463ebb333ff..4a5ba5f84f1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ import torch import torch.fx import torch.nn.functional as F -from common_utils import assert_equal, cpu_and_gpu, needs_cuda +from common_utils import assert_equal, cpu_and_gpu, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -722,6 +722,24 @@ def test_nms_cuda(self, iou, dtype=torch.float64): is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) + @needs_mps + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_mps(self, iou, dtype=torch.float32): + tol = 1e-3 if dtype is torch.half else 1e-5 + err_msg = "NMS incompatible between CPU and MPS for IoU={}" + + boxes, scores = self._create_tensors_with_iou(1000, iou) + r_cpu = ops.nms(boxes, scores, iou) + r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou) + + print(r_cpu.size(), r_mps.size()) + is_eq = torch.allclose(r_cpu, r_mps.cpu()) + if not is_eq: + # if the indices are not the same, ensure that it's because the scores + # are duplicate + is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol) + assert is_eq, err_msg.format(iou) + @needs_cuda @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) @@ -745,6 +763,23 @@ def test_nms_cuda_float16(self): keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) assert_equal(keep32, keep16) + @needs_mps + @pytest.mark.xfail + def test_nms_mps_float16(self): + boxes = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + ).to("mps") + scores = torch.tensor([0.6370, 0.7569, 0.3966]).to("mps") + + iou_thres = 0.2 + keep32 = ops.nms(boxes, scores, iou_thres) + keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) + assert_equal(keep32, keep16) + @pytest.mark.parametrize("seed", range(10)) def test_batched_nms_implementations(self, seed): """Make sure that both implementations of batched_nms yield identical results""" diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index c54d1f00148..50479066cbd 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); - TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); + TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); TORCH_CHECK( dets.scalar_type() == scores.scalar_type(), "dets should have the same type as scores"); diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index e85acbc18d8..f2c30fe29e4 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -1,12 +1,18 @@ -//#include +#include #include #include "vision_kernels.h" +#include +#include + namespace vision { namespace ops { namespace { +// This should be in sync with the one in metal kernel. +int const threadsPerBlock = sizeof(uint64_t) * 8; + at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, @@ -49,12 +55,9 @@ int dets_num = dets.size(0); float iou_threshold_f = static_cast(iou_threshold); - //TODO: ceil_div - //const int col_blocks = ceil_div(dets_num, threadsPerBlock); - //at::Tensor mask = - // at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + const int col_blocks = (dets_num + threadsPerBlock - 1) / threadsPerBlock; at::Tensor mask = - at::empty({dets_num}, dets.options().dtype(at::kLong)); + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); id inputBuffer = getMTLBufferStorage(dets_sorted); id outputBuffer = getMTLBufferStorage(mask); @@ -65,16 +68,14 @@ const uint32_t numThreads = dets_num; dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { - NSError* error = nil; id computeEncoder = mpsStream->commandEncoder(); - MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); - + MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); id binaryPSO = mps::binaryPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - //getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other}); + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {dets, scores}); [computeEncoder setComputePipelineState:binaryPSO]; [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; @@ -82,18 +83,53 @@ [computeEncoder setBytes:&dets_num length:sizeof(int) atIndex:2]; [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; + // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > numThreads) { - tgSize = numThreads; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; } MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - //getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(binaryPSO); } }); - return mask; + + // out[det] = + int64_t num_to_keep = 0; + + at::Tensor mask_cpu = mask.to(at::kCPU); // tid or + unsigned long long* mask_host = + (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = + at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + for (int i = 0; i < dets_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + //std::cout << "remv:" << remv[nblock] << "cur:" << (1ULL << i) << std::endl; + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } else { + //std::cout << "SKIP at:" << i << std::endl; + } + } + //std::cout << "NTK: " << num_to_keep << std::endl; + //std::cout << "SUM mask: " << mask_cpu.sum() << std::endl; + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) + .to(order_t.device(), keep.scalar_type())}); } diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 9cd22cae7ac..4bfd2603e2a 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -10,11 +10,18 @@ static const char* METAL_VISION = R"VISION_METAL( #include using namespace metal; +constant uint threadsPerBlock = sizeof(uint64_t) * 8; + +template +inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} + template -bool IoU( +bool inline IoU( constant T & a, - constant T & b, - scalar_t threshold) { + threadgroup T & b, + const float threshold) { auto xx1 = max(a.x, b.x); auto yy1 = max(a.y, b.y); auto xx2 = min(a.z, b.z); @@ -28,29 +35,54 @@ bool IoU( } template -kernel void nms(constant T * input [[buffer(0)]], - device int64_t * out [[buffer(1)]], - constant int & dets_num [[buffer(2)]], +kernel void nms(constant T * dev_boxes [[buffer(0)]], + device uint64_t * mask [[buffer(1)]], + constant int & n_boxes [[buffer(2)]], constant float & iou_threshold [[buffer(3)]], - uint tid [[thread_position_in_grid]]) { - int t = 0; - for (int i = tid + 1; i < dets_num; i++){ - if (IoU(input[tid], input[i], iou_threshold)){ - t |= static_cast(1) << i; + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + + const uint row_start = tgid.y; + const uint col_start = tgid.x; + const uint tid = tid2.x; + const uint row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const uint col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + threadgroup T block_boxes[threadsPerBlock]; + block_boxes[tid] = dev_boxes[threadsPerBlock * col_start + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < row_size) { + const uint cur_box_idx = threadsPerBlock * row_start + tid; + uint64_t t = 0; + uint start = 0; + + if (row_start == col_start) { + start = tid + 1; + } + + for (uint i = start; i < col_size; i++){ + if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ + t |= static_cast(1) << i; // discard 1 keep 0 + } } + const uint col_blocks = ceil_div(static_cast(n_boxes), threadsPerBlock); + mask[cur_box_idx * col_blocks + col_start] = t; } - out[tid] = static_cast(t); } #define REGISTER_NMS_OP(DTYPE) \ -template \ +template \ [[host_name("nms_" #DTYPE)]] \ -kernel void nms( \ - constant DTYPE ## 4 * input [[buffer(0)]], \ - device int64_t * out [[buffer(1)]], \ - constant int & dets_num [[buffer(2)]], \ - constant float & iou_threshold [[buffer(3)]], \ - uint tid [[thread_position_in_grid]]); +kernel void nms( \ + constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ + device uint64_t * mask [[buffer(1)]], \ + constant int & n_boxes [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); From c7c43dccc313373e7123167b82f88204a018689e Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 10 Jun 2023 16:30:27 +0800 Subject: [PATCH 03/28] roi_align fw --- test/common_utils.py | 2 + test/test_ops.py | 16 +- torchvision/csrc/ops/mps/mps_helpers.h | 4 + torchvision/csrc/ops/mps/roi_align_kernel.mm | 199 +++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 217 +++++++++++++++++++ 5 files changed, 431 insertions(+), 7 deletions(-) create mode 100644 torchvision/csrc/ops/mps/mps_helpers.h create mode 100644 torchvision/csrc/ops/mps/roi_align_kernel.mm diff --git a/test/common_utils.py b/test/common_utils.py index f2c8125eddf..c3fd743cf16 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -127,6 +127,8 @@ def cpu_and_gpu(): return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) +def cpu_and_gpu_and_mps(): + return cpu_and_gpu() + (pytest.param("mps", marks=pytest.mark.needs_mps),) def needs_cuda(test_func): import pytest # noqa diff --git a/test/test_ops.py b/test/test_ops.py index 4a5ba5f84f1..4194b4d760b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ import torch import torch.fx import torch.nn.functional as F -from common_utils import assert_equal, cpu_and_gpu, needs_cuda, needs_mps +from common_utils import assert_equal, cpu_and_gpu, cpu_and_gpu_and_mps, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -96,12 +96,14 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: class RoIOpTester(ABC): dtype = torch.float64 + mps_dtype = torch.float32 - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): - x_dtype = self.dtype if x_dtype is None else x_dtype - rois_dtype = self.dtype if rois_dtype is None else rois_dtype + dtype = self.mps_dtype if device == "mps" else self.dtype + x_dtype = dtype if x_dtype is None else x_dtype + rois_dtype = dtype if rois_dtype is None else rois_dtype pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS operations. n_channels = 2 * (pool_size**2) @@ -120,7 +122,7 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype gt_y = self.expected_fn( - x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=dtype, **kwargs ) tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 @@ -418,7 +420,7 @@ def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) @pytest.mark.parametrize("aligned", (True, False)) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): @@ -450,7 +452,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): ) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_backward(self, seed, device, contiguous, deterministic): diff --git a/torchvision/csrc/ops/mps/mps_helpers.h b/torchvision/csrc/ops/mps/mps_helpers.h new file mode 100644 index 00000000000..48afe765d93 --- /dev/null +++ b/torchvision/csrc/ops/mps/mps_helpers.h @@ -0,0 +1,4 @@ +template +constexpr inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} \ No newline at end of file diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm new file mode 100644 index 00000000000..32fe7a96520 --- /dev/null +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -0,0 +1,199 @@ +#include +#include +#include "vision_kernels.h" +#include "mps_helpers.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with the one in metal kernel. +int const threadsPerBlock = 512; + +at::Tensor roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return output; +} + +at::Tensor roi_align_backward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return output; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_forward_kernel)); + //m.impl( + // TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + // TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 4bfd2603e2a..4ec16503b7c 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -10,6 +10,15 @@ static const char* METAL_VISION = R"VISION_METAL( #include using namespace metal; +#define MPS_1D_KERNEL_LOOP_T(i, n, n_grids, index_t) \ + for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ + i += (tptg.x * n_grids)) + +#define MPS_1D_KERNEL_LOOP(i, n, n_grids) MPS_1D_KERNEL_LOOP_T(i, n, n_grids, int) + +// This should be in sync with the one in nms_kernel.mm. +// Since metal does not support dynamic array, +// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. constant uint threadsPerBlock = sizeof(uint64_t) * 8; template @@ -17,6 +26,60 @@ inline T ceil_div(T n, T m) { return (n + m - 1) / m; } +template +inline T bilinear_interpolate( + constant T* input, + int64_t height, + int64_t width, + T y, + T x, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + template bool inline IoU( constant T & a, @@ -87,6 +150,160 @@ kernel void nms( \ REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); +template +kernel void roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + output[index] = output_val; + } +} + +#define REGISTER_ROI_ALIGN_OP(DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_ROI_ALIGN_OP(float); +REGISTER_ROI_ALIGN_OP(half); + +template +kernel void roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * grad_input [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + + } +} + +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_ROI_ALIGN_BACKWARD_OP(float); +REGISTER_ROI_ALIGN_BACKWARD_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) { From ccde29cc8c50bd43908074649f6abdbbece37e3c Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 16:40:11 +0800 Subject: [PATCH 04/28] roi_align bw (failed) --- test/test_ops.py | 8 +- torchvision/csrc/ops/mps/roi_align_kernel.mm | 65 +++--- torchvision/csrc/ops/mps/vision_kernels.h | 207 ++++++++++++++++++- 3 files changed, 238 insertions(+), 42 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4194b4d760b..aa5e0c95a0d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -157,16 +157,18 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, seed, device, contiguous, deterministic=False): + dtype = self.mps_dtype if device == "mps" else self.dtype + torch.random.manual_seed(seed) pool_size = 2 - x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( - [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy) ) def func(z): diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index 32fe7a96520..afbd4ce3aaf 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -100,64 +100,67 @@ } at::Tensor roi_align_backward_kernel( - const at::Tensor& input, + const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, int64_t sampling_ratio, bool aligned) { using namespace at::native::mps; - TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); - at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "roi_align_forward_kernel"; - at::checkAllSameGPU(c, {input_t, rois_t}); - at::checkAllSameType(c, {input_t, rois_t}); + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); - int64_t num_rois = rois.size(0); - int64_t channels = input.size(1); - int64_t height = input.size(2); - int64_t width = input.size(3); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - - int64_t output_size = num_rois * pooled_height * pooled_width * channels; + at::Tensor grad_input = at::zeros( + {batch_size, channels, height, width}, grad.options()); - if (output.numel() == 0) { - return output; + if (grad.numel() == 0) { + return grad_input; } - auto input_ = input.contiguous(); + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); auto rois_ = rois.contiguous(); - id inputBuffer = getMTLBufferStorage(input_); + id inputBuffer = getMTLBufferStorage(grad); id roisBuffer = getMTLBufferStorage(rois_); - id outputBuffer = getMTLBufferStorage(output); + id outputBuffer = getMTLBufferStorage(grad_input); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); - const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); id binaryPSO = mps::binaryPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); [computeEncoder setComputePipelineState:binaryPSO]; // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; @@ -168,6 +171,10 @@ [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; @@ -181,7 +188,7 @@ getMPSProfiler().endProfileKernel(binaryPSO); } }); - return output; + return grad_input; } } // namespace @@ -190,9 +197,9 @@ m.impl( TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); - //m.impl( - // TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - // TORCH_FN(roi_align_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), + TORCH_FN(roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 4ec16503b7c..afadb777361 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -8,6 +8,7 @@ namespace mps { static const char* METAL_VISION = R"VISION_METAL( #include +#include using namespace metal; #define MPS_1D_KERNEL_LOOP_T(i, n, n_grids, index_t) \ @@ -21,11 +22,41 @@ using namespace metal; // we need to make it static instead of deriving it from [[threads_per_threadgroup]]. constant uint threadsPerBlock = sizeof(uint64_t) * 8; +// Utility functions + template inline T ceil_div(T n, T m) { return (n + m - 1) / m; } +// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 +void atomic_add_float( device atomic_uint* atom_var, const float val ) +{ + uint fetched_uint, assigning_uint; + float fetched_float, assigning_float; + + fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); + + fetched_float = *( (thread float*) &fetched_uint ); + + assigning_float = fetched_float + val; + + assigning_uint = *( (thread uint*) &assigning_float ); + + while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) { + + uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); + + float fetched_float_again = *( (thread float*) &fetched_uint_again ); + + fetched_float = *( (thread float*) &(fetched_uint) ); + + assigning_float = fetched_float_again + fetched_float; + + assigning_uint = *( (thread uint*) &assigning_float ); + } +} + template inline T bilinear_interpolate( constant T* input, @@ -80,6 +111,65 @@ inline T bilinear_interpolate( return val; } +template +void bilinear_interpolate_gradient( + int height, + int width, + T y, + T x, + thread T& w1, + thread T& w2, + thread T& w3, + thread T& w4, + thread int& x_low, + thread int& x_high, + thread int& y_low, + thread int& y_high, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + template bool inline IoU( constant T & a, @@ -266,16 +356,112 @@ kernel void roi_align_backward( constant int64_t & pooled_width [[buffer(8)]], constant int64_t & sampling_ratio [[buffer(9)]], constant bool & aligned [[buffer(10)]], - constant int64_t & n_stride [[buffer(11)]], - constant int64_t & c_stride [[buffer(12)]], - constant int64_t & h_stride [[buffer(13)]], - constant int64_t & w_stride [[buffer(14)]], + constant float & spatial_scale [[buffer(11)]], + constant int64_t & n_stride [[buffer(12)]], + constant int64_t & c_stride [[buffer(13)]], + constant int64_t & h_stride [[buffer(14)]], + constant int64_t & w_stride [[buffer(15)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { - - } + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We need to index the gradient using the tensor strides to access the + // correct values. + const int output_offset = n * n_stride + c * c_stride; + constant T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + const int input_offset = (roi_batch_ind * channels + c) * height * width; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + input_offset + y_low * width + x_low); + device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + input_offset + y_low * width + x_high); + device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + input_offset + y_high * width + x_low); + device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + input_offset + y_high * width + x_high); + + // atomic_float data type is supported on Metal 3 onward. + // TODO: Use native atomic_fetch_add_explicit for Metal 3. + atomic_add_float(xAtomic, static_cast(g1)); + atomic_add_float(yAtomic, static_cast(g2)); + atomic_add_float(zAtomic, static_cast(g3)); + atomic_add_float(wAtomic, static_cast(g4)); + + } // if + } // ix + } // iy + } // MPS_1D_KERNEL_LOOP } #define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE) \ @@ -293,10 +479,11 @@ kernel void roi_align_backward( \ constant int64_t & pooled_width [[buffer(8)]], \ constant int64_t & sampling_ratio [[buffer(9)]], \ constant bool & aligned [[buffer(10)]], \ - constant int64_t & n_stride [[buffer(11)]], \ - constant int64_t & c_stride [[buffer(12)]], \ - constant int64_t & h_stride [[buffer(13)]], \ - constant int64_t & w_stride [[buffer(14)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); From c930e5419d37f749fbc79ecb76ba9a71d6e7cdfe Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 17:14:07 +0800 Subject: [PATCH 05/28] roi_pool fw --- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 207 ++++++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 90 +++++++++ 2 files changed, 297 insertions(+) create mode 100644 torchvision/csrc/ops/mps/roi_pool_kernel.mm diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm new file mode 100644 index 00000000000..2b08a8b7bfd --- /dev/null +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -0,0 +1,207 @@ +#include +#include +#include "vision_kernels.h" +#include "mps_helpers.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with the one in metal kernel. +int const threadsPerBlock = 512; + +std::tuple roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, + input.options().dtype(at::kLong)); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id argmaxBuffer = getMTLBufferStorage(argmax); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros( + {batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_forward_kernel)); + //m.impl( + // TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + // TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index afadb777361..6ef6c89ccd1 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -491,6 +491,96 @@ kernel void roi_align_backward( \ REGISTER_ROI_ALIGN_BACKWARD_OP(float); REGISTER_ROI_ALIGN_BACKWARD_OP(half); +template +kernel void roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * argmax [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), static_cast(height)); + hend = min(max(hend + roi_start_h, 0), static_cast(height)); + wstart = min(max(wstart + roi_start_w, 0), static_cast(width)); + wend = min(max(wend + roi_start_w, 0), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; + } + } + } + output[index] = maxval; + argmax[index] = maxidx; + } +} + +#define REGISTER_ROI_POOL_OP(DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_ROI_POOL_OP(float); +REGISTER_ROI_POOL_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) { From 3305cc151d558d444d992075da930e7455124b70 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 17:33:29 +0800 Subject: [PATCH 06/28] roi_pool bw (failed prec) --- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 49 +++++++------- torchvision/csrc/ops/mps/vision_kernels.h | 75 ++++++++++++++++++++- 2 files changed, 99 insertions(+), 25 deletions(-) diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index 2b08a8b7bfd..d37d679dd4f 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -103,27 +103,28 @@ at::Tensor roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, + const at::Tensor& argmax, double spatial_scale, int64_t pooled_height, int64_t pooled_width, int64_t batch_size, int64_t channels, int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; at::CheckedFrom c = "roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); at::checkAllSameType(c, {grad_t, rois_t}); float spatial_scale_f = static_cast(spatial_scale); + auto num_rois = rois.size(0); at::Tensor grad_input = at::zeros( {batch_size, channels, height, width}, grad.options()); @@ -139,10 +140,11 @@ int64_t output_size = grad.numel(); at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); - auto rois_ = rois.contiguous(); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); id inputBuffer = getMTLBufferStorage(grad); id roisBuffer = getMTLBufferStorage(rois_); + id argmaxBuffer = getMTLBufferStorage(argmax_); id outputBuffer = getMTLBufferStorage(grad_input); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); @@ -155,27 +157,26 @@ id binaryPSO = mps::binaryPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_}); [computeEncoder setComputePipelineState:binaryPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; @@ -198,9 +199,9 @@ m.impl( TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); - //m.impl( - // TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - // TORCH_FN(roi_pool_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), + TORCH_FN(roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 6ef6c89ccd1..1caf2a73a68 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -566,7 +566,7 @@ kernel void roi_pool( \ constant DTYPE * input [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * output [[buffer(2)]], \ - device int64_t * argmax [[buffer(3)]], \ + device int64_t * argmax_data [[buffer(3)]], \ constant int64_t & output_size [[buffer(4)]], \ constant int64_t & channels [[buffer(5)]], \ constant int64_t & height [[buffer(6)]], \ @@ -581,6 +581,79 @@ kernel void roi_pool( \ REGISTER_ROI_POOL_OP(float); REGISTER_ROI_POOL_OP(half); +template +kernel void roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * argmax_data [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + const int output_offset = n * n_stride + c * c_stride; + constant int64_t * argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + const int argmax = argmax_data_offset[ph * pooled_width + pw]; + const int offset = (roi_batch_ind * channels + c) * height * width; + + if (argmax != -1) { + device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + argmax); + // atomic_float data type is supported on Metal 3 onward. + // TODO: Use native atomic_fetch_add_explicit for Metal 3. + atomic_add_float(xAtomic, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_ROI_POOL_BACKWARD_OP(float); +REGISTER_ROI_POOL_BACKWARD_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) { From 0f8d2c39e1f157feca144b806d6061dc21a4fd03 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 18:13:05 +0800 Subject: [PATCH 07/28] ps_roi_align fw --- .../csrc/ops/mps/ps_roi_align_kernel.mm | 215 ++++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 118 +++++++++- 2 files changed, 321 insertions(+), 12 deletions(-) create mode 100644 torchvision/csrc/ops/mps/ps_roi_align_kernel.mm diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm new file mode 100644 index 00000000000..9e63067d2cb --- /dev/null +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -0,0 +1,215 @@ +#include +#include +#include "vision_kernels.h" +#include "mps_helpers.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with the one in metal kernel. +int const threadsPerBlock = 512; + +std::tuple ps_roi_align_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kLong)); + + int64_t output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros( + {batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_forward_kernel)); + //m.impl( + // TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + // TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 1caf2a73a68..c83ced529ea 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -36,24 +36,17 @@ void atomic_add_float( device atomic_uint* atom_var, const float val ) float fetched_float, assigning_float; fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - fetched_float = *( (thread float*) &fetched_uint ); assigning_float = fetched_float + val; - assigning_uint = *( (thread uint*) &assigning_float ); while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) { - - uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - - float fetched_float_again = *( (thread float*) &fetched_uint_again ); - - fetched_float = *( (thread float*) &(fetched_uint) ); - - assigning_float = fetched_float_again + fetched_float; - - assigning_uint = *( (thread uint*) &assigning_float ); + uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); + float fetched_float_again = *( (thread float*) &fetched_uint_again ); + fetched_float = *( (thread float*) &(fetched_uint) ); + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float ); } } @@ -654,6 +647,107 @@ kernel void roi_pool_backward( \ REGISTER_ROI_POOL_BACKWARD_OP(float); REGISTER_ROI_POOL_BACKWARD_OP(half); +template +kernel void ps_roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c_out = (index / pooled_width / pooled_height) % channels_out; + int n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_PS_ROI_ALIGN_OP(float); +REGISTER_PS_ROI_ALIGN_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) { From e157c7c87e29c95f7816bfca12165b160a98988e Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 19:21:54 +0800 Subject: [PATCH 08/28] ps_roi_align bw (failed prec) --- .../csrc/ops/mps/ps_roi_align_kernel.mm | 53 +++---- torchvision/csrc/ops/mps/vision_kernels.h | 132 ++++++++++++++++++ 2 files changed, 159 insertions(+), 26 deletions(-) diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 9e63067d2cb..1803ec680f0 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -111,29 +111,30 @@ at::Tensor ps_roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, + const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, + int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor grad_input = at::zeros( + auto grad_input = at::zeros( {batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { @@ -146,11 +147,14 @@ int64_t w_stride = grad.stride(3); int64_t output_size = grad.numel(); + int64_t channels_out = channels / (pooled_height * pooled_width); + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); - auto rois_ = rois.contiguous(); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - id inputBuffer = getMTLBufferStorage(grad); + id inputBuffer = getMTLBufferStorage(grad_); id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); id outputBuffer = getMTLBufferStorage(grad_input); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); @@ -167,23 +171,20 @@ [computeEncoder setComputePipelineState:binaryPSO]; // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; @@ -206,9 +207,9 @@ m.impl( TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); - //m.impl( - // TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - // TORCH_FN(ps_roi_align_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index c83ced529ea..f8f5033bcd8 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -748,6 +748,138 @@ kernel void ps_roi_align( \ REGISTER_PS_ROI_ALIGN_OP(float); REGISTER_PS_ROI_ALIGN_OP(half); +template +kernel void ps_roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const int offset = (roi_batch_ind * channels + c_in) * height * width; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low); + device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high); + device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low); + device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high); + + // atomic_float data type is supported on Metal 3 onward. + // TODO: Use native atomic_fetch_add_explicit for Metal 3. + atomic_add_float(xAtomic, static_cast(g1)); + atomic_add_float(yAtomic, static_cast(g2)); + atomic_add_float(zAtomic, static_cast(g3)); + atomic_add_float(wAtomic, static_cast(g4)); + } // if + } // ix + } // iy + } +} + +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) { From 160d5b576fcadb001b576ee384e997dc2b0241cc Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 14 Jun 2023 17:30:02 +0800 Subject: [PATCH 09/28] Several improvements --- torchvision/csrc/ops/mps/mps_helpers.h | 2 + torchvision/csrc/ops/mps/nms_kernel.mm | 31 ++-- .../csrc/ops/mps/ps_roi_align_kernel.mm | 8 -- torchvision/csrc/ops/mps/roi_align_kernel.mm | 3 - torchvision/csrc/ops/mps/roi_pool_kernel.mm | 4 - torchvision/csrc/ops/mps/vision_kernels.h | 133 ++++++++---------- torchvision/ops/roi_align.py | 2 +- 7 files changed, 72 insertions(+), 111 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_helpers.h b/torchvision/csrc/ops/mps/mps_helpers.h index 48afe765d93..a18f51a900f 100644 --- a/torchvision/csrc/ops/mps/mps_helpers.h +++ b/torchvision/csrc/ops/mps/mps_helpers.h @@ -1,3 +1,5 @@ +constexpr int threadsPerBlock = 512; + template constexpr inline T ceil_div(T n, T m) { return (n + m - 1) / m; diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index f2c30fe29e4..e697ef80fe3 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -10,8 +10,8 @@ namespace { -// This should be in sync with the one in metal kernel. -int const threadsPerBlock = sizeof(uint64_t) * 8; +// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. +constexpr int nmsThreadsPerBlock = sizeof(uint64_t) * 8; at::Tensor nms_kernel( const at::Tensor& dets, @@ -41,10 +41,6 @@ " and ", scores.size(0)) - //at::Tensor input = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); - //at::Tensor other = at::arange({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); - //at::Tensor out = at::zeros({10}, at::kFloat, c10::nullopt, at::kMPS, c10::nullopt); - if (dets.numel() == 0) { return at::empty({0}, dets.options().dtype(at::kLong)); } @@ -55,7 +51,7 @@ int dets_num = dets.size(0); float iou_threshold_f = static_cast(iou_threshold); - const int col_blocks = (dets_num + threadsPerBlock - 1) / threadsPerBlock; + const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); @@ -63,9 +59,6 @@ id outputBuffer = getMTLBufferStorage(mask); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); - //const uint32_t nDim = iter.ndim(); - //constexpr uint32_t nOffsets = 3; - const uint32_t numThreads = dets_num; dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); @@ -85,8 +78,8 @@ // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; + if (tgSize > nmsThreadsPerBlock) { + tgSize = nmsThreadsPerBlock; } MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); @@ -96,10 +89,9 @@ } }); - // out[det] = int64_t num_to_keep = 0; - at::Tensor mask_cpu = mask.to(at::kCPU); // tid or + at::Tensor mask_cpu = mask.to(at::kCPU); unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); @@ -111,26 +103,21 @@ int64_t* keep_out = keep.data_ptr(); for (int i = 0; i < dets_num; i++) { - int nblock = i / threadsPerBlock; - int inblock = i % threadsPerBlock; + int nblock = i / nmsThreadsPerBlock; + int inblock = i % nmsThreadsPerBlock; - //std::cout << "remv:" << remv[nblock] << "cur:" << (1ULL << i) << std::endl; if (!(remv[nblock] & (1ULL << inblock))) { keep_out[num_to_keep++] = i; unsigned long long* p = mask_host + i * col_blocks; for (int j = nblock; j < col_blocks; j++) { remv[j] |= p[j]; } - } else { - //std::cout << "SKIP at:" << i << std::endl; } } - //std::cout << "NTK: " << num_to_keep << std::endl; - //std::cout << "SUM mask: " << mask_cpu.sum() << std::endl; + return order_t.index( {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) .to(order_t.device(), keep.scalar_type())}); - } } // namespace diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 1803ec680f0..15d28f23b90 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -11,9 +11,6 @@ namespace { -// This should be in sync with the one in metal kernel. -int const threadsPerBlock = 512; - std::tuple ps_roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, @@ -141,12 +138,7 @@ return grad_input; } - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); int64_t output_size = grad.numel(); - int64_t channels_out = channels / (pooled_height * pooled_width); at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index afbd4ce3aaf..d135709bb8b 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -11,9 +11,6 @@ namespace { -// This should be in sync with the one in metal kernel. -int const threadsPerBlock = 512; - at::Tensor roi_align_forward_kernel( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index d37d679dd4f..85a3b356598 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -11,9 +11,6 @@ namespace { -// This should be in sync with the one in metal kernel. -int const threadsPerBlock = 512; - std::tuple roi_pool_forward_kernel( const at::Tensor& input, const at::Tensor& rois, @@ -124,7 +121,6 @@ at::checkAllSameType(c, {grad_t, rois_t}); float spatial_scale_f = static_cast(spatial_scale); - auto num_rois = rois.size(0); at::Tensor grad_input = at::zeros( {batch_size, channels, height, width}, grad.options()); diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index f8f5033bcd8..a4d250db651 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -11,43 +11,46 @@ static const char* METAL_VISION = R"VISION_METAL( #include using namespace metal; +/*----------Macros----------*/ + #define MPS_1D_KERNEL_LOOP_T(i, n, n_grids, index_t) \ for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ i += (tptg.x * n_grids)) #define MPS_1D_KERNEL_LOOP(i, n, n_grids) MPS_1D_KERNEL_LOOP_T(i, n, n_grids, int) -// This should be in sync with the one in nms_kernel.mm. -// Since metal does not support dynamic array, -// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. -constant uint threadsPerBlock = sizeof(uint64_t) * 8; - -// Utility functions +/*----------Utils----------*/ template inline T ceil_div(T n, T m) { return (n + m - 1) / m; } -// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 -void atomic_add_float( device atomic_uint* atom_var, const float val ) +template +void atomic_add_float( device T* data_ptr, const float val) { - uint fetched_uint, assigning_uint; - float fetched_float, assigning_float; - - fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - fetched_float = *( (thread float*) &fetched_uint ); - - assigning_float = fetched_float + val; +#if __METAL_VERSION__ >= 300 + device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +#else + // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + uint fetched_uint, assigning_uint; + float fetched_float, assigning_float; + + fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); + fetched_float = *( (thread float*) &fetched_uint ); + + assigning_float = fetched_float + val; + assigning_uint = *( (thread uint*) &assigning_float ); + + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed)) != 0) { + uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed); + float fetched_float_again = *( (thread float*) &fetched_uint_again ); + fetched_float = *( (thread float*) &(fetched_uint) ); + assigning_float = fetched_float_again + fetched_float; assigning_uint = *( (thread uint*) &assigning_float ); - - while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) { - uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - float fetched_float_again = *( (thread float*) &fetched_uint_again ); - fetched_float = *( (thread float*) &(fetched_uint) ); - assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float ); - } + } +#endif } template @@ -180,6 +183,13 @@ bool inline IoU( return (inter / (area_a + area_b - inter)) > threshold; } +/*----------Kernels----------*/ + +// This should be in sync with the one in nms_kernel.mm. +// Since metal does not support dynamic array, +// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. +constant uint nmsThreadsPerBlock = sizeof(uint64_t) * 8; + template kernel void nms(constant T * dev_boxes [[buffer(0)]], device uint64_t * mask [[buffer(1)]], @@ -192,16 +202,16 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], const uint col_start = tgid.x; const uint tid = tid2.x; const uint row_size = - min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); const uint col_size = - min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); - threadgroup T block_boxes[threadsPerBlock]; - block_boxes[tid] = dev_boxes[threadsPerBlock * col_start + tid]; + threadgroup T block_boxes[nmsThreadsPerBlock]; + block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; threadgroup_barrier(mem_flags::mem_threadgroup); if (tid < row_size) { - const uint cur_box_idx = threadsPerBlock * row_start + tid; + const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; uint64_t t = 0; uint start = 0; @@ -214,7 +224,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], t |= static_cast(1) << i; // discard 1 keep 0 } } - const uint col_blocks = ceil_div(static_cast(n_boxes), threadsPerBlock); + const uint col_blocks = ceil_div(static_cast(n_boxes), nmsThreadsPerBlock); mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -230,9 +240,6 @@ kernel void nms( \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_NMS_OP(float); -REGISTER_NMS_OP(half); - template kernel void roi_align( constant T * input [[buffer(0)]], @@ -333,9 +340,6 @@ kernel void roi_align( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_ROI_ALIGN_OP(float); -REGISTER_ROI_ALIGN_OP(half); - template kernel void roi_align_backward( constant T * grad_output [[buffer(0)]], @@ -439,17 +443,10 @@ kernel void roi_align_backward( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + input_offset + y_low * width + x_low); - device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + input_offset + y_low * width + x_high); - device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + input_offset + y_high * width + x_low); - device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + input_offset + y_high * width + x_high); - - // atomic_float data type is supported on Metal 3 onward. - // TODO: Use native atomic_fetch_add_explicit for Metal 3. - atomic_add_float(xAtomic, static_cast(g1)); - atomic_add_float(yAtomic, static_cast(g2)); - atomic_add_float(zAtomic, static_cast(g3)); - atomic_add_float(wAtomic, static_cast(g4)); + atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); } // if } // ix @@ -481,9 +478,6 @@ kernel void roi_align_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_ROI_ALIGN_BACKWARD_OP(float); -REGISTER_ROI_ALIGN_BACKWARD_OP(half); - template kernel void roi_pool( constant T * input [[buffer(0)]], @@ -571,9 +565,6 @@ kernel void roi_pool( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_ROI_POOL_OP(float); -REGISTER_ROI_POOL_OP(half); - template kernel void roi_pool_backward( constant T * grad_output [[buffer(0)]], @@ -612,10 +603,7 @@ kernel void roi_pool_backward( const int offset = (roi_batch_ind * channels + c) * height * width; if (argmax != -1) { - device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + argmax); - // atomic_float data type is supported on Metal 3 onward. - // TODO: Use native atomic_fetch_add_explicit for Metal 3. - atomic_add_float(xAtomic, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); + atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); } } // MPS_1D_KERNEL_LOOP @@ -644,9 +632,6 @@ kernel void roi_pool_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_ROI_POOL_BACKWARD_OP(float); -REGISTER_ROI_POOL_BACKWARD_OP(half); - template kernel void ps_roi_align( constant T * input [[buffer(0)]], @@ -745,9 +730,6 @@ kernel void ps_roi_align( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -REGISTER_PS_ROI_ALIGN_OP(float); -REGISTER_PS_ROI_ALIGN_OP(half); - template kernel void ps_roi_align_backward( constant T * grad_output [[buffer(0)]], @@ -839,17 +821,10 @@ kernel void ps_roi_align_backward( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low); - device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high); - device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low); - device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high); - - // atomic_float data type is supported on Metal 3 onward. - // TODO: Use native atomic_fetch_add_explicit for Metal 3. - atomic_add_float(xAtomic, static_cast(g1)); - atomic_add_float(yAtomic, static_cast(g2)); - atomic_add_float(zAtomic, static_cast(g3)); - atomic_add_float(wAtomic, static_cast(g4)); + atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); } // if } // ix } // iy @@ -877,6 +852,18 @@ kernel void ps_roi_align_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); +REGISTER_NMS_OP(float); +REGISTER_NMS_OP(half); +REGISTER_ROI_ALIGN_OP(float); +REGISTER_ROI_ALIGN_OP(half); +REGISTER_ROI_ALIGN_BACKWARD_OP(float); +REGISTER_ROI_ALIGN_BACKWARD_OP(half); +REGISTER_ROI_POOL_OP(float); +REGISTER_ROI_POOL_OP(half); +REGISTER_ROI_POOL_BACKWARD_OP(float); +REGISTER_ROI_POOL_BACKWARD_OP(half); +REGISTER_PS_ROI_ALIGN_OP(float); +REGISTER_PS_ROI_ALIGN_OP(half); REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index be8ec8aea74..3e839b05526 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -232,7 +232,7 @@ def roi_align( if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): - if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): + if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops() return torch.ops.torchvision.roi_align( From 40ea5257f032de9fa8b2478c8da418d33e9493a9 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 17 Jun 2023 15:38:26 +0800 Subject: [PATCH 10/28] ps_roi_pool fw --- .../csrc/ops/mps/ps_roi_pool_kernel.mm | 208 ++++++++++++++++++ torchvision/csrc/ops/mps/vision_kernels.h | 91 ++++++++ 2 files changed, 299 insertions(+) create mode 100644 torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm new file mode 100644 index 00000000000..683b812b5b0 --- /dev/null +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -0,0 +1,208 @@ +#include +#include +#include "vision_kernels.h" +#include "mps_helpers.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_forward_kernel( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros( + {batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id argmaxBuffer = getMTLBufferStorage(argmax_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + + const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id binaryPSO = mps::binaryPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_}); + + [computeEncoder setComputePipelineState:binaryPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), + TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index a4d250db651..5eec0360341 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -852,6 +852,95 @@ kernel void ps_roi_align_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); +template +kernel void ps_roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c_out = (index / (pooled_width * pooled_height)) % channels_out; + int n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w, 1); + int roi_height = max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), static_cast(height - 1)); + hend = min(max(hend + roi_start_h, 0), static_cast(height - 1)); + wstart = min(max(wstart + roi_start_w, 0), static_cast(width - 1)); + wend = min(max(wend + roi_start_w, 0), static_cast(width - 1)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_POOL_OP(DTYPE) \ +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); REGISTER_ROI_ALIGN_OP(float); @@ -866,6 +955,8 @@ REGISTER_PS_ROI_ALIGN_OP(float); REGISTER_PS_ROI_ALIGN_OP(half); REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); +REGISTER_PS_ROI_POOL_OP(float); +REGISTER_PS_ROI_POOL_OP(half); )VISION_METAL"; From 2c20036fd70c57876b02274490e375bc432e45fa Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 17 Jun 2023 16:00:20 +0800 Subject: [PATCH 11/28] ps_roi_pool bw --- .../csrc/ops/mps/ps_roi_pool_kernel.mm | 35 ++++---- torchvision/csrc/ops/mps/vision_kernels.h | 87 +++++++++++++++++++ 2 files changed, 102 insertions(+), 20 deletions(-) diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 683b812b5b0..20821f5fbd2 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -104,7 +104,7 @@ at::Tensor ps_roi_pool_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, - const at::Tensor& argmax, + const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, @@ -116,35 +116,33 @@ using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); - TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "ps_roi_pool_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor grad_input = at::zeros( + auto num_rois = rois.size(0); + auto grad_input = at::zeros( {batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { return grad_input; } - int64_t n_stride = grad.stride(0); - int64_t c_stride = grad.stride(1); - int64_t h_stride = grad.stride(2); - int64_t w_stride = grad.stride(3); + int64_t channels_out = channels / (pooled_height * pooled_width); int64_t output_size = grad.numel(); at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); - auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - id inputBuffer = getMTLBufferStorage(grad); + id inputBuffer = getMTLBufferStorage(grad_); id roisBuffer = getMTLBufferStorage(rois_); - id argmaxBuffer = getMTLBufferStorage(argmax_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); id outputBuffer = getMTLBufferStorage(grad_input); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); @@ -157,13 +155,13 @@ id binaryPSO = mps::binaryPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_}); + getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad_, rois_, channel_mapping}); [computeEncoder setComputePipelineState:binaryPSO]; // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2]; [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; @@ -172,11 +170,8 @@ [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index 5eec0360341..07612fe2d6e 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -941,6 +941,91 @@ kernel void ps_roi_pool( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); +template +kernel void ps_roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w, 1); + int roi_height = max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), static_cast(height)); + hend = min(max(hend + roi_start_h, 0), static_cast(height)); + wstart = min(max(wstart + roi_start_w, 0), static_cast(width)); + wend = min(max(wend + roi_start_w, 0), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int c_in = channel_mapping[index]; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + + const int offset = (roi_batch_ind * channels + c_in) * height * width; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width + w; + atomic_add_float(grad_input + offset + grad_input_index, diff_val); + } + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); REGISTER_ROI_ALIGN_OP(float); @@ -957,6 +1042,8 @@ REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); REGISTER_PS_ROI_POOL_OP(float); REGISTER_PS_ROI_POOL_OP(half); +REGISTER_PS_ROI_POOL_BACKWARD_OP(float); +REGISTER_PS_ROI_POOL_BACKWARD_OP(half); )VISION_METAL"; From a427c2a744e8b7d69b7997379018a84dcb88d95b Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 17 Jun 2023 16:11:50 +0800 Subject: [PATCH 12/28] Rename kernels header --- .../mps/{vision_kernels.h => mps_kernels.h} | 24 +++++++++---------- torchvision/csrc/ops/mps/nms_kernel.mm | 4 ++-- .../csrc/ops/mps/ps_roi_align_kernel.mm | 6 ++--- .../csrc/ops/mps/ps_roi_pool_kernel.mm | 6 ++--- torchvision/csrc/ops/mps/roi_align_kernel.mm | 6 ++--- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 6 ++--- 6 files changed, 26 insertions(+), 26 deletions(-) rename torchvision/csrc/ops/mps/{vision_kernels.h => mps_kernels.h} (98%) diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h similarity index 98% rename from torchvision/csrc/ops/mps/vision_kernels.h rename to torchvision/csrc/ops/mps/mps_kernels.h index 07612fe2d6e..5aff9cf0378 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -1047,23 +1047,23 @@ REGISTER_PS_ROI_POOL_BACKWARD_OP(half); )VISION_METAL"; -static id compileBinaryOpsLibrary(id device) { - static id binaryLibrary = nil; - if (binaryLibrary) { - return binaryLibrary; +static id compileVisionOpsLibrary(id device) { + static id visionLibrary = nil; + if (visionLibrary) { + return visionLibrary; } NSError* error = nil; MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; [options setLanguageVersion:MTLLanguageVersion2_3]; - binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] + visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] options:options error:&error]; - TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]); - return binaryLibrary; + TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); + return visionLibrary; } -static id binaryPipelineState(id device, const std::string& kernel) { +static id visionPipelineState(id device, const std::string& kernel) { static std::unordered_map> psoCache; id pso = psoCache[kernel]; if (pso) { @@ -1071,10 +1071,10 @@ static id binaryPipelineState(id device, con } NSError* error = nil; - id binaryLib = compileBinaryOpsLibrary(device); - id binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; - TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel); - pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error]; + id visionLib = compileVisionOpsLibrary(device); + id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); psoCache[kernel] = pso; diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index e697ef80fe3..64730511590 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -1,6 +1,6 @@ #include #include -#include "vision_kernels.h" +#include "mps_kernels.h" #include #include @@ -65,7 +65,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {dets, scores}); diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 15d28f23b90..26d2e4ed44d 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -1,6 +1,6 @@ #include #include -#include "vision_kernels.h" +#include "mps_kernels.h" #include "mps_helpers.h" #include @@ -68,7 +68,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); @@ -156,7 +156,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 20821f5fbd2..0ce2e6800fc 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -1,6 +1,6 @@ #include #include -#include "vision_kernels.h" +#include "mps_kernels.h" #include "mps_helpers.h" #include @@ -65,7 +65,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); @@ -152,7 +152,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad_, rois_, channel_mapping}); diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index d135709bb8b..a57bbc018dc 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -1,6 +1,6 @@ #include #include -#include "vision_kernels.h" +#include "mps_kernels.h" #include "mps_helpers.h" #include @@ -60,7 +60,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); @@ -148,7 +148,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index 85a3b356598..5f376a2f091 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -1,6 +1,6 @@ #include #include -#include "vision_kernels.h" +#include "mps_kernels.h" #include "mps_helpers.h" #include @@ -62,7 +62,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); @@ -150,7 +150,7 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::binaryPipelineState(device, kernel); + id binaryPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_}); From 8036dc2fb63552b5075ce1b95ab474700e1777f3 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 17 Jun 2023 16:12:19 +0800 Subject: [PATCH 13/28] Add atol to RoI backward tests --- test/test_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index aa5e0c95a0d..940ad8aee3b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -160,6 +160,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, seed, device, contiguous, deterministic=False): + atol = 5e-2 if device == "mps" else None dtype = self.mps_dtype if device == "mps" else self.dtype torch.random.manual_seed(seed) @@ -177,9 +178,9 @@ def func(z): script_func = self.get_script_fn(rois, pool_size) with DeterministicGuard(deterministic): - gradcheck(func, (x,)) + gradcheck(func, (x,), atol=atol) - gradcheck(script_func, (x,)) + gradcheck(script_func, (x,), atol=atol) @needs_cuda @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) From 0ae9124ebe826eb4690d871d57f06a97fd59687e Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Sat, 17 Jun 2023 16:56:53 +0800 Subject: [PATCH 14/28] mps kernels formatting --- torchvision/csrc/ops/mps/mps_kernels.h | 492 ++++++++++++------------- 1 file changed, 246 insertions(+), 246 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 5aff9cf0378..8aaf55df15f 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -13,11 +13,11 @@ using namespace metal; /*----------Macros----------*/ -#define MPS_1D_KERNEL_LOOP_T(i, n, n_grids, index_t) \ +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ - i += (tptg.x * n_grids)) + i += (tptg.x * n_tgs)) -#define MPS_1D_KERNEL_LOOP(i, n, n_grids) MPS_1D_KERNEL_LOOP_T(i, n, n_grids, int) +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, int) /*----------Utils----------*/ @@ -191,12 +191,12 @@ bool inline IoU( constant uint nmsThreadsPerBlock = sizeof(uint64_t) * 8; template -kernel void nms(constant T * dev_boxes [[buffer(0)]], - device uint64_t * mask [[buffer(1)]], - constant int & n_boxes [[buffer(2)]], - constant float & iou_threshold [[buffer(3)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tid2 [[thread_position_in_threadgroup]]) { +kernel void nms(constant T * dev_boxes [[buffer(0)]], + device uint64_t * mask [[buffer(1)]], + constant int & n_boxes [[buffer(2)]], + constant float & iou_threshold [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tid2 [[thread_position_in_threadgroup]]) { const uint row_start = tgid.y; const uint col_start = tgid.x; @@ -229,30 +229,30 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], } } -#define REGISTER_NMS_OP(DTYPE) \ -template \ -[[host_name("nms_" #DTYPE)]] \ -kernel void nms( \ - constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ - device uint64_t * mask [[buffer(1)]], \ - constant int & n_boxes [[buffer(2)]], \ - constant float & iou_threshold [[buffer(3)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); +#define REGISTER_NMS_OP(DTYPE) \ +template \ +[[host_name("nms_" #DTYPE)]] \ +kernel void nms( \ + constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ + device uint64_t * mask [[buffer(1)]], \ + constant int & n_boxes [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], constant float & spatial_scale [[buffer(11)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], @@ -324,17 +324,17 @@ kernel void roi_align( template \ [[host_name("roi_align_" #DTYPE)]] \ kernel void roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ @@ -342,22 +342,22 @@ kernel void roi_align( \ template kernel void roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * grad_input [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], - constant int64_t & channels [[buffer(4)]], - constant int64_t & height [[buffer(5)]], - constant int64_t & width [[buffer(6)]], + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * grad_input [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], constant int64_t & pooled_height [[buffer(7)]], - constant int64_t & pooled_width [[buffer(8)]], - constant int64_t & sampling_ratio [[buffer(9)]], - constant bool & aligned [[buffer(10)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], constant float & spatial_scale [[buffer(11)]], - constant int64_t & n_stride [[buffer(12)]], - constant int64_t & c_stride [[buffer(13)]], - constant int64_t & h_stride [[buffer(14)]], - constant int64_t & w_stride [[buffer(15)]], + constant int64_t & n_stride [[buffer(12)]], + constant int64_t & c_stride [[buffer(13)]], + constant int64_t & h_stride [[buffer(14)]], + constant int64_t & w_stride [[buffer(15)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -454,26 +454,26 @@ kernel void roi_align_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE) \ -template \ -[[host_name("roi_align_backward_" #DTYPE)]] \ -kernel void roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * grad_input [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - constant int64_t & n_stride [[buffer(12)]], \ - constant int64_t & c_stride [[buffer(13)]], \ - constant int64_t & h_stride [[buffer(14)]], \ - constant int64_t & w_stride [[buffer(15)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); @@ -481,16 +481,16 @@ kernel void roi_align_backward( \ template kernel void roi_pool( constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], + constant T * rois [[buffer(1)]], device T * output [[buffer(2)]], device int64_t * argmax [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], + constant float & spatial_scale [[buffer(10)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -546,42 +546,42 @@ kernel void roi_pool( } } -#define REGISTER_ROI_POOL_OP(DTYPE) \ -template \ -[[host_name("roi_pool_" #DTYPE)]] \ -kernel void roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * argmax_data [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_ROI_POOL_OP(DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax_data [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * argmax_data [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * argmax_data [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], constant int64_t & pooled_width [[buffer(9)]], - constant float & spatial_scale [[buffer(10)]], - constant int64_t & n_stride [[buffer(11)]], - constant int64_t & c_stride [[buffer(12)]], - constant int64_t & h_stride [[buffer(13)]], - constant int64_t & w_stride [[buffer(14)]], + constant float & spatial_scale [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -609,44 +609,44 @@ kernel void roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE) \ -template \ -[[host_name("roi_pool_backward_" #DTYPE)]] \ -kernel void roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * argmax_data [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - constant int64_t & n_stride [[buffer(11)]], \ - constant int64_t & c_stride [[buffer(12)]], \ - constant int64_t & h_stride [[buffer(13)]], \ - constant int64_t & w_stride [[buffer(14)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void ps_roi_align( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -709,42 +709,42 @@ kernel void ps_roi_align( } } -#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \ -template \ -[[host_name("ps_roi_align_" #DTYPE)]] \ -kernel void ps_roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void ps_roi_align_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], constant int64_t & sampling_ratio [[buffer(10)]], - constant int64_t & channels_out [[buffer(11)]], - constant float & spatial_scale [[buffer(12)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -831,41 +831,41 @@ kernel void ps_roi_align_backward( } } -#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \ -template \ -[[host_name("ps_roi_align_backward_" #DTYPE)]] \ -kernel void ps_roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void ps_roi_pool( - constant T * input [[buffer(0)]], - constant T * rois [[buffer(1)]], - device T * output [[buffer(2)]], - device int64_t * channel_mapping [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -922,39 +922,39 @@ kernel void ps_roi_pool( } #define REGISTER_PS_ROI_POOL_OP(DTYPE) \ -template \ -[[host_name("ps_roi_pool_" #DTYPE)]] \ -kernel void ps_roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); template kernel void ps_roi_pool_backward( - constant T * grad_output [[buffer(0)]], - constant T * rois [[buffer(1)]], - constant int64_t * channel_mapping [[buffer(2)]], - device T * grad_input [[buffer(3)]], - constant int64_t & output_size [[buffer(4)]], - constant int64_t & channels [[buffer(5)]], - constant int64_t & height [[buffer(6)]], - constant int64_t & width [[buffer(7)]], - constant int64_t & pooled_height [[buffer(8)]], - constant int64_t & pooled_width [[buffer(9)]], - constant int64_t & channels_out [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tptg [[threads_per_threadgroup]], uint2 tid2 [[thread_position_in_threadgroup]]){ @@ -1006,24 +1006,24 @@ kernel void ps_roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE) \ -template \ -[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ -kernel void ps_roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); REGISTER_NMS_OP(float); From 1d21cfc6bea998f1071add595d3b73f1e5a36a8f Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Mon, 19 Jun 2023 02:07:16 +0800 Subject: [PATCH 15/28] binaryPSO -> visionPSO --- torchvision/csrc/ops/mps/nms_kernel.mm | 10 +++++----- .../csrc/ops/mps/ps_roi_align_kernel.mm | 20 +++++++++---------- .../csrc/ops/mps/ps_roi_pool_kernel.mm | 20 +++++++++---------- torchvision/csrc/ops/mps/roi_align_kernel.mm | 20 +++++++++---------- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 20 +++++++++---------- 5 files changed, 45 insertions(+), 45 deletions(-) diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index 64730511590..8ff58abd6b8 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -65,19 +65,19 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {dets, scores}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; [computeEncoder setBytes:&dets_num length:sizeof(int) atIndex:2]; [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > nmsThreadsPerBlock) { tgSize = nmsThreadsPerBlock; } @@ -85,7 +85,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 26d2e4ed44d..a2946a47056 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -68,12 +68,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -91,7 +91,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -99,7 +99,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return std::make_tuple(output, channel_mapping); @@ -156,12 +156,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -179,7 +179,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -187,7 +187,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return grad_input; diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 0ce2e6800fc..65c4dc43e24 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -65,12 +65,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -87,7 +87,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -95,7 +95,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return std::make_tuple(output, channel_mapping); @@ -152,12 +152,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad_, rois_, channel_mapping}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -174,7 +174,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -182,7 +182,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return grad_input; diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index a57bbc018dc..3e1a623db98 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -60,12 +60,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -82,7 +82,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -90,7 +90,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return output; @@ -148,12 +148,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -174,7 +174,7 @@ [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -182,7 +182,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return grad_input; diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index 5f376a2f091..da1271bad51 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -62,12 +62,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -83,7 +83,7 @@ [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -91,7 +91,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return std::make_tuple(output, argmax); @@ -150,12 +150,12 @@ MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); - id binaryPSO = mps::visionPipelineState(device, kernel); + id visionPSO = mps::visionPipelineState(device, kernel); // this function call is a no-op if MPS Profiler is not enabled - getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_, argmax_}); + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); - [computeEncoder setComputePipelineState:binaryPSO]; + [computeEncoder setComputePipelineState:visionPSO]; // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; @@ -175,7 +175,7 @@ [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; if (tgSize > threadsPerBlock) { tgSize = threadsPerBlock; } @@ -183,7 +183,7 @@ MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; - getMPSProfiler().endProfileKernel(binaryPSO); + getMPSProfiler().endProfileKernel(visionPSO); } }); return grad_input; From 3018b25257320e567f0587c6eafe8ff6781b8c69 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 20 Jun 2023 17:18:01 +0800 Subject: [PATCH 16/28] Formatting --- .github/scripts/run-clang-format.py | 2 +- test/common_utils.py | 3 + torchvision/csrc/ops/mps/mps_helpers.h | 2 +- torchvision/csrc/ops/mps/nms_kernel.mm | 56 ++++------- .../csrc/ops/mps/ps_roi_align_kernel.mm | 95 +++++++++---------- .../csrc/ops/mps/ps_roi_pool_kernel.mm | 91 +++++++++--------- torchvision/csrc/ops/mps/roi_align_kernel.mm | 78 ++++++++------- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 76 +++++++-------- 8 files changed, 188 insertions(+), 215 deletions(-) diff --git a/.github/scripts/run-clang-format.py b/.github/scripts/run-clang-format.py index 5c61b2519e0..670fd97833a 100755 --- a/.github/scripts/run-clang-format.py +++ b/.github/scripts/run-clang-format.py @@ -48,7 +48,7 @@ DEVNULL = open(os.devnull, "wb") -DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm" class ExitStatus: diff --git a/test/common_utils.py b/test/common_utils.py index c3fd743cf16..91d8a59b9a3 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -127,14 +127,17 @@ def cpu_and_gpu(): return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + def cpu_and_gpu_and_mps(): return cpu_and_gpu() + (pytest.param("mps", marks=pytest.mark.needs_mps),) + def needs_cuda(test_func): import pytest # noqa return pytest.mark.needs_cuda(test_func) + def needs_mps(test_func): import pytest diff --git a/torchvision/csrc/ops/mps/mps_helpers.h b/torchvision/csrc/ops/mps/mps_helpers.h index a18f51a900f..d3c0e8d94b7 100644 --- a/torchvision/csrc/ops/mps/mps_helpers.h +++ b/torchvision/csrc/ops/mps/mps_helpers.h @@ -3,4 +3,4 @@ constexpr int threadsPerBlock = 512; template constexpr inline T ceil_div(T n, T m) { return (n + m - 1) / m; -} \ No newline at end of file +} diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index 8ff58abd6b8..97064cf9aa3 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -2,8 +2,8 @@ #include #include "mps_kernels.h" -#include #include +#include namespace vision { namespace ops { @@ -13,47 +13,32 @@ // This should be in sync with `nmsThreadsPerBlock` in the metal kernel. constexpr int nmsThreadsPerBlock = sizeof(uint64_t) * 8; -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - +at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { using namespace at::native::mps; TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)) - + TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); + TORCH_CHECK(dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + if (dets.numel() == 0) { return at::empty({0}, dets.options().dtype(at::kLong)); } - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); auto dets_sorted = dets.index_select(0, order_t).contiguous(); int dets_num = dets.size(0); float iou_threshold_f = static_cast(iou_threshold); const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; - at::Tensor mask = - at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); id inputBuffer = getMTLBufferStorage(dets_sorted); id outputBuffer = getMTLBufferStorage(mask); @@ -92,15 +77,13 @@ int64_t num_to_keep = 0; at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = - (unsigned long long*)mask_cpu.data_ptr(); + unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); std::vector remv(col_blocks); memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = - at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data_ptr(); + + at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); for (int i = 0; i < dets_num; i++) { int nblock = i / nmsThreadsPerBlock; @@ -116,8 +99,7 @@ } return order_t.index( - {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) - .to(order_t.device(), keep.scalar_type())}); + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); } } // namespace diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index a2946a47056..1e0a2d902ee 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -1,24 +1,22 @@ #include #include -#include "mps_kernels.h" #include "mps_helpers.h" +#include "mps_kernels.h" -#include #include +#include namespace vision { namespace ops { namespace { -std::tuple ps_roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio) { - +std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { using namespace at::native::mps; TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -29,23 +27,20 @@ at::CheckedFrom c = "ps_roi_align_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); - + int64_t num_rois = rois.size(0); int64_t channels = input.size(1); int64_t height = input.size(2); int64_t width = input.size(3); float spatial_scale_f = static_cast(spatial_scale); - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); - + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int64_t channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); int64_t output_size = output.numel(); @@ -65,7 +60,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -78,8 +76,10 @@ [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3]; - + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -105,34 +105,32 @@ return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - +at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "ps_roi_align_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); - + float spatial_scale_f = static_cast(spatial_scale); - auto grad_input = at::zeros( - {batch_size, channels, height, width}, grad.options()); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { return grad_input; @@ -153,7 +151,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -165,9 +166,11 @@ // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -196,12 +199,8 @@ } // namespace TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), - TORCH_FN(ps_roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - TORCH_FN(ps_roi_align_backward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 65c4dc43e24..070e1fee6f4 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -1,23 +1,21 @@ #include #include -#include "mps_kernels.h" #include "mps_helpers.h" +#include "mps_kernels.h" -#include #include +#include namespace vision { namespace ops { namespace { -std::tuple ps_roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - +std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { using namespace at::native::mps; TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -28,22 +26,19 @@ at::CheckedFrom c = "ps_roi_pool_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); - + int64_t num_rois = rois.size(0); int64_t channels = input.size(1); int64_t height = input.size(2); int64_t width = input.size(3); float spatial_scale_f = static_cast(spatial_scale); - - TORCH_CHECK( - channels % (pooled_height * pooled_width) == 0, - "input channels must be a multiple of pooling height * pooling width"); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); int64_t channels_out = channels / (pooled_height * pooled_width); - auto output = at::zeros( - {num_rois, channels_out, pooled_height, pooled_width}, input.options()); - auto channel_mapping = - at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); auto output_size = output.numel(); if (output_size == 0) { @@ -62,7 +57,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -75,8 +73,10 @@ [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3]; - + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -101,34 +101,32 @@ return std::make_tuple(output, channel_mapping); } -at::Tensor ps_roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - +at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "ps_roi_pool_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); - + float spatial_scale_f = static_cast(spatial_scale); auto num_rois = rois.size(0); - auto grad_input = at::zeros( - {batch_size, channels, height, width}, grad.options()); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { return grad_input; @@ -149,7 +147,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -161,9 +162,11 @@ // [N, C, H, W] [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -191,12 +194,8 @@ } // namespace TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), - TORCH_FN(ps_roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), - TORCH_FN(ps_roi_pool_backward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index 3e1a623db98..cd6483847a1 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -1,25 +1,23 @@ #include #include -#include "mps_kernels.h" #include "mps_helpers.h" +#include "mps_kernels.h" -#include #include +#include namespace vision { namespace ops { namespace { -at::Tensor roi_align_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - +at::Tensor roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { using namespace at::native::mps; TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -30,16 +28,15 @@ at::CheckedFrom c = "roi_align_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); - + int64_t num_rois = rois.size(0); int64_t channels = input.size(1); int64_t height = input.size(2); int64_t width = input.size(3); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + int64_t output_size = num_rois * pooled_height * pooled_width * channels; if (output.numel() == 0) { @@ -57,7 +54,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -70,7 +70,7 @@ [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; @@ -96,19 +96,17 @@ return output; } -at::Tensor roi_align_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { - +at::Tensor roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -118,11 +116,10 @@ at::CheckedFrom c = "roi_align_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t}); at::checkAllSameType(c, {grad_t, rois_t}); - + float spatial_scale_f = static_cast(spatial_scale); - at::Tensor grad_input = at::zeros( - {batch_size, channels, height, width}, grad.options()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { return grad_input; @@ -145,7 +142,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -158,7 +158,7 @@ [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; @@ -191,12 +191,8 @@ } // namespace TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_align"), - TORCH_FN(roi_align_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), - TORCH_FN(roi_align_backward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index da1271bad51..aa01e4f4fc0 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -1,23 +1,21 @@ #include #include -#include "mps_kernels.h" #include "mps_helpers.h" +#include "mps_kernels.h" -#include #include +#include namespace vision { namespace ops { namespace { -std::tuple roi_pool_forward_kernel( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { - +std::tuple roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { using namespace at::native::mps; TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -28,19 +26,16 @@ at::CheckedFrom c = "roi_pool_forward_kernel"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); - + int64_t num_rois = rois.size(0); int64_t channels = input.size(1); int64_t height = input.size(2); int64_t width = input.size(3); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor output = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, input.options()); - at::Tensor argmax = at::zeros( - {num_rois, channels, pooled_height, pooled_width}, - input.options().dtype(at::kLong)); - + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); + int64_t output_size = num_rois * pooled_height * pooled_width * channels; if (output.numel() == 0) { @@ -59,7 +54,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -73,7 +71,7 @@ [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -97,18 +95,16 @@ return std::make_tuple(output, argmax); } -at::Tensor roi_pool_backward_kernel( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { - +at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); @@ -119,11 +115,10 @@ at::CheckedFrom c = "roi_pool_backward_kernel"; at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); at::checkAllSameType(c, {grad_t, rois_t}); - + float spatial_scale_f = static_cast(spatial_scale); - at::Tensor grad_input = at::zeros( - {batch_size, channels, height, width}, grad.options()); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { return grad_input; @@ -147,7 +142,10 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); @@ -161,7 +159,7 @@ [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; @@ -192,12 +190,8 @@ } // namespace TORCH_LIBRARY_IMPL(torchvision, MPS, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::roi_pool"), - TORCH_FN(roi_pool_forward_kernel)); - m.impl( - TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), - TORCH_FN(roi_pool_backward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); } } // namespace ops From d609da45587a4d3ee5ea8a56ecbf6bd39aabb615 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 20 Jun 2023 18:05:29 +0800 Subject: [PATCH 17/28] Testing --- test/common_utils.py | 4 ++- test/conftest.py | 10 +++++- test/test_ops.py | 10 ++++-- torchvision/csrc/ops/mps/mps_kernels.h | 49 ++++++++++++++++---------- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 91d8a59b9a3..75105adfe34 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -32,7 +32,9 @@ IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" +MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." +OSS_CI_GPU_NO_MPS_MSG = "We're in an OSS M1 machine, and this test doesn't need mps." @contextlib.contextmanager @@ -139,7 +141,7 @@ def needs_cuda(test_func): def needs_mps(test_func): - import pytest + import pytest # noqa return pytest.mark.needs_mps(test_func) diff --git a/test/conftest.py b/test/conftest.py index a9e8f1cda52..94670c1ae35 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,12 +8,13 @@ torchvision.disable_beta_transforms_warning() -from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG +from common_utils import CUDA_NOT_AVAILABLE_MSG, MPS_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG, OSS_CI_GPU_NO_MPS_MSG def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") + config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") @@ -37,11 +38,16 @@ def pytest_collection_modifyitems(items): # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # and the ones with device == 'cpu' won't have the mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_mps = item.get_closest_marker("needs_mps") is not None + if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) + + if needs_mps and not torch.backends.mps.is_available(): + item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) if IN_FBCODE: # fbcode doesn't like skipping tests, so instead we just don't collect the test @@ -60,6 +66,8 @@ def pytest_collection_modifyitems(items): # Similar to what happens in RE workers: we don't need the OSS CI GPU machines # to run the CPU-only tests. item.add_marker(pytest.mark.skip(reason=OSS_CI_GPU_NO_CUDA_MSG)) + if not needs_mps and torch.backends.mps.is_available(): + item.add_marker(pytest.mark.skip(reason=OSS_CI_GPU_NO_MPS_MSG)) if item.get_closest_marker("dont_collect") is not None: # currently, this is only used for some tests we're sure we don't want to run on fbcode diff --git a/test/test_ops.py b/test/test_ops.py index 940ad8aee3b..f41923d55ef 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -97,6 +97,7 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: class RoIOpTester(ABC): dtype = torch.float64 mps_dtype = torch.float32 + mps_backward_atol = 2e-2 @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @@ -160,7 +161,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, seed, device, contiguous, deterministic=False): - atol = 5e-2 if device == "mps" else None + atol = self.mps_backward_atol if device == "mps" else 1e-05 dtype = self.mps_dtype if device == "mps" else self.dtype torch.random.manual_seed(seed) @@ -276,6 +277,8 @@ def test_jit_boxes_list(self): class TestPSRoIPool(RoIOpTester): + mps_backward_atol = 5e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) @@ -357,6 +360,8 @@ def bilinear_interpolate(data, y, x, snap_border=False): class TestRoIAlign(RoIOpTester): + mps_backward_atol = 6e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): return ops.RoIAlign( (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned @@ -542,6 +547,8 @@ def test_jit_boxes_list(self): class TestPSRoIAlign(RoIOpTester): + mps_backward_atol = 5e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) @@ -769,7 +776,6 @@ def test_nms_cuda_float16(self): assert_equal(keep32, keep16) @needs_mps - @pytest.mark.xfail def test_nms_mps_float16(self): boxes = torch.tensor( [ diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 8aaf55df15f..c6d5888a6e8 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -13,7 +13,7 @@ using namespace metal; /*----------Macros----------*/ -#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ i += (tptg.x * n_tgs)) @@ -27,28 +27,40 @@ inline T ceil_div(T n, T m) { } template -void atomic_add_float( device T* data_ptr, const float val) +void atomic_add_float( device T* data_ptr, const T val) { #if __METAL_VERSION__ >= 300 + // atomic_float is supported in Metal 3 (macOS Ventura) onward. device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); #else // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 + // Create an atomic uint pointer for atomic checking. device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + // Create necessary storage. uint fetched_uint, assigning_uint; - float fetched_float, assigning_float; + T fetched_float, assigning_float; - fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed ); - fetched_float = *( (thread float*) &fetched_uint ); + // Replace the value in atom_var with 0 and return the previous value in atom_var. + fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); + // Read out the previous value as float. + fetched_float = *( (thread T*) &fetched_uint ); + // Do addition and represent the addition result in uint for atomic checking. assigning_float = fetched_float + val; - assigning_uint = *( (thread uint*) &assigning_float ); - - while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed)) != 0) { - uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed); - float fetched_float_again = *( (thread float*) &fetched_uint_again ); - fetched_float = *( (thread float*) &(fetched_uint) ); + assigning_uint = *((thread uint*) &assigning_float); + + // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { + // If atom_var is not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previous assigned addition result. + uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); + T fetched_float_again = *( (thread T*) &fetched_uint_again ); + // Re-add again + fetched_float = *((thread T*) &(fetched_uint)); + // Previous assigned addition result + addition result from other threads. assigning_float = fetched_float_again + fetched_float; - assigning_uint = *( (thread uint*) &assigning_float ); + assigning_uint = *( (thread uint*) &assigning_float); } #endif } @@ -177,9 +189,10 @@ bool inline IoU( auto yy2 = min(a.w, b.w); auto w = max(static_cast(0), xx2 - xx1); auto h = max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto area_a = (a.z - a.x) * (a.w - a.y); - auto area_b = (b.z - b.x) * (b.w - b.y); + // Upcast to float before multiplications to circumvent precision issues in half. + auto inter = static_cast(w) * static_cast(h); + auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); + auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); return (inter / (area_a + area_b - inter)) > threshold; } @@ -1081,6 +1094,6 @@ static id visionPipelineState(id device, con return pso; } -} -} -} // namespace \ No newline at end of file +} // namespace mps +} // namespace ops +} // namespace vision From 256bd569b0dbc8fce8d04e416eb0e2d3ff2ce979 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 20 Jun 2023 18:08:28 +0800 Subject: [PATCH 18/28] Rename cpu_and_gpu to cpu_and_cuda --- test/common_utils.py | 6 +- test/conftest.py | 15 +++-- test/test_functional_tensor.py | 90 +++++++++++++-------------- test/test_models.py | 12 ++-- test/test_ops.py | 42 ++++++------- test/test_prototype_models.py | 6 +- test/test_transforms_tensor.py | 60 +++++++++--------- test/test_transforms_v2.py | 6 +- test/test_transforms_v2_functional.py | 50 +++++++-------- 9 files changed, 147 insertions(+), 140 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 75105adfe34..f83e7ca5d99 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -124,14 +124,14 @@ def disable_console_output(): yield -def cpu_and_gpu(): +def cpu_and_cuda(): import pytest # noqa return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) -def cpu_and_gpu_and_mps(): - return cpu_and_gpu() + (pytest.param("mps", marks=pytest.mark.needs_mps),) +def cpu_and_cuda_and_mps(): + return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),) def needs_cuda(test_func): diff --git a/test/conftest.py b/test/conftest.py index 94670c1ae35..819ca7ce229 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,15 @@ torchvision.disable_beta_transforms_warning() -from common_utils import CUDA_NOT_AVAILABLE_MSG, MPS_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG, OSS_CI_GPU_NO_MPS_MSG +from common_utils import ( + CUDA_NOT_AVAILABLE_MSG, + IN_FBCODE, + IN_OSS_CI, + IN_RE_WORKER, + MPS_NOT_AVAILABLE_MSG, + OSS_CI_GPU_NO_CUDA_MSG, + OSS_CI_GPU_NO_MPS_MSG, +) def pytest_configure(config): @@ -34,18 +42,17 @@ def pytest_collection_modifyitems(items): # The needs_cuda mark will exist if the test was explicitly decorated with # the @needs_cuda decorator. It will also exist if it was parametrized with a # parameter that has the mark: for example if a test is parametrized with - # @pytest.mark.parametrize('device', cpu_and_gpu()) + # @pytest.mark.parametrize('device', cpu_and_cuda()) # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # and the ones with device == 'cpu' won't have the mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None needs_mps = item.get_closest_marker("needs_mps") is not None - if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) - + if needs_mps and not torch.backends.mps.is_available(): item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index de9d10d6bde..43f54e6f107 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -21,7 +21,7 @@ _create_data_batch, _test_fn_on_batch, assert_equal, - cpu_and_gpu, + cpu_and_cuda, needs_cuda, ) from torchvision.transforms import InterpolationMode @@ -34,7 +34,7 @@ ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions]) def test_image_sizes(device, fn): script_F = torch.jit.script(fn) @@ -72,7 +72,7 @@ class TestRotate: scripted_rotate = torch.jit.script(F.rotate) IMG_W = 26 - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)]) @pytest.mark.parametrize( "center", @@ -131,7 +131,7 @@ def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn f"{out_pil_tensor[0, :7, :7]}" ) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", ALL_DTYPES) def test_rotate_batch(self, device, dt): if dt == torch.float16 and device == "cpu": @@ -157,7 +157,7 @@ class TestAffine: ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16] scripted_affine = torch.jit.script(F.affine) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("dt", ALL_DTYPES) def test_identity_map(self, device, height, width, dt): @@ -180,7 +180,7 @@ def test_identity_map(self, device, height, width, dt): ) assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}") - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(26, 26)]) @pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize( @@ -224,7 +224,7 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): # Tolerance : less than 6% of different pixels assert ratio_diff_pixels < 0.06 - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(32, 26)]) @pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120]) @@ -258,7 +258,7 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn, center): # Tolerance : less than 3% of different pixels assert ratio_diff_pixels < 0.03 - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("t", [[10, 12], (-12, -13)]) @@ -283,7 +283,7 @@ def test_translations(self, device, height, width, dt, t, fn): _assert_equal_tensor_to_pil(out_tensor, out_pil_img) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize( @@ -344,7 +344,7 @@ def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): tol = 0.06 if device == "cuda" else 0.05 assert ratio_diff_pixels < tol - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", ALL_DTYPES) def test_batches(self, device, dt): if dt == torch.float16 and device == "cpu": @@ -357,7 +357,7 @@ def test_batches(self, device, dt): _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_interpolation_type(self, device): tensor, pil_img = _create_data(26, 26, device=device) @@ -389,7 +389,7 @@ def _get_data_dims_and_points_for_perspective(): return dims_and_points -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( @@ -435,7 +435,7 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): assert ratio_diff_pixels < 0.05 -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) def test_perspective_batch(device, dims_and_points, dt): @@ -473,7 +473,7 @@ def test_perspective_interpolation_type(): assert_equal(res1, res2) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( "size", @@ -539,7 +539,7 @@ def test_resize(device, dt, size, max_size, interpolation): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_resize_asserts(device): tensor, pil_img = _create_data(26, 36, device=device) @@ -556,7 +556,7 @@ def test_resize_asserts(device): F.resize(img, size=32, max_size=32) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) @@ -663,7 +663,7 @@ def check_functional_vs_PIL_vs_scripted( _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) @pytest.mark.parametrize("channels", [1, 3]) @@ -679,7 +679,7 @@ def test_adjust_brightness(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("channels", [1, 3]) def test_invert(device, dtype, channels): @@ -688,7 +688,7 @@ def test_invert(device, dtype, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)]) @pytest.mark.parametrize("channels", [1, 3]) def test_posterize(device, config, channels): @@ -705,7 +705,7 @@ def test_posterize(device, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) @pytest.mark.parametrize("channels", [1, 3]) def test_solarize1(device, config, channels): @@ -722,7 +722,7 @@ def test_solarize1(device, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) @pytest.mark.parametrize("channels", [1, 3]) @@ -754,7 +754,7 @@ def test_solarize2(device, dtype, config, channels): *[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]], ], ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_solarize_threshold_within_bound(threshold, dtype, device): make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max) img = make_img((3, 12, 23), dtype=dtype, device=device) @@ -770,7 +770,7 @@ def test_solarize_threshold_within_bound(threshold, dtype, device): (torch.int64, 2**64), ], ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_solarize_threshold_above_bound(threshold, dtype, device): make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max) img = make_img((3, 12, 23), dtype=dtype, device=device) @@ -778,7 +778,7 @@ def test_solarize_threshold_above_bound(threshold, dtype, device): F_t.solarize(img, threshold) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("channels", [1, 3]) @@ -794,7 +794,7 @@ def test_adjust_sharpness(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("channels", [1, 3]) def test_autocontrast(device, dtype, channels): @@ -803,7 +803,7 @@ def test_autocontrast(device, dtype, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("channels", [1, 3]) def test_autocontrast_equal_minmax(device, dtype, channels): @@ -815,7 +815,7 @@ def test_autocontrast_equal_minmax(device, dtype, channels): assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all() -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("channels", [1, 3]) def test_equalize(device, channels): torch.use_deterministic_algorithms(False) @@ -832,7 +832,7 @@ def test_equalize(device, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("channels", [1, 3]) @@ -842,7 +842,7 @@ def test_adjust_contrast(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) @pytest.mark.parametrize("channels", [1, 3]) @@ -852,7 +852,7 @@ def test_adjust_saturation(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) @pytest.mark.parametrize("channels", [1, 3]) @@ -862,7 +862,7 @@ def test_adjust_hue(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) @pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) @pytest.mark.parametrize("channels", [1, 3]) @@ -878,7 +878,7 @@ def test_adjust_gamma(device, dtype, config, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]]) @pytest.mark.parametrize( @@ -928,7 +928,7 @@ def test_pad(device, dt, pad, config): _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC]) def test_resized_crop(device, mode): # test values of F.resized_crop in several cases: @@ -963,7 +963,7 @@ def test_resized_crop(device, mode): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "func, args", [ @@ -996,7 +996,7 @@ def test_assert_image_tensor(device, func, args): func(tensor, *args) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_vflip(device): script_vflip = torch.jit.script(F.vflip) @@ -1013,7 +1013,7 @@ def test_vflip(device): _test_fn_on_batch(batch_tensors, F.vflip) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_hflip(device): script_hflip = torch.jit.script(F.hflip) @@ -1030,7 +1030,7 @@ def test_hflip(device): _test_fn_on_batch(batch_tensors, F.hflip) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "top, left, height, width", [ @@ -1059,7 +1059,7 @@ def test_crop(device, top, left, height, width): _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("image_size", ("small", "large")) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @@ -1113,7 +1113,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_hsv2rgb(device): scripted_fn = torch.jit.script(F_t._hsv2rgb) shape = (3, 100, 150) @@ -1144,7 +1144,7 @@ def test_hsv2rgb(device): _test_fn_on_batch(batch_tensors, F_t._hsv2rgb) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_rgb2hsv(device): scripted_fn = torch.jit.script(F_t._rgb2hsv) shape = (3, 150, 100) @@ -1183,7 +1183,7 @@ def test_rgb2hsv(device): _test_fn_on_batch(batch_tensors, F_t._rgb2hsv) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("num_output_channels", (3, 1)) def test_rgb_to_grayscale(device, num_output_channels): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) @@ -1202,7 +1202,7 @@ def test_rgb_to_grayscale(device, num_output_channels): _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_center_crop(device): script_center_crop = torch.jit.script(F.center_crop) @@ -1220,7 +1220,7 @@ def test_center_crop(device): _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_five_crop(device): script_five_crop = torch.jit.script(F.five_crop) @@ -1254,7 +1254,7 @@ def test_five_crop(device): assert_equal(transformed_batch, s_transformed_batch) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_ten_crop(device): script_ten_crop = torch.jit.script(F.ten_crop) @@ -1300,7 +1300,7 @@ def test_elastic_transform_asserts(): _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2)) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( diff --git a/test/test_models.py b/test/test_models.py index f6eeb7c28c8..67eb2115c85 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -15,7 +15,7 @@ import torch.fx import torch.nn as nn from _utils_internal import get_relative_path -from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed +from common_utils import cpu_and_cuda, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed from PIL import Image from torchvision import models, transforms from torchvision.models import get_model_builder, list_models @@ -676,14 +676,14 @@ def vitc_b_16(**kwargs: Any): @pytest.mark.parametrize("model_fn", [vitc_b_16]) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_vitc_models(model_fn, dev): test_classification_model(model_fn, dev) @disable_tf32() # see: https://github.com/pytorch/vision/issues/7618 @pytest.mark.parametrize("model_fn", list_model_fns(models)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_classification_model(model_fn, dev): set_rng_seed(0) defaults = { @@ -726,7 +726,7 @@ def test_classification_model(model_fn, dev): @pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_segmentation_model(model_fn, dev): set_rng_seed(0) defaults = { @@ -791,7 +791,7 @@ def check_out(out): @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_detection_model(model_fn, dev): set_rng_seed(0) defaults = { @@ -923,7 +923,7 @@ def test_detection_model_validation(model_fn): @pytest.mark.parametrize("model_fn", list_model_fns(models.video)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_video_model(model_fn, dev): set_rng_seed(0) # the default input shape is diff --git a/test/test_ops.py b/test/test_ops.py index f41923d55ef..a6e851e9943 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ import torch import torch.fx import torch.nn.functional as F -from common_utils import assert_equal, cpu_and_gpu, cpu_and_gpu_and_mps, needs_cuda, needs_mps +from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -99,7 +99,7 @@ class RoIOpTester(ABC): mps_dtype = torch.float32 mps_backward_atol = 2e-2 - @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): dtype = self.mps_dtype if device == "mps" else self.dtype @@ -129,7 +129,7 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_is_leaf_node(self, device): op_obj = self.make_obj(wrap=True).to(device=device) graph_node_names = get_graph_node_names(op_obj) @@ -138,7 +138,7 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float): op_obj = self.make_obj().to(device=device) graph_module = torch.fx.symbolic_trace(op_obj) @@ -158,7 +158,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, seed, device, contiguous, deterministic=False): atol = self.mps_backward_atol if device == "mps" else 1e-05 @@ -428,7 +428,7 @@ def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) @pytest.mark.parametrize("aligned", (True, False)) - @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): @@ -460,7 +460,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): ) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_gpu_and_mps()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_backward(self, seed, device, contiguous, deterministic): @@ -624,7 +624,7 @@ def test_msroialign_repr(self): ) assert repr(t) == expected_string - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_is_leaf_node(self, device): op_obj = self.make_obj(wrap=True).to(device=device) graph_node_names = get_graph_node_names(op_obj) @@ -931,7 +931,7 @@ def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, ) return DeformConvModuleWrapper(obj) if wrap else obj - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_is_leaf_node(self, device): op_obj = self.make_obj(wrap=True).to(device=device) graph_node_names = get_graph_node_names(op_obj) @@ -940,7 +940,7 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) def test_forward(self, device, contiguous, batch_sz, dtype=None): @@ -992,7 +992,7 @@ def test_wrong_sizes(self): wrong_mask = torch.rand_like(mask[:, :2]) layer(x, offset, wrong_mask) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) def test_backward(self, device, contiguous, batch_sz): @@ -1457,7 +1457,7 @@ def assert_empty_loss(iou_fn, dtype, device): class TestGeneralizedBoxIouLoss: # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_giou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) @@ -1485,7 +1485,7 @@ def test_giou_loss(self, dtype, device): with pytest.raises(ValueError, match="Invalid"): ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz") - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_empty_inputs(self, dtype, device): assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) @@ -1493,7 +1493,7 @@ def test_empty_inputs(self, dtype, device): class TestCompleteBoxIouLoss: @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_ciou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) @@ -1507,14 +1507,14 @@ def test_ciou_loss(self, dtype, device): with pytest.raises(ValueError, match="Invalid"): ops.complete_box_iou_loss(box1s, box2s, reduction="xyz") - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_empty_inputs(self, dtype, device): assert_empty_loss(ops.complete_box_iou_loss, dtype, device) class TestDistanceBoxIouLoss: - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_distance_iou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) @@ -1529,7 +1529,7 @@ def test_distance_iou_loss(self, dtype, device): with pytest.raises(ValueError, match="Invalid"): ops.distance_box_iou_loss(box1s, box2s, reduction="xyz") - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_empty_distance_iou_inputs(self, dtype, device): assert_empty_loss(ops.distance_box_iou_loss, dtype, device) @@ -1574,7 +1574,7 @@ def generate_tensor_with_range_type(shape, range_type, **kwargs): @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [0, 1]) def test_correct_ratio(self, alpha, gamma, device, dtype, seed): @@ -1603,7 +1603,7 @@ def test_correct_ratio(self, alpha, gamma, device, dtype, seed): torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol) @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [2, 3]) def test_equal_ce_loss(self, reduction, device, dtype, seed): @@ -1630,7 +1630,7 @@ def test_equal_ce_loss(self, reduction, device, dtype, seed): @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [4, 5]) def test_jit(self, alpha, gamma, reduction, device, dtype, seed): @@ -1646,7 +1646,7 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) # Raise ValueError for anonymous reduction mode - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_reduction_mode(self, device, dtype, reduction="xyz"): if device == "cpu" and dtype is torch.half: diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 6d9f22c1543..d32df68f1f4 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,13 +1,13 @@ import pytest import test_models as TM import torch -from common_utils import cpu_and_gpu, set_rng_seed +from common_utils import cpu_and_cuda, set_rng_seed from torchvision.prototype import models @pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,)) @pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_raft_stereo(model_fn, model_mode, dev): # A simple test to make sure the model can do forward pass and jit scriptable set_rng_seed(0) @@ -40,7 +40,7 @@ def test_raft_stereo(model_fn, model_mode, dev): @pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,)) @pytest.mark.parametrize("model_mode", ("standard", "scripted")) -@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.parametrize("dev", cpu_and_cuda()) def test_crestereo(model_fn, model_mode, dev): set_rng_seed(0) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 077a12af490..e2ab5673f1e 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -12,7 +12,7 @@ _create_data, _create_data_batch, assert_equal, - cpu_and_gpu, + cpu_and_cuda, float_dtypes, get_tmp_dir, int_dtypes, @@ -105,7 +105,7 @@ def _test_fn_save_load(fn, tmpdir): _ = torch.jit.load(p) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "func,method,fn_kwargs,match_kwargs", [ @@ -130,7 +130,7 @@ def test_random(func, method, device, channels, fn_kwargs, match_kwargs): @pytest.mark.parametrize("seed", range(10)) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("channels", [1, 3]) class TestColorJitter: @pytest.fixture(autouse=True) @@ -206,7 +206,7 @@ def test_color_jitter_all(self, device, channels): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"]) @pytest.mark.parametrize("mul", [1, -1]) def test_pad(m, mul, device): @@ -229,7 +229,7 @@ def test_pad(m, mul, device): _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_crop(device): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} # Test transforms.RandomCrop with size and padding as tuple @@ -257,7 +257,7 @@ def test_crop(device): _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "padding_config", [ @@ -283,7 +283,7 @@ def test_random_crop_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_center_crop(device, tmpdir): fn_kwargs = {"output_size": (4, 5)} meth_kwargs = {"size": (4, 5)} @@ -313,7 +313,7 @@ def test_center_crop_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "fn, method, out_length", [ @@ -380,7 +380,7 @@ def test_resize_int(self, size): assert y.shape[1] == size assert y.shape[2] == int(size * 46 / 32) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64]) @pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]]) @pytest.mark.parametrize("max_size", [None, 35, 1000]) @@ -404,7 +404,7 @@ def test_resize_save_load(self, tmpdir): fn = T.Resize(size=[32], antialias=True) _test_fn_save_load(fn, tmpdir) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]]) @pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]]) @@ -460,42 +460,42 @@ def test_random_affine_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) def test_random_affine_shear(device, interpolation, shear): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) def test_random_affine_scale(device, interpolation, scale): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]]) def test_random_affine_translate(device, interpolation, translate): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) def test_random_affine_degrees(device, interpolation, degrees): _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) def test_random_affine_fill(device, interpolation, fill): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)]) @pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) @@ -517,7 +517,7 @@ def test_random_rotate_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20)) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) @pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) @@ -537,7 +537,7 @@ def test_random_perspective_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "Klass, meth_kwargs", [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})], @@ -547,7 +547,7 @@ def test_to_grayscale(device, Klass, meth_kwargs): _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max") -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes()) @pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes()) def test_convert_image_dtype(device, in_dtype, out_dtype): @@ -578,7 +578,7 @@ def test_convert_image_dtype_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy]) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) def test_autoaugment(device, policy, fill): @@ -592,7 +592,7 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("num_ops", [1, 2, 3]) @pytest.mark.parametrize("magnitude", [7, 9, 11]) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) @@ -607,7 +607,7 @@ def test_randaugment(device, num_ops, magnitude, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) def test_trivialaugmentwide(device, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) @@ -620,7 +620,7 @@ def test_trivialaugmentwide(device, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) def test_augmix(device, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) @@ -686,7 +686,7 @@ def shear(pil_img, level, mode, resample): _assert_approx_equal_tensor_to_pil(out, expected_out) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "config", [ @@ -724,7 +724,7 @@ def test_random_erasing_with_invalid_data(): random_erasing(img) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_normalize(device, tmpdir): fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tensor, _ = _create_data(26, 34, device=device) @@ -743,7 +743,7 @@ def test_normalize(device, tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_linear_transformation(device, tmpdir): c, h, w = 3, 24, 32 @@ -769,7 +769,7 @@ def test_linear_transformation(device, tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_compose(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 @@ -797,7 +797,7 @@ def test_compose(device): torch.jit.script(t) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_random_apply(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 @@ -839,7 +839,7 @@ def test_random_apply(device): torch.jit.script(transforms) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "meth_kwargs", [ @@ -877,7 +877,7 @@ def test_gaussian_blur(device, channels, meth_kwargs): ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "fill", [ diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 02e3e1e569a..71df9ad72d8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -16,7 +16,7 @@ from common_utils import ( assert_equal, assert_run_python_script, - cpu_and_gpu, + cpu_and_cuda, make_bounding_box, make_bounding_boxes, make_detection_mask, @@ -173,7 +173,7 @@ class TestSmoke: next(make_vanilla_tensor_images()), ], ) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_common(self, transform, adapter, container_type, image_or_video, device): spatial_size = F.get_spatial_size(image_or_video) input = dict( @@ -1364,7 +1364,7 @@ def test_assertions(self): class TestRandomIoUCrop: - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) def test__get_params(self, device, options, mocker): image = mocker.MagicMock(spec=datapoints.Image) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 60a06f571b1..9a2ea37a4ae 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -14,7 +14,7 @@ from common_utils import ( assert_close, cache, - cpu_and_gpu, + cpu_and_cuda, DEFAULT_SQUARE_SPATIAL_SIZE, make_bounding_boxes, needs_cuda, @@ -120,7 +120,7 @@ class TestKernels: [info for info in KERNEL_INFOS if info.logs_usage], args_kwargs_fn=lambda info: info.sample_inputs_fn(), ) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_logging(self, spy_on, info, args_kwargs, device): spy = spy_on(torch._C._log_api_usage_once) @@ -131,7 +131,7 @@ def test_logging(self, spy_on, info, args_kwargs, device): @ignore_jit_warning_no_profile @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_scripted_vs_eager(self, test_id, info, args_kwargs, device): kernel_eager = info.kernel kernel_scripted = script(kernel_eager) @@ -167,7 +167,7 @@ def _unbatch(self, batch, *, data_dims): ] @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_batched_vs_single(self, test_id, info, args_kwargs, device): (batched_input, *other_args), kwargs = args_kwargs.load(device) @@ -208,7 +208,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device): ) @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_no_inplace(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) input = input.as_subclass(torch.Tensor) @@ -240,7 +240,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs): ) @sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_dtype_and_device_consistency(self, info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) input = input.as_subclass(torch.Tensor) @@ -320,7 +320,7 @@ class TestDispatchers: DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), ) - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_logging(self, spy_on, info, args_kwargs, device): spy = spy_on(torch._C._log_api_usage_once) @@ -331,7 +331,7 @@ def test_logging(self, spy_on, info, args_kwargs, device): @ignore_jit_warning_no_profile @image_sample_inputs - @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_scripted_smoke(self, info, args_kwargs, device): dispatcher = script(info.dispatcher) @@ -553,7 +553,7 @@ def test_alias(alias, target): args_kwargs_fn=lambda info: info.sample_inputs_fn(), ), ) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): (input, *other_args), kwargs = args_kwargs.load(device) dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) @@ -564,7 +564,7 @@ def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device): assert output.device == input.device -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("num_channels", [1, 3]) def test_normalize_image_tensor_stats(device, num_channels): stats = pytest.importorskip("scipy.stats", reason="SciPy is not available") @@ -664,7 +664,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): return true_matrix -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_affine_bounding_box_on_fixed_input(device): # Check transformation against known expected output format = datapoints.BoundingBoxFormat.XYXY @@ -715,7 +715,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_affine_segmentation_mask_on_fixed_input(device): # Check transformation against known expected output and CPU/CUDA devices @@ -820,7 +820,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): torch.testing.assert_close(output_spatial_size, expected_spatial_size, atol=1, rtol=0) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): # Check transformation against known expected output @@ -876,7 +876,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_rotate_segmentation_mask_on_fixed_input(device): # Check transformation against known expected output and CPU/CUDA devices @@ -892,7 +892,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "format", [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], @@ -949,7 +949,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, torch.testing.assert_close(output_spatial_size, spatial_size) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) mask[:, :, 0] = 1 @@ -961,7 +961,7 @@ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) mask[:, 0, :] = 1 @@ -973,7 +973,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "format", [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], @@ -1032,7 +1032,7 @@ def _parse_padding(padding): return padding -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) def test_correctness_pad_bounding_box(device, padding): def _compute_expected_bbox(bbox, padding_): @@ -1087,7 +1087,7 @@ def _compute_expected_spatial_size(bbox, padding_): torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_pad_segmentation_mask_on_fixed_input(device): mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) @@ -1098,7 +1098,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "startpoints, endpoints", [ @@ -1182,7 +1182,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize( "output_size", [(18, 18), [18, 15], (16, 19), [12], [46, 48]], @@ -1236,7 +1236,7 @@ def _compute_expected_bbox(bbox, output_size_): torch.testing.assert_close(output_spatial_size, output_size) -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) def test_correctness_center_crop_mask(device, output_size): def _compute_expected_mask(mask, output_size): @@ -1260,7 +1260,7 @@ def _compute_expected_mask(mask, output_size): # Copied from test/test_functional_tensor.py -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("spatial_size", ("small", "large")) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @@ -1357,7 +1357,7 @@ def test_equalize_image_tensor_edge_cases(): assert output.unique().tolist() == [0, 255] -@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_uniform_temporal_subsample(device): video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8) out_video = F.uniform_temporal_subsample(video, 5) From 990685f0efae4fd70ac7433e3436993523775cdd Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 20 Jun 2023 18:16:19 +0800 Subject: [PATCH 19/28] formatting --- torchvision/csrc/ops/mps/mps_kernels.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index c6d5888a6e8..ec7b825d3a3 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -7,8 +7,8 @@ namespace mps { static const char* METAL_VISION = R"VISION_METAL( -#include #include +#include using namespace metal; /*----------Macros----------*/ From 5dce2d78be6b09d2853f26cbb2f53a1736246df7 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 20 Jun 2023 20:15:37 +0800 Subject: [PATCH 20/28] mps kernel dtype consistency --- torchvision/csrc/ops/mps/mps_kernels.h | 370 ++++++++++++------------- torchvision/csrc/ops/mps/nms_kernel.mm | 14 +- 2 files changed, 192 insertions(+), 192 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index ec7b825d3a3..f9589f646ea 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -17,7 +17,7 @@ using namespace metal; for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ i += (tptg.x * n_tgs)) -#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, int) +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) /*----------Utils----------*/ @@ -52,24 +52,24 @@ void atomic_add_float( device T* data_ptr, const T val) // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { - // If atom_var is not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. - // Try to assign 0 and get the previous assigned addition result. + // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previously assigned addition result. uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); T fetched_float_again = *( (thread T*) &fetched_uint_again ); // Re-add again fetched_float = *((thread T*) &(fetched_uint)); - // Previous assigned addition result + addition result from other threads. + // Previously assigned addition result + addition result from other threads. assigning_float = fetched_float_again + fetched_float; assigning_uint = *( (thread uint*) &assigning_float); } #endif } -template +template inline T bilinear_interpolate( constant T* input, - int64_t height, - int64_t width, + integer_t height, + integer_t width, T y, T x, uint index /* index for debug only*/) { @@ -84,10 +84,10 @@ inline T bilinear_interpolate( if (x <= 0) x = 0; - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; + integer_t y_low = (integer_t)y; + integer_t x_low = (integer_t)x; + integer_t y_high; + integer_t x_high; if (y_low >= height - 1) { y_high = y_low = height - 1; @@ -119,20 +119,20 @@ inline T bilinear_interpolate( return val; } -template +template void bilinear_interpolate_gradient( - int height, - int width, + integer_t height, + integer_t width, T y, T x, thread T& w1, thread T& w2, thread T& w3, thread T& w4, - thread int& x_low, - thread int& x_high, - thread int& y_low, - thread int& y_high, + thread integer_t& x_low, + thread integer_t& x_high, + thread integer_t& y_low, + thread integer_t& y_high, uint index /* index for debug only*/) { // deal with cases that inverse elements are out of feature map boundary if (y < -1.0 || y > height || x < -1.0 || x > width) { @@ -147,8 +147,8 @@ void bilinear_interpolate_gradient( if (x <= 0) x = 0; - y_low = (int)y; - x_low = (int)x; + y_low = (integer_t)y; + x_low = (integer_t)x; if (y_low >= height - 1) { y_high = y_low = height - 1; @@ -201,12 +201,12 @@ bool inline IoU( // This should be in sync with the one in nms_kernel.mm. // Since metal does not support dynamic array, // we need to make it static instead of deriving it from [[threads_per_threadgroup]]. -constant uint nmsThreadsPerBlock = sizeof(uint64_t) * 8; +constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; template kernel void nms(constant T * dev_boxes [[buffer(0)]], device uint64_t * mask [[buffer(1)]], - constant int & n_boxes [[buffer(2)]], + constant int64_t & n_boxes [[buffer(2)]], constant float & iou_threshold [[buffer(3)]], uint2 tgid [[threadgroup_position_in_grid]], uint2 tid2 [[thread_position_in_threadgroup]]) { @@ -237,7 +237,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]], t |= static_cast(1) << i; // discard 1 keep 0 } } - const uint col_blocks = ceil_div(static_cast(n_boxes), nmsThreadsPerBlock); + const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -248,12 +248,12 @@ template \ kernel void nms( \ constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ device uint64_t * mask [[buffer(1)]], \ - constant int & n_boxes [[buffer(2)]], \ + constant int64_t & n_boxes [[buffer(2)]], \ constant float & iou_threshold [[buffer(3)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void roi_align( constant T * input [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -272,13 +272,13 @@ kernel void roi_align( uint2 tid2 [[thread_position_in_threadgroup]]){ MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + integer_t roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical T offset = aligned ? (T)0.5 : (T)0.0; @@ -302,23 +302,23 @@ kernel void roi_align( input + (roi_batch_ind * channels + c) * height * width; // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) + integer_t roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = + integer_t roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); // We do average (integral) pooling inside a bin // When the grid is empty, output zeros. - const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 T output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); @@ -333,10 +333,10 @@ kernel void roi_align( } } -#define REGISTER_ROI_ALIGN_OP(DTYPE) \ +#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("roi_align_" #DTYPE)]] \ -kernel void roi_align( \ +kernel void roi_align( \ constant DTYPE * input [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * output [[buffer(2)]], \ @@ -353,7 +353,7 @@ kernel void roi_align( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void roi_align_backward( constant T * grad_output [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -377,13 +377,13 @@ kernel void roi_align_backward( MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + integer_t roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical T offset = aligned ? (T)0.5 : (T)0.0; @@ -405,35 +405,35 @@ kernel void roi_align_backward( // We need to index the gradient using the tensor strides to access the // correct values. - const int output_offset = n * n_stride + c * c_stride; + const integer_t output_offset = n * n_stride + c * c_stride; constant T* offset_grad_output = grad_output + output_offset; const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride]; // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) + integer_t roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = + integer_t roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); // We do average (integral) pooling inside a bin const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 - const int input_offset = (roi_batch_ind * channels + c) * height * width; + const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; - for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 { const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; + integer_t x_low, x_high, y_low, y_high; bilinear_interpolate_gradient( height, @@ -467,10 +467,10 @@ kernel void roi_align_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("roi_align_backward_" #DTYPE)]] \ -kernel void roi_align_backward( \ +kernel void roi_align_backward( \ constant DTYPE * grad_output [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * grad_input [[buffer(2)]], \ @@ -491,7 +491,7 @@ kernel void roi_align_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void roi_pool( constant T * input [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -509,45 +509,45 @@ kernel void roi_pool( uint2 tid2 [[thread_position_in_threadgroup]]){ MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); // Force malformed ROIs to be 1x1 - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - int roi_height = max(roi_end_h - roi_start_h + 1, 1); + integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, 0), static_cast(height)); - hend = min(max(hend + roi_start_h, 0), static_cast(height)); - wstart = min(max(wstart + roi_start_w, 0), static_cast(width)); - wend = min(max(wend + roi_start_w, 0), static_cast(width)); + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); bool is_empty = (hend <= hstart) || (wend <= wstart); // Define an empty pooling region to be zero T maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; + integer_t maxidx = -1; constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; if (offset_input[input_index] > maxval) { maxval = offset_input[input_index]; maxidx = input_index; @@ -559,10 +559,10 @@ kernel void roi_pool( } } -#define REGISTER_ROI_POOL_OP(DTYPE) \ +#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("roi_pool_" #DTYPE)]] \ -kernel void roi_pool( \ +kernel void roi_pool( \ constant DTYPE * input [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * output [[buffer(2)]], \ @@ -578,7 +578,7 @@ kernel void roi_pool( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void roi_pool_backward( constant T * grad_output [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -601,19 +601,19 @@ kernel void roi_pool_backward( MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + integer_t roi_batch_ind = offset_rois[0]; - const int output_offset = n * n_stride + c * c_stride; - constant int64_t * argmax_data_offset = + const integer_t output_offset = n * n_stride + c * c_stride; + constant integer_t * argmax_data_offset = argmax_data + (n * channels + c) * pooled_height * pooled_width; - const int argmax = argmax_data_offset[ph * pooled_width + pw]; - const int offset = (roi_batch_ind * channels + c) * height * width; + const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; + const integer_t offset = (roi_batch_ind * channels + c) * height * width; if (argmax != -1) { atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); @@ -622,10 +622,10 @@ kernel void roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE) \ +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("roi_pool_backward_" #DTYPE)]] \ -kernel void roi_pool_backward( \ +kernel void roi_pool_backward( \ constant DTYPE * grad_output [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ constant int64_t * argmax_data [[buffer(2)]], \ @@ -645,7 +645,7 @@ kernel void roi_pool_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void ps_roi_align( constant T * input [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -665,17 +665,17 @@ kernel void ps_roi_align( uint2 tid2 [[thread_position_in_threadgroup]]){ MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c_out, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c_out = (index / pooled_width / pooled_height) % channels_out; - int n = index / pooled_width / pooled_height / channels_out; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / pooled_width / pooled_height) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; // (n, c_in, ph, pw) is the associated element in the input - int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; // [start, end) interval for spatial sampling constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + integer_t roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); @@ -693,21 +693,21 @@ kernel void ps_roi_align( T wstart = static_cast(pw) * bin_size_w + roi_start_w; // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) + integer_t roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); - int roi_bin_grid_w = + integer_t roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); const T count = roi_bin_grid_h * roi_bin_grid_w; constant T* offset_input = input + (roi_batch_ind * channels + c_in) * height * width; T out_sum = 0; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { const T y = hstart + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { const T x = wstart + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); @@ -722,10 +722,10 @@ kernel void ps_roi_align( } } -#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \ +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("ps_roi_align_" #DTYPE)]] \ -kernel void ps_roi_align( \ +kernel void ps_roi_align( \ constant DTYPE * input [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * output [[buffer(2)]], \ @@ -743,7 +743,7 @@ kernel void ps_roi_align( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void ps_roi_align_backward( constant T * grad_output [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -764,12 +764,12 @@ kernel void ps_roi_align_backward( MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, *, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int n = index / pooled_width / pooled_height / channels_out; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + integer_t roi_batch_ind = offset_rois[0]; // Do not using rounding; this implementation detail is critical T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); @@ -783,7 +783,7 @@ kernel void ps_roi_align_backward( T bin_size_h = roi_height / static_cast(pooled_height); T bin_size_w = roi_width / static_cast(pooled_width); - int c_in = channel_mapping[index]; + integer_t c_in = channel_mapping[index]; // Do not using floor/ceil; this implementation detail is critical T hstart = static_cast(ph) * bin_size_h + roi_start_h; @@ -792,26 +792,26 @@ kernel void ps_roi_align_backward( const T grad_output_this_bin = grad_output[index]; // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) + integer_t roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = + integer_t roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); const T count = roi_bin_grid_h * roi_bin_grid_w; - const int offset = (roi_batch_ind * channels + c_in) * height * width; + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { const T y = hstart + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { const T x = wstart + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; + integer_t x_low, x_high, y_low, y_high; bilinear_interpolate_gradient( height, @@ -844,10 +844,10 @@ kernel void ps_roi_align_backward( } } -#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("ps_roi_align_backward_" #DTYPE)]] \ -kernel void ps_roi_align_backward( \ +kernel void ps_roi_align_backward( \ constant DTYPE * grad_output [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ constant int64_t * channel_mapping [[buffer(2)]], \ @@ -865,7 +865,7 @@ kernel void ps_roi_align_backward( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void ps_roi_pool( constant T * input [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -884,46 +884,46 @@ kernel void ps_roi_pool( uint2 tid2 [[thread_position_in_threadgroup]]){ MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, c_out, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c_out = (index / (pooled_width * pooled_height)) % channels_out; - int n = index / pooled_width / pooled_height / channels_out; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; // (n, c_in, ph, pw) is the associated element in the input - int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; // [start, end) interval for spatial sampling constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); // Force too small ROIs to be 1x1 - int roi_width = max(roi_end_w - roi_start_w, 1); - int roi_height = max(roi_end_h - roi_start_h, 1); + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, 0), static_cast(height - 1)); - hend = min(max(hend + roi_start_h, 0), static_cast(height - 1)); - wstart = min(max(wstart + roi_start_w, 0), static_cast(width - 1)); - wend = min(max(wend + roi_start_w, 0), static_cast(width - 1)); + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); bool is_empty = (hend <= hstart) || (wend <= wstart); constant T* offset_input = input + (roi_batch_ind * channels + c_in) * height * width; T out_sum = 0; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_index = h * width + w; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; out_sum += offset_input[input_index]; } } @@ -934,10 +934,10 @@ kernel void ps_roi_pool( } } -#define REGISTER_PS_ROI_POOL_OP(DTYPE) \ +#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("ps_roi_pool_" #DTYPE)]] \ -kernel void ps_roi_pool( \ +kernel void ps_roi_pool( \ constant DTYPE * input [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ device DTYPE * output [[buffer(2)]], \ @@ -954,7 +954,7 @@ kernel void ps_roi_pool( \ uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void ps_roi_pool_backward( constant T * grad_output [[buffer(0)]], constant T * rois [[buffer(1)]], @@ -974,44 +974,44 @@ kernel void ps_roi_pool_backward( MPS_1D_KERNEL_LOOP(index, output_size, 1) { // (n, *, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int n = index / pooled_width / pooled_height / channels_out; + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; constant T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; - int roi_start_w = round(offset_rois[1] * spatial_scale); - int roi_start_h = round(offset_rois[2] * spatial_scale); - int roi_end_w = round(offset_rois[3] * spatial_scale); - int roi_end_h = round(offset_rois[4] * spatial_scale); + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); // Force too small ROIs to be 1x1 - int roi_width = max(roi_end_w - roi_start_w, 1); - int roi_height = max(roi_end_h - roi_start_h, 1); + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries - hstart = min(max(hstart + roi_start_h, 0), static_cast(height)); - hend = min(max(hend + roi_start_h, 0), static_cast(height)); - wstart = min(max(wstart + roi_start_w, 0), static_cast(width)); - wend = min(max(wend + roi_start_w, 0), static_cast(width)); + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); bool is_empty = (hend <= hstart) || (wend <= wstart); - int c_in = channel_mapping[index]; + integer_t c_in = channel_mapping[index]; T bin_area = (hend - hstart) * (wend - wstart); T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; - const int offset = (roi_batch_ind * channels + c_in) * height * width; + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int grad_input_index = h * width + w; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t grad_input_index = h * width + w; atomic_add_float(grad_input + offset + grad_input_index, diff_val); } } @@ -1019,10 +1019,10 @@ kernel void ps_roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE) \ +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ template \ [[host_name("ps_roi_pool_backward_" #DTYPE)]] \ -kernel void ps_roi_pool_backward( \ +kernel void ps_roi_pool_backward( \ constant DTYPE * grad_output [[buffer(0)]], \ constant DTYPE * rois [[buffer(1)]], \ constant int64_t * channel_mapping [[buffer(2)]], \ @@ -1041,22 +1041,22 @@ kernel void ps_roi_pool_backward( \ REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); -REGISTER_ROI_ALIGN_OP(float); -REGISTER_ROI_ALIGN_OP(half); -REGISTER_ROI_ALIGN_BACKWARD_OP(float); -REGISTER_ROI_ALIGN_BACKWARD_OP(half); -REGISTER_ROI_POOL_OP(float); -REGISTER_ROI_POOL_OP(half); -REGISTER_ROI_POOL_BACKWARD_OP(float); -REGISTER_ROI_POOL_BACKWARD_OP(half); -REGISTER_PS_ROI_ALIGN_OP(float); -REGISTER_PS_ROI_ALIGN_OP(half); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); -REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); -REGISTER_PS_ROI_POOL_OP(float); -REGISTER_PS_ROI_POOL_OP(half); -REGISTER_PS_ROI_POOL_BACKWARD_OP(float); -REGISTER_PS_ROI_POOL_BACKWARD_OP(half); +REGISTER_ROI_ALIGN_OP(float, int64_t); +REGISTER_ROI_ALIGN_OP(half, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_ROI_POOL_OP(float, int64_t); +REGISTER_ROI_POOL_OP(half, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_POOL_OP(float, int64_t); +REGISTER_PS_ROI_POOL_OP(half, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); )VISION_METAL"; diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index 97064cf9aa3..c405ef11754 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -11,7 +11,7 @@ namespace { // This should be in sync with `nmsThreadsPerBlock` in the metal kernel. -constexpr int nmsThreadsPerBlock = sizeof(uint64_t) * 8; +constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { using namespace at::native::mps; @@ -34,7 +34,7 @@ auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); auto dets_sorted = dets.index_select(0, order_t).contiguous(); - int dets_num = dets.size(0); + int64_t dets_num = dets.size(0); float iou_threshold_f = static_cast(iou_threshold); const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; @@ -58,7 +58,7 @@ [computeEncoder setComputePipelineState:visionPSO]; [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; - [computeEncoder setBytes:&dets_num length:sizeof(int) atIndex:2]; + [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; // A threadGroup is equivalent to a cuda's block. @@ -85,14 +85,14 @@ at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); int64_t* keep_out = keep.data_ptr(); - for (int i = 0; i < dets_num; i++) { - int nblock = i / nmsThreadsPerBlock; - int inblock = i % nmsThreadsPerBlock; + for (int64_t i = 0; i < dets_num; i++) { + int64_t nblock = i / nmsThreadsPerBlock; + int64_t inblock = i % nmsThreadsPerBlock; if (!(remv[nblock] & (1ULL << inblock))) { keep_out[num_to_keep++] = i; unsigned long long* p = mask_host + i * col_blocks; - for (int j = nblock; j < col_blocks; j++) { + for (int64_t j = nblock; j < col_blocks; j++) { remv[j] |= p[j]; } } From 40ebde5ca27a6f16e7dd8392c02edac1aa15a697 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 21 Jun 2023 00:39:54 +0800 Subject: [PATCH 21/28] Kernel improvements --- torchvision/csrc/ops/mps/mps_kernels.h | 319 +++++++++++++------------ 1 file changed, 161 insertions(+), 158 deletions(-) diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index f9589f646ea..e720a1608f1 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -19,7 +19,7 @@ using namespace metal; #define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) -/*----------Utils----------*/ +/*----------Helpers--------*/ template inline T ceil_div(T n, T m) { @@ -27,15 +27,18 @@ inline T ceil_div(T n, T m) { } template -void atomic_add_float( device T* data_ptr, const T val) +inline void atomic_add_float( device T* data_ptr, const T val) { #if __METAL_VERSION__ >= 300 // atomic_float is supported in Metal 3 (macOS Ventura) onward. device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); #else + // Custom atomic addition implementation // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 - // Create an atomic uint pointer for atomic checking. + // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) + + // Create an atomic uint pointer for atomic transaction. device atomic_uint* atom_var = (device atomic_uint*)data_ptr; // Create necessary storage. uint fetched_uint, assigning_uint; @@ -46,7 +49,7 @@ void atomic_add_float( device T* data_ptr, const T val) // Read out the previous value as float. fetched_float = *( (thread T*) &fetched_uint ); - // Do addition and represent the addition result in uint for atomic checking. + // Do addition and represent the addition result in uint for atomic transaction. assigning_float = fetched_float + val; assigning_uint = *((thread uint*) &assigning_float); @@ -120,7 +123,7 @@ inline T bilinear_interpolate( } template -void bilinear_interpolate_gradient( +inline void bilinear_interpolate_gradient( integer_t height, integer_t width, T y, @@ -179,7 +182,7 @@ void bilinear_interpolate_gradient( } template -bool inline IoU( +inline bool IoU( constant T & a, threadgroup T & b, const float threshold) { @@ -333,24 +336,24 @@ kernel void roi_align( } } -#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_" #DTYPE)]] \ -kernel void roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -467,28 +470,28 @@ kernel void roi_align_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_backward_" #DTYPE)]] \ -kernel void roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * grad_input [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - constant int64_t & n_stride [[buffer(12)]], \ - constant int64_t & c_stride [[buffer(13)]], \ - constant int64_t & h_stride [[buffer(14)]], \ - constant int64_t & w_stride [[buffer(15)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -559,23 +562,23 @@ kernel void roi_pool( } } -#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_" #DTYPE)]] \ -kernel void roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * argmax_data [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax_data [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -622,27 +625,27 @@ kernel void roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_pool_backward_" #DTYPE)]] \ -kernel void roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * argmax_data [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant float & spatial_scale [[buffer(10)]], \ - constant int64_t & n_stride [[buffer(11)]], \ - constant int64_t & c_stride [[buffer(12)]], \ - constant int64_t & h_stride [[buffer(13)]], \ - constant int64_t & w_stride [[buffer(14)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -722,25 +725,25 @@ kernel void ps_roi_align( } } -#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_" #DTYPE)]] \ -kernel void ps_roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -844,25 +847,25 @@ kernel void ps_roi_align_backward( } } -#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_align_backward_" #DTYPE)]] \ -kernel void ps_roi_align_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & sampling_ratio [[buffer(10)]], \ - constant int64_t & channels_out [[buffer(11)]], \ - constant float & spatial_scale [[buffer(12)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -934,24 +937,24 @@ kernel void ps_roi_pool( } } -#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_" #DTYPE)]] \ -kernel void ps_roi_pool( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - device int64_t * channel_mapping [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); template @@ -1019,24 +1022,24 @@ kernel void ps_roi_pool_backward( } // MPS_1D_KERNEL_LOOP } -#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ -kernel void ps_roi_pool_backward( \ - constant DTYPE * grad_output [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - constant int64_t * channel_mapping [[buffer(2)]], \ - device DTYPE * grad_input [[buffer(3)]], \ - constant int64_t & output_size [[buffer(4)]], \ - constant int64_t & channels [[buffer(5)]], \ - constant int64_t & height [[buffer(6)]], \ - constant int64_t & width [[buffer(7)]], \ - constant int64_t & pooled_height [[buffer(8)]], \ - constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ uint2 tid2 [[thread_position_in_threadgroup]]); REGISTER_NMS_OP(float); From efbb52e27ca792d6e943111f910c1acbc5985a5a Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 4 Jul 2023 18:07:31 +0800 Subject: [PATCH 22/28] Apply suggestions from code review --- setup.py | 1 - test/common_utils.py | 1 - test/conftest.py | 3 --- 3 files changed, 5 deletions(-) diff --git a/setup.py b/setup.py index 8124783feb5..cd41081142d 100644 --- a/setup.py +++ b/setup.py @@ -207,7 +207,6 @@ def get_extensions(): extra_compile_args["nvcc"] = nvcc_flags elif torch.backends.mps.is_available() or force_mps: sources += source_mps - define_macros += [("WITH_MPS", None)] if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)] diff --git a/test/common_utils.py b/test/common_utils.py index c19f1c05070..32f36cf5a21 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -36,7 +36,6 @@ CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." -OSS_CI_GPU_NO_MPS_MSG = "We're in an OSS M1 machine, and this test doesn't need mps." @contextlib.contextmanager diff --git a/test/conftest.py b/test/conftest.py index 819ca7ce229..4c5c97fb7af 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -15,7 +15,6 @@ IN_RE_WORKER, MPS_NOT_AVAILABLE_MSG, OSS_CI_GPU_NO_CUDA_MSG, - OSS_CI_GPU_NO_MPS_MSG, ) @@ -73,8 +72,6 @@ def pytest_collection_modifyitems(items): # Similar to what happens in RE workers: we don't need the OSS CI GPU machines # to run the CPU-only tests. item.add_marker(pytest.mark.skip(reason=OSS_CI_GPU_NO_CUDA_MSG)) - if not needs_mps and torch.backends.mps.is_available(): - item.add_marker(pytest.mark.skip(reason=OSS_CI_GPU_NO_MPS_MSG)) if item.get_closest_marker("dont_collect") is not None: # currently, this is only used for some tests we're sure we don't want to run on fbcode From b36cafaa0661246663483b9deabcbdd0d870106d Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 4 Jul 2023 18:39:30 +0800 Subject: [PATCH 23/28] Test more dtypes for roi forward functions and assert half inputs in MPS roi backward kernels --- test/test_ops.py | 65 ++++++++++--------- .../csrc/ops/mps/ps_roi_align_kernel.mm | 1 + .../csrc/ops/mps/ps_roi_pool_kernel.mm | 1 + torchvision/csrc/ops/mps/roi_align_kernel.mm | 1 + torchvision/csrc/ops/mps/roi_pool_kernel.mm | 1 + torchvision/ops/roi_align.py | 4 +- 6 files changed, 41 insertions(+), 32 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a6e851e9943..ac9a43cce0f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -101,19 +101,35 @@ class RoIOpTester(ABC): @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): - dtype = self.mps_dtype if device == "mps" else self.dtype - x_dtype = dtype if x_dtype is None else x_dtype - rois_dtype = dtype if rois_dtype is None else rois_dtype + @pytest.mark.parametrize( + "dtype", + ( + torch.float16, + torch.float32, + torch.float64, + ), + ids=str, + ) + def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs): + if device == "mps" and dtype is torch.float64: + pytest.skip("MPS does not support float64") + + tol = 1e-5 + if dtype is torch.half: + if device == "mps": + tol = 5e-3 + else: + tol = 4e-3 + pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS operations. n_channels = 2 * (pool_size**2) - x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) + x = torch.rand(2, n_channels, 10, 10, dtype=dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) - dtype=rois_dtype, + dtype=dtype, device=device, ) @@ -126,7 +142,6 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=dtype, **kwargs ) - tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -429,17 +444,17 @@ def test_boxes_shape(self): @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) + @pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64), ids=str) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) - def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): + def test_forward(self, device, contiguous, deterministic, aligned, dtype): if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") super().test_forward( device=device, contiguous=contiguous, deterministic=deterministic, - x_dtype=x_dtype, - rois_dtype=rois_dtype, + dtype=dtype, aligned=aligned, ) @@ -759,32 +774,22 @@ def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): self.test_nms_cuda(iou=iou, dtype=dtype) - @needs_cuda - def test_nms_cuda_float16(self): - boxes = torch.tensor( - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - ).cuda() - scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() - - iou_thres = 0.2 - keep32 = ops.nms(boxes, scores, iou_thres) - keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres) - assert_equal(keep32, keep16) - - @needs_mps - def test_nms_mps_float16(self): + @pytest.mark.parametrize( + "device", + ( + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("mps", marks=pytest.mark.needs_mps), + ), + ) + def test_nms_float16(self, device): boxes = torch.tensor( [ [285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], [279.2440, 197.9812, 1189.4746, 849.2019], ] - ).to("mps") - scores = torch.tensor([0.6370, 0.7569, 0.3966]).to("mps") + ).to(device) + scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device) iou_thres = 0.2 keep32 = ops.nms(boxes, scores, iou_thres) diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 1e0a2d902ee..e3fae9bc818 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -119,6 +119,7 @@ using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index 070e1fee6f4..f0592030264 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -114,6 +114,7 @@ using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index cd6483847a1..5a0e93c43a9 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -110,6 +110,7 @@ using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index aa01e4f4fc0..2c42ed6056a 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -108,6 +108,7 @@ using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 3e839b05526..0d505c140ee 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -158,12 +158,12 @@ def from_K(t): y = ( from_K(roi_start_h) + ph[None, :, None] * from_K(bin_size_h) - + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h) ) # [K, PH, IY] x = ( from_K(roi_start_w) + pw[None, :, None] * from_K(bin_size_w) - + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w) ) # [K, PW, IX] val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] From fad54f685cfa5fadf2c847ad3fc6f4ec388e081e Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Fri, 7 Jul 2023 16:07:47 +0800 Subject: [PATCH 24/28] Add mps error inputs check --- test/test_ops.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index ac9a43cce0f..4dc6e82bd86 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -198,6 +198,22 @@ def func(z): gradcheck(script_func, (x,), atol=atol) + @needs_mps + def test_mps_error_inputs(self): + pool_size = 2 + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True) + rois = torch.tensor( + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy) + ) + + def func(z): + return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) + + with pytest.raises( + RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs." + ): + gradcheck(func, (x,)) + @needs_cuda @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) From 66a00fc7d7eb19ac6e440dc44ccee090319327e5 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 11 Jul 2023 00:43:45 +0800 Subject: [PATCH 25/28] Clean up headers --- torchvision/csrc/ops/mps/nms_kernel.mm | 3 --- torchvision/csrc/ops/mps/ps_roi_align_kernel.mm | 3 --- torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm | 3 --- torchvision/csrc/ops/mps/roi_align_kernel.mm | 3 --- torchvision/csrc/ops/mps/roi_pool_kernel.mm | 3 --- 5 files changed, 15 deletions(-) diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm index c405ef11754..5ee9b5cbeae 100644 --- a/torchvision/csrc/ops/mps/nms_kernel.mm +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -2,9 +2,6 @@ #include #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index e3fae9bc818..16b711ad5ef 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -3,9 +3,6 @@ #include "mps_helpers.h" #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm index f0592030264..fc24f6990fa 100644 --- a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -3,9 +3,6 @@ #include "mps_helpers.h" #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index 5a0e93c43a9..d4ed8b43fd2 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -3,9 +3,6 @@ #include "mps_helpers.h" #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm index 2c42ed6056a..816d8d70863 100644 --- a/torchvision/csrc/ops/mps/roi_pool_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -3,9 +3,6 @@ #include "mps_helpers.h" #include "mps_kernels.h" -#include -#include - namespace vision { namespace ops { From 70f3906f76b890b2ad268b4d746a5677a6b1a9cf Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Mon, 17 Jul 2023 16:47:58 +0800 Subject: [PATCH 26/28] Fix dtype parameters --- test/test_ops.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4dc6e82bd86..3ddc1706296 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -102,7 +102,7 @@ class RoIOpTester(ABC): @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize( - "dtype", + "x_dtype", ( torch.float16, torch.float32, @@ -110,12 +110,14 @@ class RoIOpTester(ABC): ), ids=str, ) - def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs): - if device == "mps" and dtype is torch.float64: + def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): + if device == "mps" and x_dtype is torch.float64: pytest.skip("MPS does not support float64") + rois_dtype = x_dtype if rois_dtype is None else rois_dtype + tol = 1e-5 - if dtype is torch.half: + if x_dtype is torch.half: if device == "mps": tol = 5e-3 else: @@ -124,12 +126,12 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs) pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS operations. n_channels = 2 * (pool_size**2) - x = torch.rand(2, n_channels, 10, 10, dtype=dtype, device=device) + x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) - dtype=dtype, + dtype=rois_dtype, device=device, ) @@ -139,7 +141,7 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs) # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype gt_y = self.expected_fn( - x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=dtype, **kwargs + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs ) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) @@ -460,17 +462,18 @@ def test_boxes_shape(self): @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) - @pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64), ids=str) + @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) - def test_forward(self, device, contiguous, deterministic, aligned, dtype): + def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None): if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") super().test_forward( device=device, contiguous=contiguous, deterministic=deterministic, - dtype=dtype, + x_dtype=x_dtype, + rois_dtype=rois_dtype, aligned=aligned, ) From b1cf61975f531ebafa074cae9442289f147d0e02 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Mon, 17 Jul 2023 17:07:40 +0800 Subject: [PATCH 27/28] parameterize nms gpu test --- test/test_ops.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3ddc1706296..743fe159e37 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -751,39 +751,28 @@ def test_qnms(self, iou, scale, zero_point): torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) - @needs_cuda + @pytest.mark.parametrize( + "device", + ( + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("mps", marks=pytest.mark.needs_mps), + ), + ) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_cuda(self, iou, dtype=torch.float64): + def test_nms_gpu(self, iou, device, dtype=torch.float64): + dtype = torch.float32 if device == "mps" else dtype tol = 1e-3 if dtype is torch.half else 1e-5 err_msg = "NMS incompatible between CPU and CUDA for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) r_cpu = ops.nms(boxes, scores, iou) - r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) - - is_eq = torch.allclose(r_cpu, r_cuda.cpu()) - if not is_eq: - # if the indices are not the same, ensure that it's because the scores - # are duplicate - is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) - assert is_eq, err_msg.format(iou) - - @needs_mps - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_mps(self, iou, dtype=torch.float32): - tol = 1e-3 if dtype is torch.half else 1e-5 - err_msg = "NMS incompatible between CPU and MPS for IoU={}" - - boxes, scores = self._create_tensors_with_iou(1000, iou) - r_cpu = ops.nms(boxes, scores, iou) - r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou) + r_gpu = ops.nms(boxes.to(device), scores.to(device), iou) - print(r_cpu.size(), r_mps.size()) - is_eq = torch.allclose(r_cpu, r_mps.cpu()) + is_eq = torch.allclose(r_cpu, r_gpu.cpu()) if not is_eq: # if the indices are not the same, ensure that it's because the scores # are duplicate - is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol) + is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) @needs_cuda @@ -791,7 +780,7 @@ def test_nms_mps(self, iou, dtype=torch.float32): @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): - self.test_nms_cuda(iou=iou, dtype=dtype) + self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") @pytest.mark.parametrize( "device", From 108bc155453d4e16a4101623b7f3baf6e8b8a9af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Aug 2023 12:24:20 +0100 Subject: [PATCH 28/28] Allow to skip MPS tests internally on non-MPS machines --- test/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/conftest.py b/test/conftest.py index 4c5c97fb7af..a54028bc70d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -66,6 +66,9 @@ def pytest_collection_modifyitems(items): # TODO: something more robust would be to do that only in a sandcastle instance, # so that we can still see the test being skipped when testing locally from a devvm continue + if needs_mps and not torch.backends.mps.is_available(): + # Same as above, but for MPS + continue elif IN_OSS_CI: # Here we're not in fbcode, so we can safely collect and skip tests. if not needs_cuda and torch.cuda.is_available():