Skip to content

Commit

Permalink
[Relay][fix] Stack should take exprs that evaluate to tuples (apache#…
Browse files Browse the repository at this point in the history
…7130)

* Fix stack to take Relay exprs that evaluate to tuples

* Doc tweak

* Linting fix
  • Loading branch information
slyubomirsky authored Dec 30, 2020
1 parent cfdbf0e commit 466383a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
9 changes: 5 additions & 4 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,8 +1105,8 @@ def stack(data, axis):
Parameters
----------
data : Union(List[relay.Expr], Tuple(relay.Expr))
A list of tensors.
data : Union(List[relay.Expr], relay.Expr)
A list of tensors or a Relay expression that evaluates to a tuple of tensors.
axis : int
The axis in the result array along which the input arrays are stacked.
Expand All @@ -1116,12 +1116,13 @@ def stack(data, axis):
ret : relay.Expr
The stacked tensor.
"""
data = list(data)
if not data:
raise ValueError("relay.stack requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.stack(Tuple(data), axis)
if not isinstance(data, Expr):
data = Tuple(list(data))
return _make.stack(data, axis)


def copy(data):
Expand Down
60 changes: 45 additions & 15 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,28 +787,58 @@ def verify_repeat(dshape, repeats, axis):

@tvm.testing.uses_gpu
def test_stack():
def verify_stack(dshapes, axis):
y = []
for shape in dshapes:
y.append(relay.var("input", relay.TensorType(shape, "float32")))
x = relay.Tuple(y)
z = relay.stack(x, axis=axis)
def produce_input_tuple(dshapes):
y = [relay.var("input", relay.TensorType(shape, "float32")) for shape in dshapes]
return relay.Tuple(y)

func = relay.Function(y, z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = np.stack(x_data, axis=axis)
def ref_stack(inputs, axis):
return np.stack(inputs, axis=axis)

def verify_stack(input_expr, relay_args, ref_res, axis):
z = relay.stack(input_expr, axis=axis)
inp_vars = relay.analysis.free_vars(z)
func = relay.Function(inp_vars, z)

for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*x_data)
op_res = intrp.evaluate(func)(*relay_args)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_stack([(2,), (2,), (2,)], -1)
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4)
def verify_tup_lit_stack(dshapes, axis):
input_tuple = produce_input_tuple(dshapes)
input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = ref_stack(input_data, axis)
verify_stack(input_tuple, input_data, ref_res, axis)

def verify_list_lit_stack(dshapes, axis):
input_list = produce_input_tuple(dshapes).fields
input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = ref_stack(input_data, axis)
verify_stack(input_list, input_data, ref_res, axis)

def verify_tup_expr_stack(dshapes, axis):
input_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = ref_stack(input_data, axis)

# expression that evaluates to a tuple
# but is not a tuple literal
x = relay.Var("x")
input_expr = relay.Let(x, relay.Tuple([relay.const(inp) for inp in input_data]), x)
verify_stack(input_expr, [], ref_res, axis)

dshape_axis_combos = [
([(2,), (2,), (2,)], -1),
([(2,), (2,), (2,)], 0),
([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1),
([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1),
([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], 4),
]

for dshapes, axis in dshape_axis_combos:
verify_tup_lit_stack(dshapes, axis)
verify_list_lit_stack(dshapes, axis)
verify_tup_expr_stack(dshapes, axis)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 466383a

Please sign in to comment.