Skip to content

Commit

Permalink
[TFLite] Strided slice handling of shrink_axis_mask improved (#6998)
Browse files Browse the repository at this point in the history
* [TFLite] Strided slice handlig of shrink_axis_mask improved

1. Added removal of dimensions if result is a scalar
to mimic TensorFlow behaviour. E.g.:
    tf.strided_slice([1,2,3], [0], [1], [1], shrink_axis_mask=0)
    <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>

    tf.strided_slice([[[1,2,3],[4,5,6],[7,8,9]]], [0, 0, 0], [3, 3, 3], [1, 1, 1], shrink_axis_mask=7)
    <tf.Tensor: shape=(), dtype=int32, numpy=1>

2. Added extra check to assert_allclose to check shape equalities
as np.testing.assert_allclose() does not distinguish between cases like:

    np.testing.assert_allclose(1, np.array(1))
    np.testing.assert_allclose(1, np.array([1]))
    np.testing.assert_allclose(np.array(1), np.array([1]))

* unit tests fixed
  • Loading branch information
d-smirnov authored Jan 21, 2021
1 parent 727345e commit e8ab607
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 10 deletions.
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,14 +1613,19 @@ def _transform_mask(stride_dim, ellipsis_mask):

# Create final output shape.
final_output = []
final_len = len(fshape_indices)
for gather_index in fshape_indices:
if gather_index == -1:
final_output.append(1)
final_len += 1
elif gather_index == -2:
pass
final_len -= 1
else:
final_output.append(out_shape[gather_index])

if final_len == 0:
return _op.squeeze(out, axis=tuple(range(len(fshape_indices))))

if not final_output:
return out
return _op.reshape(out, newshape=tuple(final_output))
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
actual = np.asanyarray(actual)
desired = np.asanyarray(desired)
np.testing.assert_allclose(actual.shape, desired.shape)
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)


Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,24 @@ def _test_stridedslice(
def test_forward_stridedslice():
"""test StridedSlice"""
for quantized in [False, True]:
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=7,
quantized=quantized,
)
_test_stridedslice(
(1, 3, 3),
[0, 0, 0],
[3, 3, 3],
[1, 1, 1],
"float32",
shrink_axis_mask=5,
quantized=quantized,
)
_test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1, quantized=quantized)
_test_stridedslice(
(3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32", quantized=quantized
Expand Down
4 changes: 2 additions & 2 deletions tests/python/integration/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_dot():
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
k = te.reduce_axis((0, n), "k")
C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name="C")
C = te.compute((), lambda: te.sum(A[k] * B[k], axis=k), name="C")
s = te.create_schedule(C.op)

def verify(target):
Expand All @@ -36,7 +36,7 @@ def verify(target):
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(nn,)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((1,), dtype=C.dtype), ctx)
c = tvm.nd.array(np.zeros((), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-4)

Expand Down
14 changes: 7 additions & 7 deletions tests/python/integration/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_init_imm():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k, init=10.0), name="B")
# schedule
s = te.create_schedule(B.op)
# one line to build the function.
Expand All @@ -86,7 +86,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = 10.0 + np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_rfactor():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
Expand All @@ -145,7 +145,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down Expand Up @@ -191,11 +191,11 @@ def test_rfactor_factor_axis():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
B = te.compute((), lambda: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf, 1)
BF = s.rfactor(B, kf, 0)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
Expand All @@ -207,7 +207,7 @@ def check_target(target="llvm"):
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
b = tvm.nd.array(np.zeros((), dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
Expand Down

0 comments on commit e8ab607

Please sign in to comment.