diff --git a/.github/scripts/run-clang-format.py b/.github/scripts/run-clang-format.py index 5c61b2519e0..670fd97833a 100755 --- a/.github/scripts/run-clang-format.py +++ b/.github/scripts/run-clang-format.py @@ -48,7 +48,7 @@ DEVNULL = open(os.devnull, "wb") -DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm" class ExitStatus: diff --git a/CMakeLists.txt b/CMakeLists.txt index 405f947c233..0cd485d7a24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17) file(STRINGS version.txt TORCHVISION_VERSION) option(WITH_CUDA "Enable CUDA support" OFF) +option(WITH_MPS "Enable MPS support" OFF) option(WITH_PNG "Enable features requiring LibPNG." ON) option(WITH_JPEG "Enable features requiring LibJPEG." ON) option(USE_PYTHON "Link to Python when building" OFF) @@ -15,6 +16,11 @@ if(WITH_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") endif() +if(WITH_MPS) + enable_language(OBJC OBJCXX) + add_definitions(-DWITH_MPS) +endif() + find_package(Torch REQUIRED) if (WITH_PNG) @@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() +if(WITH_MPS) + list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps) +endif() FOREACH(DIR ${ALLOW_LISTED}) file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*) diff --git a/setup.py b/setup.py index 8b8ddcde1b9..cd41081142d 100644 --- a/setup.py +++ b/setup.py @@ -137,10 +137,13 @@ def get_extensions(): + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) + source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm")) print("Compiling extensions with following flags:") force_cuda = os.getenv("FORCE_CUDA", "0") == "1" print(f" FORCE_CUDA: {force_cuda}") + force_mps = os.getenv("FORCE_MPS", "0") == "1" + print(f" FORCE_MPS: {force_mps}") debug_mode = os.getenv("DEBUG", "0") == "1" print(f" DEBUG: {debug_mode}") use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1" @@ -202,6 +205,8 @@ def get_extensions(): define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags + elif torch.backends.mps.is_available() or force_mps: + sources += source_mps if sys.platform == "win32": define_macros += [("torchvision_EXPORTS", None)] diff --git a/test/common_utils.py b/test/common_utils.py index b5edda3edb2..3f8a12e161c 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -34,6 +34,7 @@ IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" +MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -130,12 +131,22 @@ def cpu_and_cuda(): return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) +def cpu_and_cuda_and_mps(): + return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),) + + def needs_cuda(test_func): import pytest # noqa return pytest.mark.needs_cuda(test_func) +def needs_mps(test_func): + import pytest # noqa + + return pytest.mark.needs_mps(test_func) + + def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) diff --git a/test/conftest.py b/test/conftest.py index 468587f1c9e..a54028bc70d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,12 +8,20 @@ torchvision.disable_beta_transforms_warning() -from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG +from common_utils import ( + CUDA_NOT_AVAILABLE_MSG, + IN_FBCODE, + IN_OSS_CI, + IN_RE_WORKER, + MPS_NOT_AVAILABLE_MSG, + OSS_CI_GPU_NO_CUDA_MSG, +) def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") + config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") @@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items): # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # and the ones with device == 'cpu' won't have the mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None + needs_mps = item.get_closest_marker("needs_mps") is not None if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) + if needs_mps and not torch.backends.mps.is_available(): + item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) + if IN_FBCODE: # fbcode doesn't like skipping tests, so instead we just don't collect the test # so that they don't even "exist", hence the continue statements. @@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items): # TODO: something more robust would be to do that only in a sandcastle instance, # so that we can still see the test being skipped when testing locally from a devvm continue + if needs_mps and not torch.backends.mps.is_available(): + # Same as above, but for MPS + continue elif IN_OSS_CI: # Here we're not in fbcode, so we can safely collect and skip tests. if not needs_cuda and torch.cuda.is_available(): diff --git a/test/test_ops.py b/test/test_ops.py index b993bce65a2..743fe159e37 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ import torch import torch.fx import torch.nn.functional as F -from common_utils import assert_equal, cpu_and_cuda, needs_cuda +from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: class RoIOpTester(ABC): dtype = torch.float64 + mps_dtype = torch.float32 + mps_backward_atol = 2e-2 - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): - x_dtype = self.dtype if x_dtype is None else x_dtype - rois_dtype = self.dtype if rois_dtype is None else rois_dtype + @pytest.mark.parametrize( + "x_dtype", + ( + torch.float16, + torch.float32, + torch.float64, + ), + ids=str, + ) + def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): + if device == "mps" and x_dtype is torch.float64: + pytest.skip("MPS does not support float64") + + rois_dtype = x_dtype if rois_dtype is None else rois_dtype + + tol = 1e-5 + if x_dtype is torch.half: + if device == "mps": + tol = 5e-3 + else: + tol = 4e-3 + pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS operations. n_channels = 2 * (pool_size**2) @@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype gt_y = self.expected_fn( - x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs ) - tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, seed, device, contiguous, deterministic=False): + atol = self.mps_backward_atol if device == "mps" else 1e-05 + dtype = self.mps_dtype if device == "mps" else self.dtype + torch.random.manual_seed(seed) pool_size = 2 - x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor( - [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy) ) def func(z): @@ -173,9 +196,25 @@ def func(z): script_func = self.get_script_fn(rois, pool_size) with DeterministicGuard(deterministic): - gradcheck(func, (x,)) + gradcheck(func, (x,), atol=atol) + + gradcheck(script_func, (x,), atol=atol) - gradcheck(script_func, (x,)) + @needs_mps + def test_mps_error_inputs(self): + pool_size = 2 + x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True) + rois = torch.tensor( + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy) + ) + + def func(z): + return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) + + with pytest.raises( + RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs." + ): + gradcheck(func, (x,)) @needs_cuda @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @@ -271,6 +310,8 @@ def test_jit_boxes_list(self): class TestPSRoIPool(RoIOpTester): + mps_backward_atol = 5e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) @@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False): class TestRoIAlign(RoIOpTester): + mps_backward_atol = 6e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): return ops.RoIAlign( (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned @@ -418,10 +461,11 @@ def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) @pytest.mark.parametrize("aligned", (True, False)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) + @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) - def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): + def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None): if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") super().test_forward( @@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): ) @pytest.mark.parametrize("seed", range(10)) - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_backward(self, seed, device, contiguous, deterministic): @@ -537,6 +581,8 @@ def test_jit_boxes_list(self): class TestPSRoIAlign(RoIOpTester): + mps_backward_atol = 5e-2 + def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) @@ -705,21 +751,28 @@ def test_qnms(self, iou, scale, zero_point): torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) - @needs_cuda + @pytest.mark.parametrize( + "device", + ( + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("mps", marks=pytest.mark.needs_mps), + ), + ) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_cuda(self, iou, dtype=torch.float64): + def test_nms_gpu(self, iou, device, dtype=torch.float64): + dtype = torch.float32 if device == "mps" else dtype tol = 1e-3 if dtype is torch.half else 1e-5 err_msg = "NMS incompatible between CPU and CUDA for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) r_cpu = ops.nms(boxes, scores, iou) - r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) + r_gpu = ops.nms(boxes.to(device), scores.to(device), iou) - is_eq = torch.allclose(r_cpu, r_cuda.cpu()) + is_eq = torch.allclose(r_cpu, r_gpu.cpu()) if not is_eq: # if the indices are not the same, ensure that it's because the scores # are duplicate - is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) + is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) @needs_cuda @@ -727,18 +780,24 @@ def test_nms_cuda(self, iou, dtype=torch.float64): @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): - self.test_nms_cuda(iou=iou, dtype=dtype) + self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") - @needs_cuda - def test_nms_cuda_float16(self): + @pytest.mark.parametrize( + "device", + ( + pytest.param("cuda", marks=pytest.mark.needs_cuda), + pytest.param("mps", marks=pytest.mark.needs_mps), + ), + ) + def test_nms_float16(self, device): boxes = torch.tensor( [ [285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], [279.2440, 197.9812, 1189.4746, 849.2019], ] - ).cuda() - scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() + ).to(device) + scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device) iou_thres = 0.2 keep32 = ops.nms(boxes, scores, iou_thres) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index c54d1f00148..50479066cbd 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); - TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); + TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); TORCH_CHECK( dets.scalar_type() == scores.scalar_type(), "dets should have the same type as scores"); diff --git a/torchvision/csrc/ops/mps/mps_helpers.h b/torchvision/csrc/ops/mps/mps_helpers.h new file mode 100644 index 00000000000..d3c0e8d94b7 --- /dev/null +++ b/torchvision/csrc/ops/mps/mps_helpers.h @@ -0,0 +1,6 @@ +constexpr int threadsPerBlock = 512; + +template +constexpr inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h new file mode 100644 index 00000000000..e720a1608f1 --- /dev/null +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -0,0 +1,1102 @@ +#include + +namespace vision { +namespace ops { + +namespace mps { + +static const char* METAL_VISION = R"VISION_METAL( + +#include +#include +using namespace metal; + +/*----------Macros----------*/ + +#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \ + for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \ + i += (tptg.x * n_tgs)) + +#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint) + +/*----------Helpers--------*/ + +template +inline T ceil_div(T n, T m) { + return (n + m - 1) / m; +} + +template +inline void atomic_add_float( device T* data_ptr, const T val) +{ +#if __METAL_VERSION__ >= 300 + // atomic_float is supported in Metal 3 (macOS Ventura) onward. + device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed); +#else + // Custom atomic addition implementation + // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472 + // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639 + // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide) + + // Create an atomic uint pointer for atomic transaction. + device atomic_uint* atom_var = (device atomic_uint*)data_ptr; + // Create necessary storage. + uint fetched_uint, assigning_uint; + T fetched_float, assigning_float; + + // Replace the value in atom_var with 0 and return the previous value in atom_var. + fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed); + // Read out the previous value as float. + fetched_float = *( (thread T*) &fetched_uint ); + + // Do addition and represent the addition result in uint for atomic transaction. + assigning_float = fetched_float + val; + assigning_uint = *((thread uint*) &assigning_float); + + // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr). + while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) { + // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads. + // Try to assign 0 and get the previously assigned addition result. + uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed); + T fetched_float_again = *( (thread T*) &fetched_uint_again ); + // Re-add again + fetched_float = *((thread T*) &(fetched_uint)); + // Previously assigned addition result + addition result from other threads. + assigning_float = fetched_float_again + fetched_float; + assigning_uint = *( (thread uint*) &assigning_float); + } +#endif +} + +template +inline T bilinear_interpolate( + constant T* input, + integer_t height, + integer_t width, + T y, + T x, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + integer_t y_low = (integer_t)y; + integer_t x_low = (integer_t)x; + integer_t y_high; + integer_t x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +inline void bilinear_interpolate_gradient( + integer_t height, + integer_t width, + T y, + T x, + thread T& w1, + thread T& w2, + thread T& w3, + thread T& w4, + thread integer_t& x_low, + thread integer_t& x_high, + thread integer_t& y_low, + thread integer_t& y_high, + uint index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (integer_t)y; + x_low = (integer_t)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; +} + +template +inline bool IoU( + constant T & a, + threadgroup T & b, + const float threshold) { + auto xx1 = max(a.x, b.x); + auto yy1 = max(a.y, b.y); + auto xx2 = min(a.z, b.z); + auto yy2 = min(a.w, b.w); + auto w = max(static_cast(0), xx2 - xx1); + auto h = max(static_cast(0), yy2 - yy1); + // Upcast to float before multiplications to circumvent precision issues in half. + auto inter = static_cast(w) * static_cast(h); + auto area_b = static_cast(b.z - b.x) * static_cast(b.w - b.y); + auto area_a = static_cast(a.z - a.x) * static_cast(a.w - a.y); + return (inter / (area_a + area_b - inter)) > threshold; +} + +/*----------Kernels----------*/ + +// This should be in sync with the one in nms_kernel.mm. +// Since metal does not support dynamic array, +// we need to make it static instead of deriving it from [[threads_per_threadgroup]]. +constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +template +kernel void nms(constant T * dev_boxes [[buffer(0)]], + device uint64_t * mask [[buffer(1)]], + constant int64_t & n_boxes [[buffer(2)]], + constant float & iou_threshold [[buffer(3)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tid2 [[thread_position_in_threadgroup]]) { + + const uint row_start = tgid.y; + const uint col_start = tgid.x; + const uint tid = tid2.x; + const uint row_size = + min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + const uint col_size = + min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock); + + threadgroup T block_boxes[nmsThreadsPerBlock]; + block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid < row_size) { + const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid; + uint64_t t = 0; + uint start = 0; + + if (row_start == col_start) { + start = tid + 1; + } + + for (uint i = start; i < col_size; i++){ + if (IoU(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){ + t |= static_cast(1) << i; // discard 1 keep 0 + } + } + const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock); + mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +#define REGISTER_NMS_OP(DTYPE) \ +template \ +[[host_name("nms_" #DTYPE)]] \ +kernel void nms( \ + constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \ + device uint64_t * mask [[buffer(1)]], \ + constant int64_t & n_boxes [[buffer(2)]], \ + constant float & iou_threshold [[buffer(3)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 + + T output_val = 0.; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + output[index] = output_val; + } +} + +#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * grad_input [[buffer(2)]], + constant int64_t & output_size [[buffer(3)]], + constant int64_t & channels [[buffer(4)]], + constant int64_t & height [[buffer(5)]], + constant int64_t & width [[buffer(6)]], + constant int64_t & pooled_height [[buffer(7)]], + constant int64_t & pooled_width [[buffer(8)]], + constant int64_t & sampling_ratio [[buffer(9)]], + constant bool & aligned [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + constant int64_t & n_stride [[buffer(12)]], + constant int64_t & c_stride [[buffer(13)]], + constant int64_t & h_stride [[buffer(14)]], + constant int64_t & w_stride [[buffer(15)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We need to index the gradient using the tensor strides to access the + // correct values. + const integer_t output_offset = n * n_stride + c * c_stride; + constant T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + const integer_t input_offset = (roi_batch_ind * channels + c) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast(g4)); + + } // if + } // ix + } // iy + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_align_backward_" #DTYPE)]] \ +kernel void roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * grad_input [[buffer(2)]], \ + constant int64_t & output_size [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + constant int64_t & n_stride [[buffer(12)]], \ + constant int64_t & c_stride [[buffer(13)]], \ + constant int64_t & h_stride [[buffer(14)]], \ + constant int64_t & w_stride [[buffer(15)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * argmax [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + integer_t maxidx = -1; + constant T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + if (offset_input[input_index] > maxval) { + maxval = offset_input[input_index]; + maxidx = input_index; + } + } + } + output[index] = maxval; + argmax[index] = maxidx; + } +} + +#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_" #DTYPE)]] \ +kernel void roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * argmax_data [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * argmax_data [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant float & spatial_scale [[buffer(10)]], + constant int64_t & n_stride [[buffer(11)]], + constant int64_t & c_stride [[buffer(12)]], + constant int64_t & h_stride [[buffer(13)]], + constant int64_t & w_stride [[buffer(14)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c = (index / pooled_width / pooled_height) % channels; + integer_t n = index / pooled_width / pooled_height / channels; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + const integer_t output_offset = n * n_stride + c * c_stride; + constant integer_t * argmax_data_offset = + argmax_data + (n * channels + c) * pooled_height * pooled_width; + const integer_t argmax = argmax_data_offset[ph * pooled_width + pw]; + const integer_t offset = (roi_batch_ind * channels + c) * height * width; + + if (argmax != -1) { + atomic_add_float(grad_input + offset + argmax, static_cast(grad_output[output_offset + ph * h_stride + pw * w_stride])); + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("roi_pool_backward_" #DTYPE)]] \ +kernel void roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * argmax_data [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant float & spatial_scale [[buffer(10)]], \ + constant int64_t & n_stride [[buffer(11)]], \ + constant int64_t & c_stride [[buffer(12)]], \ + constant int64_t & h_stride [[buffer(13)]], \ + constant int64_t & w_stride [[buffer(14)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / pooled_width / pooled_height) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_" #DTYPE)]] \ +kernel void ps_roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + integer_t c_in = channel_mapping[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + integer_t roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + integer_t roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + integer_t x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast(g1)); + atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast(g2)); + atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast(g3)); + atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } +} + +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool( + constant T * input [[buffer(0)]], + constant T * rois [[buffer(1)]], + device T * output [[buffer(2)]], + device int64_t * channel_mapping [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, c_out, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out; + integer_t n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height - 1)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height - 1)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width - 1)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width - 1)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + constant T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + } +} + +#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_" #DTYPE)]] \ +kernel void ps_roi_pool( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + device int64_t * channel_mapping [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void ps_roi_pool_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & channels_out [[buffer(10)]], + constant float & spatial_scale [[buffer(11)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + integer_t pw = index % pooled_width; + integer_t ph = (index / pooled_width) % pooled_height; + integer_t n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + integer_t roi_batch_ind = offset_rois[0]; + integer_t roi_start_w = round(offset_rois[1] * spatial_scale); + integer_t roi_start_h = round(offset_rois[2] * spatial_scale); + integer_t roi_end_w = round(offset_rois[3] * spatial_scale); + integer_t roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + integer_t roi_width = max(roi_end_w - roi_start_w, static_cast(1)); + integer_t roi_height = max(roi_end_h - roi_start_h, static_cast(1)); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + integer_t hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + integer_t wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + integer_t hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + integer_t wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, static_cast(0)), static_cast(height)); + hend = min(max(hend + roi_start_h, static_cast(0)), static_cast(height)); + wstart = min(max(wstart + roi_start_w, static_cast(0)), static_cast(width)); + wend = min(max(wend + roi_start_w, static_cast(0)), static_cast(width)); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + integer_t c_in = channel_mapping[index]; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + + const integer_t offset = (roi_batch_ind * channels + c_in) * height * width; + + for (integer_t h = hstart; h < hend; ++h) { + for (integer_t w = wstart; w < wend; ++w) { + integer_t grad_input_index = h * width + w; + atomic_add_float(grad_input + offset + grad_input_index, diff_val); + } + } + + } // MPS_1D_KERNEL_LOOP +} + +#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \ +template \ +[[host_name("ps_roi_pool_backward_" #DTYPE)]] \ +kernel void ps_roi_pool_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & channels_out [[buffer(10)]], \ + constant float & spatial_scale [[buffer(11)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_NMS_OP(float); +REGISTER_NMS_OP(half); +REGISTER_ROI_ALIGN_OP(float, int64_t); +REGISTER_ROI_ALIGN_OP(half, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_ROI_POOL_OP(float, int64_t); +REGISTER_ROI_POOL_OP(half, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_OP(half, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t); +REGISTER_PS_ROI_POOL_OP(float, int64_t); +REGISTER_PS_ROI_POOL_OP(half, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t); +REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t); + +)VISION_METAL"; + +static id compileVisionOpsLibrary(id device) { + static id visionLibrary = nil; + if (visionLibrary) { + return visionLibrary; + } + + NSError* error = nil; + MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:MTLLanguageVersion2_3]; + visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]); + return visionLibrary; +} + +static id visionPipelineState(id device, const std::string& kernel) { + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id visionLib = compileVisionOpsLibrary(device); + id visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:visionFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} + +} // namespace mps +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/nms_kernel.mm b/torchvision/csrc/ops/mps/nms_kernel.mm new file mode 100644 index 00000000000..5ee9b5cbeae --- /dev/null +++ b/torchvision/csrc/ops/mps/nms_kernel.mm @@ -0,0 +1,109 @@ +#include +#include +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +// This should be in sync with `nmsThreadsPerBlock` in the metal kernel. +constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8; + +at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { + using namespace at::native::mps; + TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor"); + TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor"); + + TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1)); + TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D"); + TORCH_CHECK(dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + int64_t dets_num = dets.size(0); + float iou_threshold_f = static_cast(iou_threshold); + + const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock; + at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + id inputBuffer = getMTLBufferStorage(dets_sorted); + id outputBuffer = getMTLBufferStorage(mask); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1); + + const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores}); + + [computeEncoder setComputePipelineState:visionPSO]; + [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0]; + [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1]; + [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2]; + [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > nmsThreadsPerBlock) { + tgSize = nmsThreadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + + int64_t num_to_keep = 0; + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + for (int64_t i = 0; i < dets_num; i++) { + int64_t nblock = i / nmsThreadsPerBlock; + int64_t inblock = i % nmsThreadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int64_t j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm new file mode 100644 index 00000000000..16b711ad5ef --- /dev/null +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -0,0 +1,205 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + + int64_t output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t output_size = grad.numel(); + int64_t channels_out = channels / (pooled_height * pooled_width); + + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm new file mode 100644 index 00000000000..fc24f6990fa --- /dev/null +++ b/torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm @@ -0,0 +1,200 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ps_roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + TORCH_CHECK(channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int64_t channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong)); + auto output_size = output.numel(); + + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs."); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "ps_roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t channels_out = channels / (pooled_height * pooled_width); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad_); + id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:channelMappingBuffer + offset:channel_mapping.storage_offset() * channel_mapping.element_size() + atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm new file mode 100644 index 00000000000..d4ed8b43fd2 --- /dev/null +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -0,0 +1,197 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor roi_align_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return output; +} + +at::Tensor roi_align_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, + int64_t sampling_ratio, + bool aligned) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs."); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_align_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/mps/roi_pool_kernel.mm b/torchvision/csrc/ops/mps/roi_pool_kernel.mm new file mode 100644 index 00000000000..816d8d70863 --- /dev/null +++ b/torchvision/csrc/ops/mps/roi_pool_kernel.mm @@ -0,0 +1,196 @@ +#include +#include +#include "mps_helpers.h" +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +std::tuple roi_pool_forward_kernel(const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "roi_pool_forward_kernel"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + int64_t num_rois = rois.size(0); + int64_t channels = input.size(1); + int64_t height = input.size(2); + int64_t width = input.size(3); + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong)); + + int64_t output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return std::make_tuple(output, argmax); + } + + auto input_ = input.contiguous(); + auto rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(input_); + id roisBuffer = getMTLBufferStorage(rois_); + id outputBuffer = getMTLBufferStorage(output); + id argmaxBuffer = getMTLBufferStorage(argmax); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return std::make_tuple(output, argmax); +} + +at::Tensor roi_pool_backward_kernel(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width) { + using namespace at::native::mps; + TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); + TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs."); + TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor"); + + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; + + at::CheckedFrom c = "roi_pool_backward_kernel"; + at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + float spatial_scale_f = static_cast(spatial_scale); + + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + if (grad.numel() == 0) { + return grad_input; + } + + int64_t n_stride = grad.stride(0); + int64_t c_stride = grad.stride(1); + int64_t h_stride = grad.stride(2); + int64_t w_stride = grad.stride(3); + int64_t output_size = grad.numel(); + + at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); + + id inputBuffer = getMTLBufferStorage(grad); + id roisBuffer = getMTLBufferStorage(rois_); + id argmaxBuffer = getMTLBufferStorage(argmax_); + id outputBuffer = getMTLBufferStorage(grad_input); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + MTLSize threadgroupsPerGrid = MTLSizeMake( + std::min(ceil_div(static_cast(grad.numel()), static_cast(512)), static_cast(4096)), + 1, + 1); + + const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type()); + id visionPSO = mps::visionPipelineState(device, kernel); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_}); + + [computeEncoder setComputePipelineState:visionPSO]; + // [N, C, H, W] + [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; + [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; + + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10]; + [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12]; + [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13]; + [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14]; + + // A threadGroup is equivalent to a cuda's block. + NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > threadsPerBlock) { + tgSize = threadsPerBlock; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + + getMPSProfiler().endProfileKernel(visionPSO); + } + }); + return grad_input; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index be8ec8aea74..0d505c140ee 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -158,12 +158,12 @@ def from_K(t): y = ( from_K(roi_start_h) + ph[None, :, None] * from_K(bin_size_h) - + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h) ) # [K, PH, IY] x = ( from_K(roi_start_w) + pw[None, :, None] * from_K(bin_size_w) - + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w) ) # [K, PW, IX] val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] @@ -232,7 +232,7 @@ def roi_align( if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): - if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): + if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops() return torch.ops.torchvision.roi_align(