Skip to content

Commit

Permalink
Add a quantized conv2 unit test for the tflite front-end (apache#5558)
Browse files Browse the repository at this point in the history
Signed-off-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
  • Loading branch information
giuseros authored and Trevor Morris committed Jun 9, 2020
1 parent 4d148f4 commit b45fa57
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def test_forward_l2_pool2d():

def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format,
is_depthwise=False):
is_depthwise=False, quantized=False):
""" One iteration of convolution with given shapes and attributes """

total_size_1 = 1
Expand All @@ -646,12 +646,16 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
if quantized:
data_array = np.random.uniform(0, 255, tensor_in_sizes).astype('uint8')
filter_array = np.random.uniform(0, 255, filter_in_sizes).astype('uint8')
else:
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32', name='in_data')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32', name='in_filter')
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]

Expand All @@ -667,15 +671,37 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
strides=strides,
padding=padding,
data_format=data_format)
data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])

if quantized:
# For now only quantized conv2d is supported
assert not is_depthwise

# Quantized the inputs and feed them to the convolution
inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
out = nn_ops.conv2d(inq_data,
inq_filter,
strides=strides,
padding=padding,
data_format=data_format)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")

# Set the input quantization range
input_range = {'in_data': (-100, 100)} if quantized else None

# Compare
compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
else:
data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out])


def test_forward_convolution():
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
for quantized in [False, True]:
_test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
_test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)

# depthwise convolution
_test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
Expand Down

0 comments on commit b45fa57

Please sign in to comment.