From 66973601d479340b9da33f1a3ce2f89f0ffea713 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 12 May 2023 07:31:28 -0700 Subject: [PATCH] Alert non-deterministic on kernels that use gpuAtomicAdd Signed-off-by: Edward Z. Yang --- torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu | 2 ++ torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu | 2 ++ torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu | 2 ++ torchvision/csrc/ops/cuda/roi_align_kernel.cu | 2 ++ torchvision/csrc/ops/cuda/roi_pool_kernel.cu | 2 ++ 5 files changed, 10 insertions(+) diff --git a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu index 5fd039a3103..b664bf11b55 100644 --- a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu @@ -426,6 +426,8 @@ void compute_grad_input( // Checks if num_kernels or columns numel larger than 2 ** 31 use_64bits_indexing |= num_kernels > (1 << 31); + at::globalContext().alertNotDeterministic("compute_grad_input"); + if (use_64bits_indexing) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( columns.scalar_type(), "compute_grad_input", ([&] { diff --git a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu index b9c624b09c8..17cc188cd68 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu @@ -412,6 +412,8 @@ at::Tensor ps_roi_align_backward_kernel( int channels_out = channels / (pooled_height * pooled_width); + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { diff --git a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu index 917fff03e8d..3789a2b7dfa 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu @@ -251,6 +251,8 @@ at::Tensor ps_roi_pool_backward_kernel( int channels_out = channels / (pooled_height * pooled_width); + at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel"); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { diff --git a/torchvision/csrc/ops/cuda/roi_align_kernel.cu b/torchvision/csrc/ops/cuda/roi_align_kernel.cu index f1f886c4738..2622edec1fc 100644 --- a/torchvision/csrc/ops/cuda/roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_align_kernel.cu @@ -421,6 +421,8 @@ at::Tensor roi_align_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); + at::globalContext().alertNotDeterministic("roi_align_backward_kernel"); + auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "roi_align_backward_kernel", [&] { diff --git a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu index e29c4438ed4..74952bba047 100644 --- a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu @@ -232,6 +232,8 @@ at::Tensor roi_pool_backward_kernel( int h_stride = grad.stride(2); int w_stride = grad.stride(3); + at::globalContext().alertNotDeterministic("roi_pool_backward_kernel"); + auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "roi_pool_backward_kernel", [&] {