Skip to content

Commit

Permalink
[fbsync] Only do meta registrations if we have the ops (#7500)
Browse files Browse the repository at this point in the history
Summary: Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Reviewed By: vmoens

Differential Revision: D45183661

fbshipit-source-id: 808a5d461b4217e7a00295675311fedd99e3d23a
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Apr 24, 2023
1 parent 82ac911 commit 2ad9c15
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import torch
import torch.library

Expand All @@ -6,20 +8,22 @@

from torch._prims_common import check

_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta")

vision = torch.ops.torchvision
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")


def register_meta(op):
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
_meta_lib.impl(op, fn)
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn

return wrapper


@register_meta(vision.roi_align.default)
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check(
Expand All @@ -34,7 +38,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
return input.new_empty((num_rois, channels, pooled_height, pooled_width))


@register_meta(vision._roi_align_backward.default)
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
Expand Down

0 comments on commit 2ad9c15

Please sign in to comment.