From 26afae52297e3b39b04f227927e1953783a0c733 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Fri, 10 Jun 2022 22:06:58 +0900 Subject: [PATCH] Support for inf or -inf --- sog4onnx/__init__.py | 2 +- sog4onnx/onnx_operation_generator.py | 37 +++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/sog4onnx/__init__.py b/sog4onnx/__init__.py index 6c79ab7..85580de 100644 --- a/sog4onnx/__init__.py +++ b/sog4onnx/__init__.py @@ -1,3 +1,3 @@ from sog4onnx.onnx_operation_generator import generate, main -__version__ = '1.0.12' +__version__ = '1.0.13' diff --git a/sog4onnx/onnx_operation_generator.py b/sog4onnx/onnx_operation_generator.py index 9b1a1d8..2a57b6d 100644 --- a/sog4onnx/onnx_operation_generator.py +++ b/sog4onnx/onnx_operation_generator.py @@ -176,7 +176,7 @@ def generate( value_info = onnx.helper.make_tensor_value_info( constant_name, dtype, - attr_values.shape + attr_values.shape, ) node = onnx.helper.make_node( op_type, @@ -362,10 +362,39 @@ def main(): attr_type = attribute[1] if attr_type == 'string': attr_type = 'str' - if attr_type != 'str': - attr_value = ast.literal_eval(attribute[2]) - else: + if attr_type == 'str': attr_value = attribute[2] + else: + if ('-inf' in attribute[2].lower()) or ('-infinity' in attribute[2].lower()): + lower_attr = attribute[2].lower() + inf_count = lower_attr.count('-inf') + infinity_count = lower_attr.count('-infinity') + if (inf_count + infinity_count) > 1: + print( + f'{Color.RED}ERROR:{Color.RESET} '+ + f'Values containing "inf" or "-inf" can only be 1D tensors. \n'+ + f'e.g. [inf] or [-inf]' + ) + sys.exit(1) + else: + attr_value = [-np.inf] + + elif ('inf' in attribute[2].lower()) or ('infinity' in attribute[2].lower()): + lower_attr = attribute[2].lower() + inf_count = lower_attr.count('inf') + infinity_count = lower_attr.count('infinity') + if (inf_count + infinity_count) > 1: + print( + f'{Color.RED}ERROR:{Color.RESET} '+ + f'Values containing "inf" or "-inf" can only be 1D tensors. \n'+ + f'e.g. [inf] or [-inf]' + ) + sys.exit(1) + else: + attr_value = [np.inf] + + else: + attr_value = ast.literal_eval(attribute[2]) # dtype check if attr_type not in AVAILABLE_DTYPES: