Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTorch grid_sample #10184

Merged
merged 4 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ struct AffineGridAttrs : public tvm::AttrsNode<AffineGridAttrs> {
struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
String method;
String layout;
String padding_mode;

TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
TVM_ATTR_FIELD(method)
Expand All @@ -287,6 +288,11 @@ struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(padding_mode)
.set_default("zeros")
.describe(
"Specify the padding mode to use."
"zeros, border etc.");
}
};

Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,25 @@ def mv(self, inputs, _):
# Chop off the extra result dimension
return _op.transform.squeeze(dense_result)

def grid_sampler(self, inputs, input_types):
if inputs[2] == 0:
mode = "bilinear"
else:
msg = "Only bilinear mode is supported in grid_sampler"
raise NotImplementedError(msg)

if inputs[3] == 0:
padding_mode = "zeros"
elif inputs[3] == 1:
padding_mode = "border"
else:
msg = "Only zeros and border padding mode are supported in grid_sampler"
raise NotImplementedError(msg)

axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
return _op.image.grid_sample(inputs[0], grid, mode, "NCHW", padding_mode)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3124,6 +3143,7 @@ def create_convert_map(self):
"aten::einsum": self.einsum,
"aten::dot": self.dot,
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
}

def update_convert_map(self, custom_map):
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def affine_grid_func(attrs, inputs, _):
def compute_grid_sample(attrs, inputs, out_dtype):
method = attrs.method
layout = attrs.layout
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)]
padding_mode = attrs.padding_mode
return [topi.image.grid_sample(inputs[0], inputs[1], method, layout, padding_mode)]


reg.register_injective_schedule("image.grid_sample")
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def affine_grid(data, target_shape=None):
return _make.affine_grid(data, target_shape)


def grid_sample(data, grid, method="bilinear", layout="NCHW"):
def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.

Given :math:`data` and :math:`grid`, then the output is computed by
Expand All @@ -468,7 +468,8 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):

:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation function.
The out-boundary points will be padded with zeros. The shape of the output will be
The out-boundary points will be padded with zeros if padding_mode is "zeros".
The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).

The operator assumes that :math:`grid` has been normalized to [-1, 1].
Expand All @@ -489,9 +490,12 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
layout : str
The layout of input data and the output.

padding_mode : str
The padding mode for outside grid values.

Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
"""
return _make.grid_sample(data, grid, method, layout)
return _make.grid_sample(data, grid, method, layout, padding_mode)
24 changes: 16 additions & 8 deletions python/tvm/topi/image/grid_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _compute(n, dim, i, j):
return te.compute(oshape, _compute, tag="affine_grid")


def grid_sample(data, grid, method="bilinear", layout="NCHW"):
def grid_sample(data, grid, method="bilinear", layout="NCHW", padding_mode="zeros"):
"""Applies bilinear sampling to input feature map.

Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by
Expand All @@ -72,7 +72,8 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):

:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
:math:`G()` denotes the interpolation method.
The out-boundary points will be padded with zeros. The shape of the output will be
The out-boundary points will be padded with zeros if the padding_mode is "zeros".
The shape of the output will be
(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).

The operator assumes that :math:`grid` has been normalized to [-1, 1].
Expand All @@ -96,19 +97,26 @@ def grid_sample(data, grid, method="bilinear", layout="NCHW"):
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, 2, out_height, out_width]
4-D with shape [batch, in_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = data.shape
out_height, out_width = grid.shape[2:]
assert method == "bilinear", "Only bilinear is supported"
assert layout == "NCHW", "Only NCHW is supported"

def _get_pixel_value(n, c, h, w):
return te.if_then_else(
te.all(h >= 0, w >= 0, h < in_height, w < in_width),
data[n, c, h, w],
tir.const(0.0, dtype=data.dtype),
)
if padding_mode == "zeros":
return te.if_then_else(
te.all(h >= 0, w >= 0, h < in_height, w < in_width),
data[n, c, h, w],
tir.const(0.0, dtype=data.dtype),
)
if padding_mode == "border":
h_b = te.max(te.min(h, in_height - 1), 0)
w_b = te.max(te.min(w, in_width - 1), 0)
return data[n, c, h_b, w_b]

raise AssertionError("unsupported padding_mode")

def _bilinear_sample(n, c, h, w):
x = grid[n, 0, h, w]
Expand Down
73 changes: 54 additions & 19 deletions python/tvm/topi/testing/grid_sample_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,71 @@ def affine_grid_python(data, target_shape):
return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape)


def _bilinear_sample_nchw_python(data, grid):
def _bilinear_sample_nchw_python(data, grid, padding_mode):
batch, in_channel, in_height, in_width = data.shape
_, _, out_height, out_width = grid.shape
out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype)

def _within_bound(y, x):
return 0 <= y < in_height and 0 <= x < in_width

for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
if _within_bound(y0, x0):
out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
if _within_bound(y0, x1):
def compute_padding_mode_zeros():
for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
if _within_bound(y0, x0):
out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
if _within_bound(y0, x1):
out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0)
if _within_bound(y1, x0):
out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0))
if _within_bound(y1, x1):
out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)

return out

def get_pixel_value(x, x_max):
return max(min(x, x_max - 1), 0)

def compute_padding_mode_border():
for n in range(0, batch):
for h in range(0, out_height):
for w in range(0, out_width):
x, y = grid[n, :, h, w]
y = (y + 1) * (in_height - 1) / 2
x = (x + 1) * (in_width - 1) / 2
y0 = int(math.floor(y))
x0 = int(math.floor(x))
y1 = y0 + 1
x1 = x0 + 1
y0 = get_pixel_value(y0, in_height)
y1 = get_pixel_value(y1, in_height)
x0 = get_pixel_value(x0, in_width)
x1 = get_pixel_value(x1, in_width)
out[n, :, h, w] = data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0))
out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0)
if _within_bound(y1, x0):
out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0))
if _within_bound(y1, x1):
out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)
return out

return out

if padding_mode == "zeros":
return compute_padding_mode_zeros()
if padding_mode == "border":
return compute_padding_mode_border()

def grid_sample_nchw_python(data, grid, method="bilinear"):
raise ValueError("invalid padding_mode")


def grid_sample_nchw_python(data, grid, method="bilinear", padding_mode="zeros"):
if method == "bilinear":
return _bilinear_sample_nchw_python(data, grid)
return _bilinear_sample_nchw_python(data, grid, padding_mode)

raise ValueError("invalid method")
3 changes: 2 additions & 1 deletion src/relay/op/image/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

// Positional relay function to create affine_grid operator
// used by frontend FFI.
Expr MakeGridSample(Expr data, Expr grid, String method, String layout) {
Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode) {
auto attrs = make_object<GridSampleAttrs>();
attrs->method = std::move(method);
attrs->layout = std::move(layout);
attrs->padding_mode = std::move(padding_mode);
static const Op& op = Op::Get("image.grid_sample");
return Call(op, {data, grid}, Attrs(attrs), {});
}
Expand Down
45 changes: 42 additions & 3 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,8 @@ def verify_model(
if not tvm.runtime.enabled(target):
continue
dev = tvm.device(target, 0)
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
relay_model = graph_executor.create(relay_graph, relay_lib, dev)
relay_model.set_input(**relay_params)
lib = relay.build(mod, target=target, params=params)
relay_model = graph_executor.GraphModule(lib["default"](dev))
for name, inp in compiled_input.items():
relay_model.set_input(name, inp)
relay_model.run()
Expand Down Expand Up @@ -4079,5 +4078,45 @@ def test_fn(m, v):
verify_model(test_fn, [torch.randn(3, 8), torch.randn(8)])


def test_grid_sample():
class Grid_sample_zeros(Module):
def forward(self, x, y):
return torch.nn.functional.grid_sample(
input=x, grid=y, mode="bilinear", padding_mode="zeros", align_corners=True
)

class Grid_sample_border(Module):
def forward(self, x, y):
return torch.nn.functional.grid_sample(
input=x, grid=y, mode="bilinear", padding_mode="border", align_corners=True
)

data = torch.rand([4, 4, 16, 32]).float()
grid = torch.rand([4, 8, 8, 2]).float()
verify_model(Grid_sample_zeros(), input_data=[data, grid])
verify_model(Grid_sample_border(), input_data=[data, grid])


def test_list_tuple():
"""test compilation error for a Python list followed by a prim::TupleConstruct."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give me more details on what this test is about? I don't see a relevant change in pytorch.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug in PyTorch frontend on list of tuples. And the bug is now already fixed upstream. I removed my fix in this patch but decided to keep the unit test.


class List_tuple(Module):
def forward(self, x):
merged = []
mask_list = []
for i in range(3):
w0 = torch.sigmoid(x)
merged.append((w0, w0))
mask_list.append(x)

for i in range(3):
merged[i] = merged[i][0] + merged[i][1]
return mask_list[2], merged

x = torch.rand([4, 4, 16, 32]).float()
script_module = torch.jit.trace(List_tuple(), x, strict=False).eval()
mod, params = relay.frontend.from_pytorch(script_module, [("x", x.shape)])


if __name__ == "__main__":
pytest.main([__file__])
12 changes: 9 additions & 3 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,20 +1392,24 @@ def verify_affine_grid(num_batch, target_shape):

@tvm.testing.uses_gpu
def test_grid_sample():
def verify_grid_sample(data_shape, grid_shape):
def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"):
dtype = "float32"
batch, channel, _, _ = data_shape
_, _, out_height, out_width = grid_shape
data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype))
y = relay.image.grid_sample(data, grid, method="bilinear", layout="NCHW")
y = relay.image.grid_sample(
data, grid, method="bilinear", layout="NCHW", padding_mode=padding_mode
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype)
func = relay.Function([data, grid], y)

data_np = np.random.uniform(size=data_shape).astype(dtype)
grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
ref_res = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, method="bilinear")
ref_res = tvm.topi.testing.grid_sample_nchw_python(
data_np, grid_np, method="bilinear", padding_mode=padding_mode
)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand All @@ -1416,6 +1420,8 @@ def verify_grid_sample(data_shape, grid_shape):

verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border")
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border")


@tvm.testing.uses_gpu
Expand Down
11 changes: 7 additions & 4 deletions tests/python/topi/python/test_topi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,18 +274,20 @@ def check_target(target, dev):

@tvm.testing.uses_gpu
def test_grid_sample():
def verify_grid_sample(data_shape, grid_shape):
def verify_grid_sample(data_shape, grid_shape, padding_mode="zeros"):
dtype = "float32"
data = te.placeholder(data_shape, dtype=dtype)
grid = te.placeholder(grid_shape, dtype=dtype)
out = topi.image.grid_sample(data, grid, "bilinear")
out = topi.image.grid_sample(data, grid, "bilinear", padding_mode=padding_mode)

@memoize("topi.tests.test_grid_sample.verify_grid_sample")
def get_ref_data():
data_np = np.random.uniform(size=data_shape).astype(dtype)
# allow grid values to be out-of-bound
grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype)
out_np = tvm.topi.testing.grid_sample_nchw_python(data_np, grid_np, "bilinear")
out_np = tvm.topi.testing.grid_sample_nchw_python(
data_np, grid_np, "bilinear", padding_mode
)
return data_np, grid_np, out_np

data_np, grid_np, out_np = get_ref_data()
Expand All @@ -306,7 +308,8 @@ def check_target(target, dev):
check_target(target, dev)

verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32), "border")
verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8), "border")


if __name__ == "__main__":
Expand Down