From 166c3fb9ba9441701b3f4d720f594f068c2705a0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 12 Oct 2020 23:21:15 +0000 Subject: [PATCH] TF argmax - handling int64 datatype --- python/tvm/relay/frontend/tensorflow.py | 6 +++++- tests/python/frontend/tensorflow/test_forward.py | 12 ++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c7e8c0084db2..3df582a0c76a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -146,7 +146,11 @@ def _impl(inputs, attr, params, mod): raise TypeError( "Unsupported argument for `{}` : `axis` should be a constant".format(func_name) ) - return func(inputs[0], axis=axis_input_value, keepdims=False) + out = func(inputs[0], axis=axis_input_value, keepdims=False) + dtype = attr["output_type"].name + if dtype != "int32": + out = _op.cast(out, dtype=dtype) + return out return _impl diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index fb4c10465f22..8e347e754b98 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs): with tf.Graph().as_default(): inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0") - func(inp, name="argx0", output_type=tf.int32, **kwargs) - + func(inp, name="argx0", **kwargs) compare_tf_with_tvm(data, "c0:0", "argx0:0") def test_forward_argminmax(): - for axis in [None, 0, 1, 2]: - data = np.random.uniform(size=(8, 4, 9)).astype("float32") - _test_argx(tf.argmax, data=data, axis=axis) - _test_argx(tf.argmin, data=data, axis=axis) + for output_type in [tf.int64, tf.int32]: + for axis in [None, 0, 1, 2]: + data = np.random.uniform(size=(8, 4, 9)).astype("float32") + _test_argx(tf.argmax, data=data, axis=axis, output_type=output_type) + _test_argx(tf.argmin, data=data, axis=axis, output_type=output_type) #######################################################################