From 5da4e604fd3e0844a2cd51d54cbbb583b79603b5 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Wed, 23 Oct 2024 03:24:48 +0530 Subject: [PATCH 1/9] [onnx][importer] Add support for externalized params --- .../compiler/tools/import_onnx/__main__.py | 260 +++++++++++++++++- 1 file changed, 259 insertions(+), 1 deletion(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index d25f40dfb347..6e3b4334a17f 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -19,6 +19,12 @@ from pathlib import Path import sys import tempfile +import warnings +import random +import iree.runtime as rt + +from ...dialects import util +from typing import Optional, Tuple, Any try: import onnx @@ -28,6 +34,8 @@ f"(typically `{sys.executable} -m pip install onnx`)" ) from e +from onnx import numpy_helper + try: from ...extras import onnx_importer except ModuleNotFoundError as e: @@ -37,19 +45,241 @@ from ...ir import ( Context, + Type as IrType, + TypeAttr, + RankedTensorType, + StringAttr, + Attribute, + Operation, + Location, + InsertionPoint, + Value, + SymbolTable, ) +class IREENodeImporter(onnx_importer.NodeImporter): + def __init__( + self, + graph_info: onnx_importer.GraphInfo, + *, + parent_op: Operation, + block: onnx_importer.Block, + context_cache: "onnx_importer.ContextCache", + module_op: Operation, + module_cache: "onnx_importer.ModuleCache", + max_numel: int, + ): + super().__init__( + graph_info, + parent_op=parent_op, + block=block, + context_cache=context_cache, + module_op=module_op, + module_cache=module_cache, + ) + self.last_global_op = None + self.symbol_table = SymbolTable(module_op) + self.symbol_table.insert(parent_op) + self.max_numel = max_numel + self.param_archive = rt.ParameterIndex() + + def sanitize_name(self, name: str) -> str: + new_name: str = "" + for c in range(len(name)): + if name[c] == ":": + new_name += "_" + else: + new_name += name[c] + + if len(new_name) == 0: + alpha = [chr(v) for v in range(ord("a"), ord("a") + 26)] + ch = random.choice(alpha) + new_name = str(random.randrange(1, 1000)) + "__" + ch + return new_name + + def create_tensor_global( + self, + t: onnx.TensorProto, + ) -> Tuple[str, IrType]: + # Always create globals at the top. Then after created, if there was + # a prior one, move the new one to after it to maintain declaration + # order. + name = self.sanitize_name(t.name) + with InsertionPoint.at_block_begin( + self._m.regions[0].blocks[0] + ), Location.unknown(): + vtensor_type = RankedTensorType.get( + tuple(t.dims), self._cc.tensor_element_type(t.data_type) + ) + ir_attrs = { + "sym_name": StringAttr.get(name), + "sym_visibility": StringAttr.get("private"), + "type": TypeAttr.get(vtensor_type), + } + + external_scope_attr = StringAttr.get("model") + external_name_attr = StringAttr.get(name) + ir_attrs["initial_value"] = Attribute.parse( + f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}" + ) + global_op = util.GlobalOp( + ir_attrs["sym_name"], + ir_attrs["type"], + sym_visibility=ir_attrs["sym_visibility"], + initial_value=ir_attrs["initial_value"], + ) + self.symbol_table.insert(global_op) + if self.last_global_op is not None: + global_op.move_after(self.last_global_op) + self.last_global_op = global_op + actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value + return actual_symbol_name, vtensor_type + + @classmethod + def define_function( + cls, + graph_info: onnx_importer.GraphInfo, + module_op: Operation, + max_numel: int, + context_cache: Optional["onnx_importer.ContextCache"] = None, + module_cache: Optional["onnx_importer.ModuleCache"] = None, + private: bool = False, + ) -> "IREENodeImporter": + cc = ( + context_cache + if context_cache is not None + else onnx_importer.ContextCache(module_op.context) + ) + mc = ( + module_cache + if module_cache is not None + else onnx_importer.ModuleCache(module_op, cc) + ) + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): + body = module_op.regions[0].blocks[0] + func_name = graph_info.graph_proto.name + input_types = [ + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() + ] + output_types = [ + cc.type_proto_to_type(out.type) + for out in graph_info.output_map.values() + ] + ftype = onnx_importer.FunctionType.get(input_types, output_types) + func_op = onnx_importer.func_dialect.FuncOp( + func_name, + ftype, + ip=InsertionPoint(body), + visibility="private" if private else None, + ) + block = func_op.add_entry_block( + [Location.name(k) for k in graph_info.input_map.keys()] + ) + imp = IREENodeImporter( + graph_info, + parent_op=func_op, + block=block, + context_cache=cc, + module_op=module_op, + module_cache=mc, + max_numel=max_numel, + ) + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): + imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) + return imp + + def import_all(self, func=True): + for init in self._gi.initializer_map.values(): + self.import_initializer(init) + + self.get_none() + for node in self._gi.graph_proto.node: + self.import_node(node) + + outputs = [] + for output_name in self._gi.output_map.keys(): + try: + outputs.append(self._nv_map[output_name]) + except KeyError: + raise onnx_importer.OnnxImportError( + f"Non topologically produced ONNX graph output '{output_name}'" + ) + with InsertionPoint(self._b), Location.unknown(): + if func: + onnx_importer.func_dialect.ReturnOp(outputs) + else: + Operation.create(name="torch.operator_terminator", operands=outputs) + + def import_initializer( + self, initializer: onnx.TensorProto, extern_name: Optional[str] = None + ) -> Value: + # If an explicitly specified name is given, use that; otherwise, pick + # up the name from the tensor proto itself + iname = extern_name if extern_name else initializer.name + dims = list(initializer.dims) + acc = 1 + for d in dims: + acc = acc * d + if acc < self.max_numel: + with InsertionPoint(self._b), Location.name(iname): + value_attr = self._cc.tensor_proto_to_attr(initializer) + vtensor_type = self._cc.tensor_proto_to_type(initializer) + attrs = { + "name": StringAttr.get(f"onnx.Constant"), + "torch.onnx.value": value_attr, + } + literal_op = Operation.create( + name="torch.operator", + results=[vtensor_type], + attributes=attrs, + ) + self._nv_map[iname] = literal_op.result + return literal_op.result + + x, t = self.create_tensor_global(initializer) + vtensor_type = self._cc.get_vtensor_type( + tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type) + ) + + with InsertionPoint(self._b), Location.name(iname): + old_op = util.GlobalLoadOp(t, x) + converted_value = Operation.create( + "torch_c.from_builtin_tensor", + results=[vtensor_type], + operands=[old_op.result], + ).result + + self._nv_map[iname] = converted_value + tensor_as_array = numpy_helper.to_array(initializer) + self.param_archive.add_buffer(x, tensor_as_array) + return converted_value + + def main(args: argparse.Namespace): model_proto = load_onnx_model(args) context = Context() model_info = onnx_importer.ModelInfo(model_proto) m = model_info.create_module(context=context).operation - imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + + imp: Any = None + if args.externalize_params: + imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel) + else: + if args.max_numel: + warnings.warn( + "'--max-numel' has no effect until externalization is enabled with '--externalize-params'" + ) + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() + if not args.no_verify: m.verify() + if args.externalize_params: + imp.param_archive.create_archive_file(args.save_params_to) + # TODO: This isn't very efficient output. If these files ever # get large, enable bytecode and direct binary emission to save # some copies. @@ -71,6 +301,11 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, str(args.data_dir)) + if args.opset_version: + raw_model = onnx.version_converter.convert_version( + raw_model, args.opset_version + ) + # Do shape inference two ways. First, attempt in-memory to avoid redundant # loading and the need for writing a temporary file somewhere. If that # fails, typically because of the 2 GB protobuf size limit, try again via @@ -132,6 +367,29 @@ def parse_arguments(argv=None) -> argparse.Namespace: " Defaults to the directory of the input file.", type=Path, ) + parser.add_argument( + "--opset-version", + help="Allows specification of a newer opset_version to update the model" + " to before importing to MLIR. This can sometime assist with shape inference.", + type=int, + ) + parser.add_argument( + "--max-numel", + help="Maximum number of elements allowed in an inlined parameter constant.", + type=int, + default=100, + ) + parser.add_argument( + "--externalize-params", + help="Externalize large parameters and store them on the disk, to load at runtime.", + action=argparse.BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--save-params-to", + help="Location to save the externalized parameters", + default="params.irpa", + ) args = parser.parse_args(argv) return args From 61a74387cfc1631a85105ed84ffb041dfbcfb914 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Thu, 24 Oct 2024 22:01:35 +0530 Subject: [PATCH 2/9] Refactor code to remove duplications --- .../compiler/tools/import_onnx/__main__.py | 47 ++++--------------- 1 file changed, 10 insertions(+), 37 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index 6e3b4334a17f..9e2354c24af2 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -22,6 +22,7 @@ import warnings import random import iree.runtime as rt +import string from ...dialects import util from typing import Optional, Tuple, Any @@ -85,6 +86,11 @@ def __init__( self.param_archive = rt.ParameterIndex() def sanitize_name(self, name: str) -> str: + # There are often some initializers in the models that have no name + # labels, or contain substrings like '::', which can cause conflicts, + # and invalid symbol names for symbolic references. This function will + # remove substrings like '::' when the name is not empty, and generate + # a random string when it is, as a placeholder. new_name: str = "" for c in range(len(name)): if name[c] == ":": @@ -93,7 +99,7 @@ def sanitize_name(self, name: str) -> str: new_name += name[c] if len(new_name) == 0: - alpha = [chr(v) for v in range(ord("a"), ord("a") + 26)] + alpha = string.ascii_lowercase ch = random.choice(alpha) new_name = str(random.randrange(1, 1000)) + "__" + ch return new_name @@ -190,28 +196,6 @@ def define_function( imp._populate_graph_attrs(func_op) return imp - def import_all(self, func=True): - for init in self._gi.initializer_map.values(): - self.import_initializer(init) - - self.get_none() - for node in self._gi.graph_proto.node: - self.import_node(node) - - outputs = [] - for output_name in self._gi.output_map.keys(): - try: - outputs.append(self._nv_map[output_name]) - except KeyError: - raise onnx_importer.OnnxImportError( - f"Non topologically produced ONNX graph output '{output_name}'" - ) - with InsertionPoint(self._b), Location.unknown(): - if func: - onnx_importer.func_dialect.ReturnOp(outputs) - else: - Operation.create(name="torch.operator_terminator", operands=outputs) - def import_initializer( self, initializer: onnx.TensorProto, extern_name: Optional[str] = None ) -> Value: @@ -223,20 +207,9 @@ def import_initializer( for d in dims: acc = acc * d if acc < self.max_numel: - with InsertionPoint(self._b), Location.name(iname): - value_attr = self._cc.tensor_proto_to_attr(initializer) - vtensor_type = self._cc.tensor_proto_to_type(initializer) - attrs = { - "name": StringAttr.get(f"onnx.Constant"), - "torch.onnx.value": value_attr, - } - literal_op = Operation.create( - name="torch.operator", - results=[vtensor_type], - attributes=attrs, - ) - self._nv_map[iname] = literal_op.result - return literal_op.result + imported_tensor = super().import_initializer(initializer) + self._nv_map[iname] = imported_tensor + return imported_tensor x, t = self.create_tensor_global(initializer) vtensor_type = self._cc.get_vtensor_type( From 5e7313d48c6567529d4c1674bb7c7eeb5ad2d4cf Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Fri, 25 Oct 2024 03:00:31 +0530 Subject: [PATCH 3/9] Fix IR import for integer to signless integer conversions --- .../compiler/tools/import_onnx/__main__.py | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index 9e2354c24af2..a92ff49dbb63 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -19,7 +19,7 @@ from pathlib import Path import sys import tempfile -import warnings +import copy import random import iree.runtime as rt import string @@ -56,6 +56,7 @@ InsertionPoint, Value, SymbolTable, + IntegerType, ) @@ -115,8 +116,11 @@ def create_tensor_global( with InsertionPoint.at_block_begin( self._m.regions[0].blocks[0] ), Location.unknown(): + # After lowering to linalg-on-tensors, the data type need to be signless. + # So, we construct the globals to have signless types, and use + # torch_c.from_builtin_tensor to convert to the correct frontend type. vtensor_type = RankedTensorType.get( - tuple(t.dims), self._cc.tensor_element_type(t.data_type) + tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]() ) ir_attrs = { "sym_name": StringAttr.get(name), @@ -240,10 +244,6 @@ def main(args: argparse.Namespace): if args.externalize_params: imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel) else: - if args.max_numel: - warnings.warn( - "'--max-numel' has no effect until externalization is enabled with '--externalize-params'" - ) imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() @@ -317,6 +317,40 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: return inferred_model +ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB) + +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT64 +] = lambda: IntegerType.get_signless(64) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT32 +] = lambda: IntegerType.get_signless(32) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT16 +] = lambda: IntegerType.get_signless(16) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT8 +] = lambda: IntegerType.get_signless(8) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT4 +] = lambda: IntegerType.get_signless(4) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT8 +] = lambda: IntegerType.get_signless(8) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT4 +] = lambda: IntegerType.get_signless(4) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT16 +] = lambda: IntegerType.get_signless(16) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT64 +] = lambda: IntegerType.get_signless(64) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT32 +] = lambda: IntegerType.get_signless(32) + + def parse_arguments(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="IREE ONNX import tool") parser.add_argument("input_file", help="ONNX protobuf input", type=Path) From eefb662f08dc3675c8dde795569f85bf8a40d4ea Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Fri, 25 Oct 2024 23:21:40 +0530 Subject: [PATCH 4/9] Address comments --- .../compiler/tools/import_onnx/__main__.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index a92ff49dbb63..2059c417b99b 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -70,7 +70,7 @@ def __init__( context_cache: "onnx_importer.ContextCache", module_op: Operation, module_cache: "onnx_importer.ModuleCache", - max_numel: int, + numel_threshold: int, ): super().__init__( graph_info, @@ -83,7 +83,7 @@ def __init__( self.last_global_op = None self.symbol_table = SymbolTable(module_op) self.symbol_table.insert(parent_op) - self.max_numel = max_numel + self.numel_threshold = numel_threshold self.param_archive = rt.ParameterIndex() def sanitize_name(self, name: str) -> str: @@ -116,7 +116,7 @@ def create_tensor_global( with InsertionPoint.at_block_begin( self._m.regions[0].blocks[0] ), Location.unknown(): - # After lowering to linalg-on-tensors, the data type need to be signless. + # After lowering to linalg-on-tensors, the data type needs to be signless. # So, we construct the globals to have signless types, and use # torch_c.from_builtin_tensor to convert to the correct frontend type. vtensor_type = RankedTensorType.get( @@ -151,7 +151,7 @@ def define_function( cls, graph_info: onnx_importer.GraphInfo, module_op: Operation, - max_numel: int, + numel_threshold: int, context_cache: Optional["onnx_importer.ContextCache"] = None, module_cache: Optional["onnx_importer.ModuleCache"] = None, private: bool = False, @@ -193,7 +193,7 @@ def define_function( context_cache=cc, module_op=module_op, module_cache=mc, - max_numel=max_numel, + numel_threshold=numel_threshold, ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value @@ -207,10 +207,10 @@ def import_initializer( # up the name from the tensor proto itself iname = extern_name if extern_name else initializer.name dims = list(initializer.dims) - acc = 1 + numel = 1 for d in dims: - acc = acc * d - if acc < self.max_numel: + numel = numel * d + if numel < self.numel_threshold: imported_tensor = super().import_initializer(initializer) self._nv_map[iname] = imported_tensor return imported_tensor @@ -242,7 +242,9 @@ def main(args: argparse.Namespace): imp: Any = None if args.externalize_params: - imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel) + imp = IREENodeImporter.define_function( + model_info.main_graph, m, args.numel_threshold + ) else: imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() @@ -251,7 +253,13 @@ def main(args: argparse.Namespace): m.verify() if args.externalize_params: - imp.param_archive.create_archive_file(args.save_params_to) + default_param_path = Path(args.output_file).parent / Path(args.output_file).stem + param_path = ( + (str(default_param_path) + "_params.irpa") + if args.save_params_to is None + else args.save_params_to + ) + imp.param_archive.create_archive_file(param_path) # TODO: This isn't very efficient output. If these files ever # get large, enable bytecode and direct binary emission to save @@ -274,7 +282,8 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, str(args.data_dir)) - if args.opset_version: + # Only change the opset version if it is greater than the current one. + if args.opset_version and args.opset_version > raw_model.opset_import[0].version: raw_model = onnx.version_converter.convert_version( raw_model, args.opset_version ) @@ -381,8 +390,8 @@ def parse_arguments(argv=None) -> argparse.Namespace: type=int, ) parser.add_argument( - "--max-numel", - help="Maximum number of elements allowed in an inlined parameter constant.", + "--numel-threshold", + help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.", type=int, default=100, ) @@ -394,8 +403,8 @@ def parse_arguments(argv=None) -> argparse.Namespace: ) parser.add_argument( "--save-params-to", - help="Location to save the externalized parameters", - default="params.irpa", + help="Location to save the externalized parameters. When not set, the parameters will be written to '_params.irpa'.", + default=None, ) args = parser.parse_args(argv) return args From cf08fce60083de42e91f4eb9d3304341d5e55661 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Mon, 28 Oct 2024 14:52:26 +0530 Subject: [PATCH 5/9] Move overrides to separate file --- compiler/bindings/python/CMakeLists.txt | 1 + .../compiler/tools/import_onnx/__main__.py | 248 +----------------- .../importer_externalization_overrides.py | 247 +++++++++++++++++ 3 files changed, 250 insertions(+), 246 deletions(-) create mode 100644 compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index fe033b2b0299..e606e7309aef 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -145,6 +145,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools tools/tf.py tools/tflite.py tools/import_onnx/__main__.py + tools/import_onnx/importer_externalization_overrides.py tools/ir_tool/__main__.py tools/scripts/iree_compile/__main__.py tools/scripts/iree_opt/__main__.py diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index 2059c417b99b..ce708da8b7cb 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -14,224 +14,14 @@ python -m iree.compiler.tools.import_onnx ... """ + import argparse import os from pathlib import Path import sys import tempfile -import copy -import random -import iree.runtime as rt -import string - -from ...dialects import util -from typing import Optional, Tuple, Any - -try: - import onnx -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"iree-import-onnx requires that the `onnx` Python package is installed " - f"(typically `{sys.executable} -m pip install onnx`)" - ) from e - -from onnx import numpy_helper - -try: - from ...extras import onnx_importer -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "iree-import-onnx is only available if IREE was built with Torch support" - ) from e - -from ...ir import ( - Context, - Type as IrType, - TypeAttr, - RankedTensorType, - StringAttr, - Attribute, - Operation, - Location, - InsertionPoint, - Value, - SymbolTable, - IntegerType, -) - - -class IREENodeImporter(onnx_importer.NodeImporter): - def __init__( - self, - graph_info: onnx_importer.GraphInfo, - *, - parent_op: Operation, - block: onnx_importer.Block, - context_cache: "onnx_importer.ContextCache", - module_op: Operation, - module_cache: "onnx_importer.ModuleCache", - numel_threshold: int, - ): - super().__init__( - graph_info, - parent_op=parent_op, - block=block, - context_cache=context_cache, - module_op=module_op, - module_cache=module_cache, - ) - self.last_global_op = None - self.symbol_table = SymbolTable(module_op) - self.symbol_table.insert(parent_op) - self.numel_threshold = numel_threshold - self.param_archive = rt.ParameterIndex() - - def sanitize_name(self, name: str) -> str: - # There are often some initializers in the models that have no name - # labels, or contain substrings like '::', which can cause conflicts, - # and invalid symbol names for symbolic references. This function will - # remove substrings like '::' when the name is not empty, and generate - # a random string when it is, as a placeholder. - new_name: str = "" - for c in range(len(name)): - if name[c] == ":": - new_name += "_" - else: - new_name += name[c] - - if len(new_name) == 0: - alpha = string.ascii_lowercase - ch = random.choice(alpha) - new_name = str(random.randrange(1, 1000)) + "__" + ch - return new_name - - def create_tensor_global( - self, - t: onnx.TensorProto, - ) -> Tuple[str, IrType]: - # Always create globals at the top. Then after created, if there was - # a prior one, move the new one to after it to maintain declaration - # order. - name = self.sanitize_name(t.name) - with InsertionPoint.at_block_begin( - self._m.regions[0].blocks[0] - ), Location.unknown(): - # After lowering to linalg-on-tensors, the data type needs to be signless. - # So, we construct the globals to have signless types, and use - # torch_c.from_builtin_tensor to convert to the correct frontend type. - vtensor_type = RankedTensorType.get( - tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]() - ) - ir_attrs = { - "sym_name": StringAttr.get(name), - "sym_visibility": StringAttr.get("private"), - "type": TypeAttr.get(vtensor_type), - } - external_scope_attr = StringAttr.get("model") - external_name_attr = StringAttr.get(name) - ir_attrs["initial_value"] = Attribute.parse( - f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}" - ) - global_op = util.GlobalOp( - ir_attrs["sym_name"], - ir_attrs["type"], - sym_visibility=ir_attrs["sym_visibility"], - initial_value=ir_attrs["initial_value"], - ) - self.symbol_table.insert(global_op) - if self.last_global_op is not None: - global_op.move_after(self.last_global_op) - self.last_global_op = global_op - actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value - return actual_symbol_name, vtensor_type - - @classmethod - def define_function( - cls, - graph_info: onnx_importer.GraphInfo, - module_op: Operation, - numel_threshold: int, - context_cache: Optional["onnx_importer.ContextCache"] = None, - module_cache: Optional["onnx_importer.ModuleCache"] = None, - private: bool = False, - ) -> "IREENodeImporter": - cc = ( - context_cache - if context_cache is not None - else onnx_importer.ContextCache(module_op.context) - ) - mc = ( - module_cache - if module_cache is not None - else onnx_importer.ModuleCache(module_op, cc) - ) - with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): - body = module_op.regions[0].blocks[0] - func_name = graph_info.graph_proto.name - input_types = [ - cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() - ] - output_types = [ - cc.type_proto_to_type(out.type) - for out in graph_info.output_map.values() - ] - ftype = onnx_importer.FunctionType.get(input_types, output_types) - func_op = onnx_importer.func_dialect.FuncOp( - func_name, - ftype, - ip=InsertionPoint(body), - visibility="private" if private else None, - ) - block = func_op.add_entry_block( - [Location.name(k) for k in graph_info.input_map.keys()] - ) - imp = IREENodeImporter( - graph_info, - parent_op=func_op, - block=block, - context_cache=cc, - module_op=module_op, - module_cache=mc, - numel_threshold=numel_threshold, - ) - for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): - imp._nv_map[node_name] = input_value - imp._populate_graph_attrs(func_op) - return imp - - def import_initializer( - self, initializer: onnx.TensorProto, extern_name: Optional[str] = None - ) -> Value: - # If an explicitly specified name is given, use that; otherwise, pick - # up the name from the tensor proto itself - iname = extern_name if extern_name else initializer.name - dims = list(initializer.dims) - numel = 1 - for d in dims: - numel = numel * d - if numel < self.numel_threshold: - imported_tensor = super().import_initializer(initializer) - self._nv_map[iname] = imported_tensor - return imported_tensor - - x, t = self.create_tensor_global(initializer) - vtensor_type = self._cc.get_vtensor_type( - tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type) - ) - - with InsertionPoint(self._b), Location.name(iname): - old_op = util.GlobalLoadOp(t, x) - converted_value = Operation.create( - "torch_c.from_builtin_tensor", - results=[vtensor_type], - operands=[old_op.result], - ).result - - self._nv_map[iname] = converted_value - tensor_as_array = numpy_helper.to_array(initializer) - self.param_archive.add_buffer(x, tensor_as_array) - return converted_value +from .importer_externalization_overrides import * def main(args: argparse.Namespace): @@ -326,40 +116,6 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: return inferred_model -ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB) - -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.INT64 -] = lambda: IntegerType.get_signless(64) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.INT32 -] = lambda: IntegerType.get_signless(32) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.INT16 -] = lambda: IntegerType.get_signless(16) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.INT8 -] = lambda: IntegerType.get_signless(8) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.INT4 -] = lambda: IntegerType.get_signless(4) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.UINT8 -] = lambda: IntegerType.get_signless(8) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.UINT4 -] = lambda: IntegerType.get_signless(4) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.UINT16 -] = lambda: IntegerType.get_signless(16) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.UINT64 -] = lambda: IntegerType.get_signless(64) -ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ - onnx.TensorProto.DataType.UINT32 -] = lambda: IntegerType.get_signless(32) - - def parse_arguments(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="IREE ONNX import tool") parser.add_argument("input_file", help="ONNX protobuf input", type=Path) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py new file mode 100644 index 000000000000..7f6e7ea08e92 --- /dev/null +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py @@ -0,0 +1,247 @@ +import copy +import random +import string +import iree.runtime as rt + +from ...dialects import util +from typing import Optional, Tuple, Any + +try: + import onnx +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"iree-import-onnx requires that the `onnx` Python package is installed " + f"(typically `{sys.executable} -m pip install onnx`)" + ) from e + +try: + from ...extras import onnx_importer +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "iree-import-onnx is only available if IREE was built with Torch support" + ) from e + +from onnx import numpy_helper + +from ...ir import ( + Context, + Type as IrType, + TypeAttr, + RankedTensorType, + StringAttr, + Attribute, + Operation, + Location, + InsertionPoint, + Value, + SymbolTable, + IntegerType, +) + + +class IREENodeImporter(onnx_importer.NodeImporter): + def __init__( + self, + graph_info: onnx_importer.GraphInfo, + *, + parent_op: Operation, + block: onnx_importer.Block, + context_cache: "onnx_importer.ContextCache", + module_op: Operation, + module_cache: "onnx_importer.ModuleCache", + numel_threshold: int, + ): + super().__init__( + graph_info, + parent_op=parent_op, + block=block, + context_cache=context_cache, + module_op=module_op, + module_cache=module_cache, + ) + self.last_global_op = None + self.symbol_table = SymbolTable(module_op) + self.symbol_table.insert(parent_op) + self.numel_threshold = numel_threshold + self.param_archive = rt.ParameterIndex() + + def sanitize_name(self, name: str) -> str: + # There are often some initializers in the models that have no name + # labels, or contain substrings like '::', which can cause conflicts, + # and invalid symbol names for symbolic references. This function will + # remove substrings like '::' when the name is not empty, and generate + # a random string when it is, as a placeholder. + new_name: str = "" + for c in range(len(name)): + if name[c] == ":": + new_name += "_" + else: + new_name += name[c] + + if len(new_name) == 0: + alpha = string.ascii_lowercase + ch = random.choice(alpha) + new_name = str(random.randrange(1, 1000)) + "__" + ch + return new_name + + def create_tensor_global( + self, + t: onnx.TensorProto, + ) -> Tuple[str, IrType]: + # Always create globals at the top. Then after created, if there was + # a prior one, move the new one to after it to maintain declaration + # order. + name = self.sanitize_name(t.name) + with InsertionPoint.at_block_begin( + self._m.regions[0].blocks[0] + ), Location.unknown(): + # After lowering to linalg-on-tensors, the data type needs to be signless. + # So, we construct the globals to have signless types, and use + # torch_c.from_builtin_tensor to convert to the correct frontend type. + vtensor_type = RankedTensorType.get( + tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]() + ) + ir_attrs = { + "sym_name": StringAttr.get(name), + "sym_visibility": StringAttr.get("private"), + "type": TypeAttr.get(vtensor_type), + } + + external_scope_attr = StringAttr.get("model") + external_name_attr = StringAttr.get(name) + ir_attrs["initial_value"] = Attribute.parse( + f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}" + ) + global_op = util.GlobalOp( + ir_attrs["sym_name"], + ir_attrs["type"], + sym_visibility=ir_attrs["sym_visibility"], + initial_value=ir_attrs["initial_value"], + ) + self.symbol_table.insert(global_op) + if self.last_global_op is not None: + global_op.move_after(self.last_global_op) + self.last_global_op = global_op + actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value + return actual_symbol_name, vtensor_type + + @classmethod + def define_function( + cls, + graph_info: onnx_importer.GraphInfo, + module_op: Operation, + numel_threshold: int, + context_cache: Optional["onnx_importer.ContextCache"] = None, + module_cache: Optional["onnx_importer.ModuleCache"] = None, + private: bool = False, + ) -> "IREENodeImporter": + cc = ( + context_cache + if context_cache is not None + else onnx_importer.ContextCache(module_op.context) + ) + mc = ( + module_cache + if module_cache is not None + else onnx_importer.ModuleCache(module_op, cc) + ) + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): + body = module_op.regions[0].blocks[0] + func_name = graph_info.graph_proto.name + input_types = [ + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() + ] + output_types = [ + cc.type_proto_to_type(out.type) + for out in graph_info.output_map.values() + ] + ftype = onnx_importer.FunctionType.get(input_types, output_types) + func_op = onnx_importer.func_dialect.FuncOp( + func_name, + ftype, + ip=InsertionPoint(body), + visibility="private" if private else None, + ) + block = func_op.add_entry_block( + [Location.name(k) for k in graph_info.input_map.keys()] + ) + imp = IREENodeImporter( + graph_info, + parent_op=func_op, + block=block, + context_cache=cc, + module_op=module_op, + module_cache=mc, + numel_threshold=numel_threshold, + ) + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): + imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) + return imp + + def import_initializer( + self, initializer: onnx.TensorProto, extern_name: Optional[str] = None + ) -> Value: + # If an explicitly specified name is given, use that; otherwise, pick + # up the name from the tensor proto itself + iname = extern_name if extern_name else initializer.name + dims = list(initializer.dims) + numel = 1 + for d in dims: + numel = numel * d + if numel < self.numel_threshold: + imported_tensor = super().import_initializer(initializer) + self._nv_map[iname] = imported_tensor + return imported_tensor + + x, t = self.create_tensor_global(initializer) + vtensor_type = self._cc.get_vtensor_type( + tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type) + ) + + with InsertionPoint(self._b), Location.name(iname): + old_op = util.GlobalLoadOp(t, x) + converted_value = Operation.create( + "torch_c.from_builtin_tensor", + results=[vtensor_type], + operands=[old_op.result], + ).result + + self._nv_map[iname] = converted_value + tensor_as_array = numpy_helper.to_array(initializer) + self.param_archive.add_buffer(x, tensor_as_array) + return converted_value + + +ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB) + +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT64 +] = lambda: IntegerType.get_signless(64) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT32 +] = lambda: IntegerType.get_signless(32) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT16 +] = lambda: IntegerType.get_signless(16) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT8 +] = lambda: IntegerType.get_signless(8) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.INT4 +] = lambda: IntegerType.get_signless(4) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT8 +] = lambda: IntegerType.get_signless(8) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT4 +] = lambda: IntegerType.get_signless(4) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT16 +] = lambda: IntegerType.get_signless(16) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT64 +] = lambda: IntegerType.get_signless(64) +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ + onnx.TensorProto.DataType.UINT32 +] = lambda: IntegerType.get_signless(32) From 40cbcb755cf58ff52c4046d705da33b93b4cc093 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Mon, 4 Nov 2024 18:31:43 +0530 Subject: [PATCH 6/9] Address comments --- .../iree/compiler/tools/import_onnx/__main__.py | 16 ++++++++++++---- .../importer_externalization_overrides.py | 6 +++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index ce708da8b7cb..fc108d8a65b2 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -33,7 +33,7 @@ def main(args: argparse.Namespace): imp: Any = None if args.externalize_params: imp = IREENodeImporter.define_function( - model_info.main_graph, m, args.numel_threshold + model_info.main_graph, m, args.num_elements_threshold, args.params_scope ) else: imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) @@ -47,7 +47,7 @@ def main(args: argparse.Namespace): param_path = ( (str(default_param_path) + "_params.irpa") if args.save_params_to is None - else args.save_params_to + else str(args.save_params_to) ) imp.param_archive.create_archive_file(param_path) @@ -146,7 +146,7 @@ def parse_arguments(argv=None) -> argparse.Namespace: type=int, ) parser.add_argument( - "--numel-threshold", + "--num-elements-threshold", help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.", type=int, default=100, @@ -159,8 +159,16 @@ def parse_arguments(argv=None) -> argparse.Namespace: ) parser.add_argument( "--save-params-to", - help="Location to save the externalized parameters. When not set, the parameters will be written to '_params.irpa'.", + help="Location to save the externalized parameters. When not set, the parameters will be written to '_params.irpa'" + " under the namespace 'model', which can be configured by passing the namespace string to 'params-scope'.", default=None, + type=Path, + ) + parser.add_argument( + "--params-scope", + help="The namespace or the scope in which the externalized parameters are placed. Default is 'model'.", + type=str, + default="model", ) args = parser.parse_args(argv) return args diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py index 7f6e7ea08e92..44ade2dca98d 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py @@ -50,6 +50,7 @@ def __init__( module_op: Operation, module_cache: "onnx_importer.ModuleCache", numel_threshold: int, + params_scope: str, ): super().__init__( graph_info, @@ -64,6 +65,7 @@ def __init__( self.symbol_table.insert(parent_op) self.numel_threshold = numel_threshold self.param_archive = rt.ParameterIndex() + self.params_scope = params_scope def sanitize_name(self, name: str) -> str: # There are often some initializers in the models that have no name @@ -107,7 +109,7 @@ def create_tensor_global( "type": TypeAttr.get(vtensor_type), } - external_scope_attr = StringAttr.get("model") + external_scope_attr = StringAttr.get(self.params_scope) external_name_attr = StringAttr.get(name) ir_attrs["initial_value"] = Attribute.parse( f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}" @@ -131,6 +133,7 @@ def define_function( graph_info: onnx_importer.GraphInfo, module_op: Operation, numel_threshold: int, + params_scope: str, context_cache: Optional["onnx_importer.ContextCache"] = None, module_cache: Optional["onnx_importer.ModuleCache"] = None, private: bool = False, @@ -173,6 +176,7 @@ def define_function( module_op=module_op, module_cache=mc, numel_threshold=numel_threshold, + params_scope=params_scope, ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value From 8f3274653ea4947c2b97705a0c54ffe1f787b577 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Mon, 11 Nov 2024 11:23:58 +0530 Subject: [PATCH 7/9] Address comments-2 --- .../importer_externalization_overrides.py | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py index 44ade2dca98d..d99b26974354 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/importer_externalization_overrides.py @@ -1,3 +1,9 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import copy import random import string @@ -49,7 +55,7 @@ def __init__( context_cache: "onnx_importer.ContextCache", module_op: Operation, module_cache: "onnx_importer.ModuleCache", - numel_threshold: int, + num_elements_threshold: int, params_scope: str, ): super().__init__( @@ -63,7 +69,7 @@ def __init__( self.last_global_op = None self.symbol_table = SymbolTable(module_op) self.symbol_table.insert(parent_op) - self.numel_threshold = numel_threshold + self.num_elements_threshold = num_elements_threshold self.param_archive = rt.ParameterIndex() self.params_scope = params_scope @@ -132,17 +138,23 @@ def define_function( cls, graph_info: onnx_importer.GraphInfo, module_op: Operation, - numel_threshold: int, + num_elements_threshold: int, params_scope: str, context_cache: Optional["onnx_importer.ContextCache"] = None, module_cache: Optional["onnx_importer.ModuleCache"] = None, private: bool = False, ) -> "IREENodeImporter": + # Recover per-context caches of various attributes. + # Allows modifications in the same context without + # loss of current state. cc = ( context_cache if context_cache is not None else onnx_importer.ContextCache(module_op.context) ) + # Recover per-module caches of various attributes. + # Allows modification in the same module_op without + # loss of current state. mc = ( module_cache if module_cache is not None @@ -175,7 +187,7 @@ def define_function( context_cache=cc, module_op=module_op, module_cache=mc, - numel_threshold=numel_threshold, + num_elements_threshold=num_elements_threshold, params_scope=params_scope, ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): @@ -188,32 +200,32 @@ def import_initializer( ) -> Value: # If an explicitly specified name is given, use that; otherwise, pick # up the name from the tensor proto itself - iname = extern_name if extern_name else initializer.name + initializer_name = extern_name if extern_name else initializer.name dims = list(initializer.dims) - numel = 1 + num_elements = 1 for d in dims: - numel = numel * d - if numel < self.numel_threshold: + num_elements = num_elements * d + if num_elements < self.num_elements_threshold: imported_tensor = super().import_initializer(initializer) - self._nv_map[iname] = imported_tensor + self._nv_map[initializer_name] = imported_tensor return imported_tensor - x, t = self.create_tensor_global(initializer) + actual_symbol_name, tensor_type = self.create_tensor_global(initializer) vtensor_type = self._cc.get_vtensor_type( tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type) ) - with InsertionPoint(self._b), Location.name(iname): - old_op = util.GlobalLoadOp(t, x) + with InsertionPoint(self._b), Location.name(initializer_name): + old_op = util.GlobalLoadOp(tensor_type, actual_symbol_name) converted_value = Operation.create( "torch_c.from_builtin_tensor", results=[vtensor_type], operands=[old_op.result], ).result - self._nv_map[iname] = converted_value + self._nv_map[initializer_name] = converted_value tensor_as_array = numpy_helper.to_array(initializer) - self.param_archive.add_buffer(x, tensor_as_array) + self.param_archive.add_buffer(actual_symbol_name, tensor_as_array) return converted_value From 0836529c08ac8e948f63ab0cbcfc1287363873fb Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Wed, 13 Nov 2024 01:28:27 +0530 Subject: [PATCH 8/9] Add tests for externalization --- .../python/test/tools/import_onnx_test.py | 78 ++++++++++++++++++ .../python/test/tools/testdata/conv.onnx | Bin 0 -> 4935 bytes 2 files changed, 78 insertions(+) create mode 100644 compiler/bindings/python/test/tools/testdata/conv.onnx diff --git a/compiler/bindings/python/test/tools/import_onnx_test.py b/compiler/bindings/python/test/tools/import_onnx_test.py index 73eb56ff899e..5c797e9be281 100644 --- a/compiler/bindings/python/test/tools/import_onnx_test.py +++ b/compiler/bindings/python/test/tools/import_onnx_test.py @@ -22,6 +22,84 @@ def run_tool(*argv: str): ONNX_FILE_PATH = os.path.join(os.path.dirname(__file__), "testdata", "LeakyReLU.onnx") +LARGE_WEIGHTS_ONNX_FILE_PATH = os.path.join( + os.path.dirname(__file__), "testdata", "conv.onnx" +) + + +class ImportOnnxwithExternalizationTest(unittest.TestCase): + def setUp(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + self.outputPath = f.name + + def tearDown(self) -> None: + if os.path.exists(self.outputPath): + os.unlink(self.outputPath) + if os.path.exists("custom_params_file.irpa"): + os.unlink("custom_params_file.irpa") + if os.path.exists(str(self.outputPath) + "_params.irpa"): + os.unlink(str(self.outputPath) + "_params.irpa") + + def testExternalizeWeightsDefaultThreshold(self): + run_tool( + LARGE_WEIGHTS_ONNX_FILE_PATH, "--externalize-params", "-o", self.outputPath + ) + with open(self.outputPath, "rt") as f: + contents = f.read() + self.assertIn("util.global", contents) + self.assertIn("util.global.load", contents) + # The bias is smaller in volume than the default 100 elements, + # so it should still be inlined. + self.assertIn("onnx.Constant", contents) + assert os.path.isfile(str(self.outputPath) + "_params.irpa") + + def testExternalizeParamsSaveCustomPath(self): + run_tool( + LARGE_WEIGHTS_ONNX_FILE_PATH, + "--externalize-params", + "--save-params-to", + "custom_params_file.irpa", + "-o", + self.outputPath, + ) + with open(self.outputPath, "rt") as f: + contents = f.read() + self.assertIn("util.global", contents) + self.assertIn("util.global.load", contents) + assert os.path.isfile("custom_params_file.irpa") + + def testExternalizeTooHighThreshold(self): + num_elements_weights = 1 * 256 * 100 * 100 + 1 + run_tool( + LARGE_WEIGHTS_ONNX_FILE_PATH, + "--externalize-params", + "--num-elements-threshold", + str(num_elements_weights), + "-o", + self.outputPath, + ) + with open(self.outputPath, "rt") as f: + contents = f.read() + self.assertNotIn("util.global", contents) + self.assertNotIn("util.global.load", contents) + self.assertIn("onnx.Constant", contents) + + def testExternalizeMinimumThreshold(self): + run_tool( + LARGE_WEIGHTS_ONNX_FILE_PATH, + "--externalize-params", + "--num-elements-threshold", + "0", + "-o", + self.outputPath, + ) + with open(self.outputPath, "rt") as f: + contents = f.read() + self.assertIn("util.global", contents) + self.assertIn("util.global.load", contents) + # As the max allowed number of elements for inlining is 0 elements, + # there should be no inlined constants. + self.assertNotIn("onnx.Constant", contents) class ImportOnnxTest(unittest.TestCase): diff --git a/compiler/bindings/python/test/tools/testdata/conv.onnx b/compiler/bindings/python/test/tools/testdata/conv.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9cbebb85c5a509213e6d605a6b9446a831404ee5 GIT binary patch literal 4935 zcmZvgv5wwE6otpDWb^A+ieW_%5{(r_6kEBbVY0k{yAXtggR+7MNUSVSPl>2pQu;}d zAH|Qt_j&I~TOf3F&YXMBo$D{^a~^B^p88x55K?dYB}GyfSr8XRnaN> z#}0HQv%xd$UB`Rl?2dx=9qvtnv+uCue0^|suD9cDY}=mL9Bh287i zp6B+y9A{_lwxcJ^iUpM58Lq&4-IIXc2zIo2cN{Zx`#n9IOR(Jq_y%{`&A}`4gf?r} zOu=(**Zc+x&t{!XU;!0xXc``YS$7p!^yv2ep3T@-yTiUPHv_#4%=q@^xg&*l{2u6) z9Tk|du!nd2Hh2Zhz_&Q7G1tq$p1Y0=vmU*| z_6FeGx9gkR9?S~by*1{Z7qq?I8CZk2n3*l;9b~k5Gjf96=Uo-Oh41Y6{h0Y(?LGEV zoJ(YPJuu=sKi(JYo_2PR*#R!VP6c{X%fa4dM!z(0_Kofn891*%-#gr4Ps{=P30=eA zB5&6no>x%7!dXwQoR4>-lXLj^`^Il&e#GlR>)G+_4)gZhvBPFsZbHvs#WySLnzJ{6 zg7$YXr*5;h-yJKy(6iIdzO&%B%*eU*e2??QT)n~0AAakO=iT-fv%r4Ech+{Pz1Oq3 zY=)h30TXzbb>|qfMEB6}20WWJlY*-n-+jJ$0BU&z6WEy-cm{Tb^B(B8eu+H`-io$s z?)aN>cF)M!_a8lU2KEm>xMl1wz`gFXQ{e%2HiMmWAKILHgDWhY%ld2cgA-oc;hPe^ zZ}t1^=&8+2z$}r;INR~{xn_Yl8*}l35=6zoV zDRgD--Xh}-!0aBLJLbTOC(I?dz@r*$oGYx?1AAs$*ZgiK{x-5_`@Y-U@i)Pn&~2VV z>$Sa`dpO&*)B4W!oVCBd_YRy3Y%il{u%h+6C&A7Z`~=LnYYHu!ub|BfTHn5$Vy17# z`$Z21u;IJ6dFDLZFEVGf`8~85eb4Hd*~8yxwhK zcc_XYn3 hcJTtA89rwC_~s%%`wxEG5Gw!x literal 0 HcmV?d00001 From c58603405ab101d707626306d9e4e5e646b85036 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Fri, 15 Nov 2024 00:48:54 +0530 Subject: [PATCH 9/9] Move externalization tests in test file --- .../python/test/tools/import_onnx_test.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/compiler/bindings/python/test/tools/import_onnx_test.py b/compiler/bindings/python/test/tools/import_onnx_test.py index 5c797e9be281..6089b65ef057 100644 --- a/compiler/bindings/python/test/tools/import_onnx_test.py +++ b/compiler/bindings/python/test/tools/import_onnx_test.py @@ -27,6 +27,30 @@ def run_tool(*argv: str): ) +class ImportOnnxTest(unittest.TestCase): + def setUp(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + self.outputPath = f.name + + def tearDown(self) -> None: + if os.path.exists(self.outputPath): + os.unlink(self.outputPath) + + def testConsoleOutput(self): + # Just test that it doesn't crash: rely on the file test for verification. + run_tool(ONNX_FILE_PATH) + + def testDisableVerify(self): + # Just test that the flag is accepted. + run_tool(ONNX_FILE_PATH, "--no-verify") + + def testFileOutput(self): + run_tool(ONNX_FILE_PATH, "-o", self.outputPath) + with open(self.outputPath, "rt") as f: + contents = f.read() + self.assertIn("torch.operator", contents) + + class ImportOnnxwithExternalizationTest(unittest.TestCase): def setUp(self): with tempfile.NamedTemporaryFile(delete=False) as f: @@ -102,30 +126,6 @@ def testExternalizeMinimumThreshold(self): self.assertNotIn("onnx.Constant", contents) -class ImportOnnxTest(unittest.TestCase): - def setUp(self): - with tempfile.NamedTemporaryFile(delete=False) as f: - self.outputPath = f.name - - def tearDown(self) -> None: - if os.path.exists(self.outputPath): - os.unlink(self.outputPath) - - def testConsoleOutput(self): - # Just test that it doesn't crash: rely on the file test for verification. - run_tool(ONNX_FILE_PATH) - - def testDisableVerify(self): - # Just test that the flag is accepted. - run_tool(ONNX_FILE_PATH, "--no-verify") - - def testFileOutput(self): - run_tool(ONNX_FILE_PATH, "-o", self.outputPath) - with open(self.outputPath, "rt") as f: - contents = f.read() - self.assertIn("torch.operator", contents) - - if __name__ == "__main__": try: import onnx