diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 91f3c27ee263..e6d0f03deda4 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -623,7 +623,7 @@ def import_program( node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) - def import_frozen_program(self, prog: torch.export.ExportedProgram): + def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"): """Imports a consolidated torch.export.ExportedProgram instance. If using the new torch.export path (vs a lower level precursor), then this is @@ -702,7 +702,7 @@ def import_frozen_program(self, prog: torch.export.ExportedProgram): node.replace_all_uses_with(replacement) g.erase_node(node) - self.import_stateless_graph(g) + self.import_stateless_graph(g, func_name) def import_graph_module(self, gm: GraphModule): """Low-level import of a GraphModule assuming that it has been functionalized. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 1f5aa8f74add..76cd91f82e0a 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -23,6 +23,7 @@ def export_and_import( constraints: Optional[torch.export.Constraint] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, + func_name: str = "main", **kwargs, ): context = ir.Context() @@ -36,8 +37,8 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog) + fx_importer.import_program(prog, func_name=func_name) else: - fx_importer.import_frozen_program(prog) + fx_importer.import_frozen_program(prog, func_name=func_name) return fx_importer.module_op diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 36c554862506..fc5b2030b648 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -56,3 +56,24 @@ def forward(self, x): m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m) + + +@run +# CHECK-LABEL: test_import_frozen_exported_program_with_func_name +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +def test_import_frozen_exported_program_with_func_name(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") + print(m)