From 6b289f29f2815d90b1de39f0ca659db2a42c12c8 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Wed, 16 Oct 2024 10:32:52 +0800 Subject: [PATCH] =?UTF-8?q?[FxImporter]=20Added=20FxImporter=20test=20meth?= =?UTF-8?q?od=20to=20be=20executed=20via=20torch.co=E2=80=A6=20(#3795)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../configs/fx_importer_backend.py | 80 ++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 11a6ef6ffd6f..91bc49ebb893 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -8,6 +8,7 @@ import torch.utils._pytree as pytree from torch.export.graph_signature import OutputSpec, OutputKind from torch.export import ExportedProgram +from torch._dynamo.backends.common import aot_autograd from torch_mlir import fx from torch_mlir_e2e_test.configs.utils import ( @@ -15,6 +16,7 @@ recursively_convert_from_numpy, ) from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME def refine_result_type(_result): @@ -31,9 +33,10 @@ def refine_result_type(_result): class FxImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with Fx Importer""" - def __init__(self, backend, output_type="linalg-on-tensors"): + def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False): super().__init__() self._backend = backend + self._torch_compile = torch_compile self._output_type = output_type def compile( @@ -41,7 +44,80 @@ def compile( ) -> torch.nn.Module: return program - def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + def run(self, artifact: torch.nn.Module, trace: Trace): + return ( + self._export_run(artifact, trace) + if not self._torch_compile + else self._stateless_run(artifact, trace) + ) + + def _stateless_run(self, artifact: torch.nn.Module, trace: Trace): + dynamic_argument_pos = None + dynamic_dim_pos = None + annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) + for i, annotation in enumerate(annotations): + if i == 0: # Skip the "self" annotation. + continue + if not annotation[2]: + raise ValueError( + "Can only compile inputs annotated as having value semantics." + ) + for dim_i, dim in enumerate(annotation[0]): + if dim == -1: + dynamic_argument_pos = i - 1 + dynamic_dim_pos = dim_i + break + if dynamic_argument_pos is not None: + break + result: Trace = [] + for item in trace: + + def _base_backend(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.op == "placeholder": + if ( + isinstance(node.meta["val"], torch.SymInt) + and not node.users + ): + gm.graph.erase_node(node) + module = fx.stateless_fx_import( + gm, + output_type=self._output_type, + model_name=artifact.__class__.__name__, + ) + module = self._backend.compile(module) + backend_module = self._backend.load(module) + + def invoke_func(*torch_inputs): + torch_inputs = [ + x + for x in filter( + lambda i: isinstance(i, torch.Tensor), torch_inputs + ) + ] + with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(torch_inputs) + return recursively_convert_from_numpy( + getattr(backend_module, artifact.__class__.__name__)( + *numpy_inputs + ) + ) + + return invoke_func + + fw_compiler = aot_autograd(fw_compiler=_base_backend) + if dynamic_argument_pos is not None: + torch._dynamo.mark_dynamic( + item.inputs[dynamic_argument_pos], dynamic_dim_pos + ) + module = torch.compile(artifact, backend=fw_compiler) + outputs = module(*item.inputs) + result.append( + TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs) + ) + return result + + def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))