From 7f22f15dc0fe70bbe83c0dde64d4d91e8486aae1 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 24 May 2023 19:16:17 -0500 Subject: [PATCH] Avoid `_prims_common.check` in favor of `torch._check` --- torchvision/_meta_registrations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 7285e15ced8..9831cfdcb45 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -6,8 +6,6 @@ # Ensure that torch.ops.torchvision is visible import torchvision.extension # noqa: F401 -from torch._prims_common import check - @functools.lru_cache(None) def get_meta_lib(): @@ -25,8 +23,8 @@ def wrapper(fn): @register_meta("roi_align") def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") - check( + torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") + torch._check( input.dtype == rois.dtype, lambda: ( "Expected tensor for input to have the same type as tensor for rois; " @@ -42,7 +40,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp def meta_roi_align_backward( grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned ): - check( + torch._check( grad.dtype == rois.dtype, lambda: ( "Expected tensor for grad to have the same type as tensor for rois; "