Skip to content

Commit

Permalink
Expose func_name to the main fx import API (llvm#2949)
Browse files Browse the repository at this point in the history
As titled.
  • Loading branch information
sjain-stanford authored Feb 26, 2024
1 parent c5a1da1 commit 3cbe6c9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def import_program(
node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op)

def import_frozen_program(self, prog: torch.export.ExportedProgram):
def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"):
"""Imports a consolidated torch.export.ExportedProgram instance.
If using the new torch.export path (vs a lower level precursor), then this is
Expand Down Expand Up @@ -702,7 +702,7 @@ def import_frozen_program(self, prog: torch.export.ExportedProgram):
node.replace_all_uses_with(replacement)
g.erase_node(node)

self.import_stateless_graph(g)
self.import_stateless_graph(g, func_name)

def import_graph_module(self, gm: GraphModule):
"""Low-level import of a GraphModule assuming that it has been functionalized.
Expand Down
5 changes: 3 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def export_and_import(
constraints: Optional[torch.export.Constraint] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
func_name: str = "main",
**kwargs,
):
context = ir.Context()
Expand All @@ -36,8 +37,8 @@ def export_and_import(
if experimental_support_mutation:
if torch.__version__ < "2.3.0.dev20240207":
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
fx_importer.import_program(prog)
fx_importer.import_program(prog, func_name=func_name)
else:
fx_importer.import_frozen_program(prog)
fx_importer.import_frozen_program(prog, func_name=func_name)

return fx_importer.module_op
21 changes: 21 additions & 0 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,24 @@ def forward(self, x):

m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)


@run
# CHECK-LABEL: test_import_frozen_exported_program_with_func_name
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
def test_import_frozen_exported_program_with_func_name():
@torch._dynamo.assume_constant_result
def get_a():
return torch.randn(1, 4)

class Basic(nn.Module):
def __init__(self):
super().__init__()
self.b = torch.randn(3, 1)
self.p = nn.Parameter(torch.randn(1, 1))

def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p

m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
print(m)

0 comments on commit 3cbe6c9

Please sign in to comment.