Skip to content

Commit

Permalink
Bring back dynamic_shapes constraints in fx importer API (llvm#3026)
Browse files Browse the repository at this point in the history
llvm#2992 dropped `constraints` from
the fx importer API,
[breaking](https://github.com/cruise-automation/mlir-tcp/actions/runs/8284385380/job/22669774071)
downstream AOT compile tests in `mlir-tcp` that use it. This knob has
been soft-deprecated for a while now, replaced by `dynamic_shapes` - a
more ergonomic interface. This PR brings back dynamic_shapes constraints
in the new supported form. Also added a python lit test with dynamic
shaped annotations.
  • Loading branch information
sjain-stanford authored Mar 14, 2024
1 parent 29ac23a commit 0b2f9c8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
5 changes: 3 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Optional
from typing import Optional, Union, Dict, Tuple, Any

import warnings

Expand All @@ -20,6 +20,7 @@ def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
func_name: str = "main",
Expand All @@ -30,7 +31,7 @@ def export_and_import(

if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
prog = torch.export.export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
decomp_table = get_decomposition_table()
prog = prog.run_decompositions(decomp_table)
if experimental_support_mutation:
Expand Down
18 changes: 17 additions & 1 deletion test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import Optional

import torch
import torch.export
import torch.nn as nn
from torch.export import Dim

from torch_mlir import fx

Expand Down Expand Up @@ -77,3 +77,19 @@ def forward(self, x):

m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
print(m)

@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tanh(x)

batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net")
print(m)

0 comments on commit 0b2f9c8

Please sign in to comment.