Skip to content

Commit

Permalink
Support PyTorch grid_sample (apache#10184)
Browse files Browse the repository at this point in the history
* [relay] Fix stack overflow in device_planner observed on windows due to recursive function calls.

* Revert "[relay] Fix stack overflow in device_planner observed on windows due to recursive function calls."

This reverts commit 7058136.

* [PyTorch] Add grid_sample with zeros and border padding mode for PyTorch.
  • Loading branch information
mei-ye authored and ylc committed Feb 16, 2022
1 parent cfef88b commit af52b86
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 42 deletions.
6 changes: 6 additions & 0 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ struct AffineGridAttrs : public tvm::AttrsNode<AffineGridAttrs> {
struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
String method;
String layout;
String padding_mode;

TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
TVM_ATTR_FIELD(method)
Expand All @@ -287,6 +288,11 @@ struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(padding_mode)
.set_default("zeros")
.describe(
"Specify the padding mode to use."
"zeros, border etc.");
}
};

Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,25 @@ def mv(self, inputs, _):
# Chop off the extra result dimension
return _op.transform.squeeze(dense_result)

def grid_sampler(self, inputs, input_types):
if inputs[2] == 0:
mode = "bilinear"
else:
msg = "Only bilinear mode is supported in grid_sampler"
raise NotImplementedError(msg)

if inputs[3] == 0:
padding_mode = "zeros"
elif inputs[3] == 1:
padding_mode = "border"
else:
msg = "Only zeros and border padding mode are supported in grid_sampler"
raise NotImplementedError(msg)

axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3124,6 +3143,7 @@ def create_convert_map(self):
"aten::einsum": self.einsum,
"aten::dot": self.dot,
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
}

def update_convert_map(self, custom_map):
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def affine_grid_func(attrs, inputs, _):
def compute_grid_sample(attrs, inputs, out_dtype):
method = attrs.method
layout = attrs.layout
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)]
padding_mode = attrs.padding_mode
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)]


reg.register_injective_schedule("image.grid_sample")
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def affine_grid(data, target_shape=None):
return _make.affine_grid(data, target_shape)


def grid_sample(data, grid, method="bilinear", layout="NCHW"):
def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.
Given :math:`data` and :math:`grid`, then the output is computed by
Expand All @@ -468,7 +468,8 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation function.
The out-boundary points will be padded with zeros. The shape of the output will be
The out-boundary points will be padded with zeros if padding_mode is "zeros".
The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
The operator assumes that :math:`grid` has been normalized to [-1, 1].
Expand All @@ -489,9 +490,12 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
layout : str
The layout of input data and the output.
padding_mode : str
The padding mode for outside grid values.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
"""
return _make.grid_sample(data, grid, method, layout)
return _make.grid_sample(data, grid, method, layout, padding_mode)
24 changes: 16 additions & 8 deletions python/tvm/topi/image/grid_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _compute(n, dim, i, j):
return te.compute(oshape, _compute, tag="affine_grid")


def grid_sample(data, grid, method="bilinear", layout="NCHW"):
def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.
Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by
Expand All @@ -72,7 +72,8 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation method.
The out-boundary points will be padded with zeros. The shape of the output will be
The out-boundary points will be padded with zeros if the padding_mode is "zeros".
The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
The operator assumes that :math:`grid` has been normalized to [-1, 1].
Expand All @@ -96,19 +97,26 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, in_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = data.shape
out_height, out_width = grid.shape[2:]
assert method == "bilinear", "Only bilinear is supported"
assert layout == "NCHW", "Only NCHW is supported"

def _get_pixel_value(n, c, h, w):
return te.if_then_else(
te.all(h >= 0, w >= 0, h < in_height, w < in_width),
data[n, c, h, w],
tir.const(0.0, dtype=data.dtype),
)
if padding_mode == "zeros":
return te.if_then_else(
te.all(h >= 0, w >= 0, h < in_height, w < in_width),
data[n, c, h, w],
tir.const(0.0, dtype=data.dtype),
)
if padding_mode == "border":
h_b = te.max(te.min(h, in_height - 1), 0)
w_b = te.max(te.min(w, in_width - 1), 0)
return data[n, c, h_b, w_b]

raise AssertionError("unsupported padding_mode")

def _bilinear_sample(n, c, h, w):
x = grid[n, 0, h, w]
Expand Down
73 changes: 54 additions & 19 deletions python/tvm/topi/testing/grid_sample_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,71 @@ def affine_grid_python(data, target_shape):
return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape)


def _bilinear_sample_nchw_python(data, grid):
def _bilinear_sample_nchw_python(data, grid, padding_mode):
batch, in_channel, in_height, in_width = data.shape
_, _, out_height, out_width = grid.shape
out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype)

def _within_bound(y, x):
return 0 <= y < in_height and 0 <= x < in_width

for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
if _within_bound(y0, x0):
out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
if _within_bound(y0, x1):
def compute_padding_mode_zeros():
for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
if _within_bound(y0, x0):
out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
if _within_bound(y0, x1):
out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0)
if _within_bound(y1, x0):
out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0))
if _within_bound(y1, x1):
out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)

return out

def get_pixel_value(x, x_max):
return max(min(x, x_max - 1), 0)

def compute_padding_mode_border():
for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
y0 = get_pixel_value(y0, in_height)
y1 = get_pixel_value(y1, in_height)
x0 = get_pixel_value(x0, in_width)
x1 = get_pixel_value(x1, in_width)
out[n, :, h, w] = data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0)
if _within_bound(y1, x0):
out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0))
if _within_bound(y1, x1):
out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)
return out

return out

if padding_mode == "zeros":
return compute_padding_mode_zeros()
if padding_mode == "border":
return compute_padding_mode_border()

def grid_sample_nchw_python(data, grid, method="bilinear"):
raise ValueError("invalid padding_mode")


def grid_sample_nchw_python(data, grid, method="bilinear", padding_mode="zeros"):
if method == "bilinear":
return _bilinear_sample_nchw_python(data, grid)
return _bilinear_sample_nchw_python(data, grid, padding_mode)

raise ValueError("invalid method")
3 changes: 2 additions & 1 deletion src/relay/op/image/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

// Positional relay function to create affine_grid operator
// used by frontend FFI.
Expr MakeGridSample(Expr data, Expr grid, String method, String layout) {
Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode) {
auto attrs = make_object<GridSampleAttrs>();
attrs->method = std::move(method);
attrs->layout = std::move(layout);
attrs->padding_mode = std::move(padding_mode);
static const Op& op = Op::Get("image.grid_sample");
return Call(op, {data, grid}, Attrs(attrs), {});
}
Expand Down
45 changes: 42 additions & 3 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,8 @@ def verify_model(
if not tvm.runtime.enabled(target):
continue
dev = tvm.device(target, 0)
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
relay_model = graph_executor.create(relay_graph, relay_lib, dev)
relay_model.set_input(**relay_params)
lib = relay.build(mod, target=target, params=params)
relay_model = graph_executor.GraphModule(lib["default"](dev))
for name, inp in compiled_input.items():
relay_model.set_input(name, inp)
relay_model.run()
Expand Down Expand Up @@ -4079,5 +4078,45 @@ def test_fn(m, v):
verify_model(test_fn, [torch.randn(3, 8), torch.randn(8)])


def test_grid_sample():
class Grid_sample_zeros(Module):
def forward(self, x, y):
return torch.nn.functional.grid_sample(
input=x, grid=y, mode="bilinear", padding_mode="zeros", align_corners=True
)

class Grid_sample_border(Module):
def forward(self, x, y):
return torch.nn.functional.grid_sample(
input=x, grid=y, mode="bilinear", padding_mode="border", align_corners=True
)

data = torch.rand([4, 4, 16, 32]).float()
grid = torch.rand([4, 8, 8, 2]).float()
verify_model(Grid_sample_zeros(), input_data=[data, grid])
verify_model(Grid_sample_border(), input_data=[data, grid])


def test_list_tuple():
"""test compilation error for a Python list followed by a prim::TupleConstruct."""

class List_tuple(Module):
def forward(self, x):
merged = []
mask_list = []
for i in range(3):
w0 = torch.sigmoid(x)
merged.append((w0, w0))
mask_list.append(x)

for i in range(3):
merged[i] = merged[i][0] + merged[i][1]
return mask_list[2], merged

x = torch.rand([4, 4, 16, 32]).float()
script_module = torch.jit.trace(List_tuple(), x, strict=False).eval()
mod, params = relay.frontend.from_pytorch(script_module, [("x", x.shape)])


if __name__ == "__main__":
pytest.main([__file__])
12 changes: 9 additions & 3 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,20 +1392,24 @@ def verify_affine_grid(num_batch, target_shape):

@tvm.testing.uses_gpu
def test_grid_sample():
def verify_grid_sample(data_shape, grid_shape):
def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"):
dtype = "float32"
batch, channel, _, _ = data_shape
_, _, out_height, out_width = grid_shape
data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype))
y = relay.image.grid_sample(data, grid, method="bilinear", layout="NCHW")
y = relay.image.grid_sample(
data, grid, method="bilinear", layout="NCHW", padding_mode=padding_mode
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype)
func = relay.Function([data, grid], y)

data_np = np.random.uniform(size=data_shape).astype(dtype)
grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
ref_res = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, method="bilinear")
ref_res = tvm.topi.testing.grid_sample_nchw_python(
data_np, grid_np, method="bilinear", padding_mode=padding_mode
)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand All @@ -1416,6 +1420,8 @@ def verify_grid_sample(data_shape, grid_shape):

verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border")
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border")


@tvm.testing.uses_gpu
Expand Down
11 changes: 7 additions & 4 deletions tests/python/topi/python/test_topi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,18 +274,20 @@ def check_target(target, dev):

@tvm.testing.uses_gpu
def test_grid_sample():
def verify_grid_sample(data_shape, grid_shape):
def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"):
dtype = "float32"
data = te.placeholder(data_shape, dtype=dtype)
grid = te.placeholder(grid_shape, dtype=dtype)
out = topi.image.grid_sample(data, grid, "bilinear")
out = topi.image.grid_sample(data, grid, "bilinear", padding_mode=padding_mode)

@memoize("topi.tests.test_grid_sample.verify_grid_sample")
def get_ref_data():
data_np = np.random.uniform(size=data_shape).astype(dtype)
# allow grid values to be out-of-bound
grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
out_np = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, "bilinear")
out_np = tvm.topi.testing.grid_sample_nchw_python(
data_np, grid_np, "bilinear", padding_mode
)
return data_np, grid_np, out_np

data_np, grid_np, out_np = get_ref_data()
Expand All @@ -306,7 +308,8 @@ def check_target(target, dev):
check_target(target, dev)

verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border")
verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border")


if __name__ == "__main__":
Expand Down

0 comments on commit af52b86

Please sign in to comment.