diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 3ce050e4d85..590b32732ac 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -3,7 +3,7 @@ from modulefinder import Module import torch -from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils +from torchvision import datasets, io, models, ops, transforms, utils from .extension import _HAS_OPS diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py deleted file mode 100644 index b33371671c7..00000000000 --- a/torchvision/_meta_registrations.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torch.library - -# Ensure that torch.ops.torchvision is visible -import torchvision.extension # noqa: F401 - -from torch._prims_common import check - -_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta") - -vision = torch.ops.torchvision - - -def register_meta(op): - def wrapper(fn): - _meta_lib.impl(op, fn) - return fn - - return wrapper - - -@register_meta(vision.roi_align.default) -def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") - check( - input.dtype == rois.dtype, - lambda: ( - "Expected tensor for input to have the same type as tensor for rois; " - f"but type {input.dtype} does not equal {rois.dtype}" - ), - ) - num_rois = rois.size(0) - _, channels, height, width = input.size() - return input.new_empty((num_rois, channels, pooled_height, pooled_width)) - - -@register_meta(vision._roi_align_backward.default) -def meta_roi_align_backward( - grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned -): - check( - grad.dtype == rois.dtype, - lambda: ( - "Expected tensor for grad to have the same type as tensor for rois; " - f"but type {grad.dtype} does not equal {rois.dtype}" - ), - ) - return grad.new_empty((batch_size, channels, height, width)) diff --git a/torchvision/csrc/ops/autograd/roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/roi_align_kernel.cpp index 6d792fe09d9..f26842b6428 100644 --- a/torchvision/csrc/ops/autograd/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autograd/roi_align_kernel.cpp @@ -15,8 +15,8 @@ class ROIAlignFunction : public torch::autograd::Function { const torch::autograd::Variable& input, const torch::autograd::Variable& rois, double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, + int64_t pooled_height, + int64_t pooled_width, int64_t sampling_ratio, bool aligned) { ctx->saved_data["spatial_scale"] = spatial_scale; @@ -24,10 +24,10 @@ class ROIAlignFunction : public torch::autograd::Function { ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["aligned"] = aligned; - ctx->saved_data["input_shape"] = input.sym_sizes(); + ctx->saved_data["input_shape"] = input.sizes(); ctx->save_for_backward({rois}); at::AutoDispatchBelowADInplaceOrView g; - auto result = roi_align_symint( + auto result = roi_align( input, rois, spatial_scale, @@ -44,17 +44,17 @@ class ROIAlignFunction : public torch::autograd::Function { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; - auto input_shape = ctx->saved_data["input_shape"].toList(); - auto grad_in = detail::_roi_align_backward_symint( + auto input_shape = ctx->saved_data["input_shape"].toIntList(); + auto grad_in = detail::_roi_align_backward( grad_output[0], rois, ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toSymInt(), - ctx->saved_data["pooled_width"].toSymInt(), - input_shape[0].get().toSymInt(), - input_shape[1].get().toSymInt(), - input_shape[2].get().toSymInt(), - input_shape[3].get().toSymInt(), + ctx->saved_data["pooled_height"].toInt(), + ctx->saved_data["pooled_width"].toInt(), + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], ctx->saved_data["sampling_ratio"].toInt(), ctx->saved_data["aligned"].toBool()); return { @@ -77,16 +77,16 @@ class ROIAlignBackwardFunction const torch::autograd::Variable& grad, const torch::autograd::Variable& rois, double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, int64_t sampling_ratio, bool aligned) { at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_roi_align_backward_symint( + auto result = detail::_roi_align_backward( grad, rois, spatial_scale, @@ -112,8 +112,8 @@ at::Tensor roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, + int64_t pooled_height, + int64_t pooled_width, int64_t sampling_ratio, bool aligned) { return ROIAlignFunction::apply( @@ -130,12 +130,12 @@ at::Tensor roi_align_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width, int64_t sampling_ratio, bool aligned) { return ROIAlignBackwardFunction::apply( diff --git a/torchvision/csrc/ops/roi_align.cpp b/torchvision/csrc/ops/roi_align.cpp index aa6dccb44f2..e2465d6261e 100644 --- a/torchvision/csrc/ops/roi_align.cpp +++ b/torchvision/csrc/ops/roi_align.cpp @@ -32,31 +32,6 @@ at::Tensor roi_align( aligned); } -at::Tensor roi_align_symint( - const at::Tensor& input, // Input feature map. - const at::Tensor& rois, // List of ROIs to pool over. - double spatial_scale, // The scale of the image features. ROIs will be - // scaled to this. - c10::SymInt pooled_height, // The height of the pooled feature map. - c10::SymInt pooled_width, // The width of the pooled feature - int64_t sampling_ratio, // The number of points to sample in each bin - bool aligned) // The flag for pixel shift -// along each axis. -{ - C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align"); - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::roi_align", "") - .typed(); - return op.call( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -} - namespace detail { at::Tensor _roi_align_backward( @@ -89,43 +64,13 @@ at::Tensor _roi_align_backward( aligned); } -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_roi_align_backward", "") - .typed(); - return op.call( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -} - } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor")); + "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor")); + "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/roi_align.h b/torchvision/csrc/ops/roi_align.h index 072d6d4231c..2ddb6ac3945 100644 --- a/torchvision/csrc/ops/roi_align.h +++ b/torchvision/csrc/ops/roi_align.h @@ -15,15 +15,6 @@ VISION_API at::Tensor roi_align( int64_t sampling_ratio, bool aligned); -VISION_API at::Tensor roi_align_symint( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - int64_t sampling_ratio, - bool aligned); - namespace detail { at::Tensor _roi_align_backward( @@ -39,19 +30,6 @@ at::Tensor _roi_align_backward( int64_t sampling_ratio, bool aligned); -at::Tensor _roi_align_backward_symint( - const at::Tensor& grad, - const at::Tensor& rois, - double spatial_scale, - c10::SymInt pooled_height, - c10::SymInt pooled_width, - c10::SymInt batch_size, - c10::SymInt channels, - c10::SymInt height, - c10::SymInt width, - int64_t sampling_ratio, - bool aligned); - } // namespace detail } // namespace ops