Skip to content

Commit

Permalink
fximporter: support newer torch versions (llvm#2999)
Browse files Browse the repository at this point in the history
uses version checking since attributes exist in both versions, the only
thing that changes is what we're receiving as an fx graph
  • Loading branch information
dan-garvey authored Mar 8, 2024
1 parent 6b3a7d0 commit 80c7bc3
Showing 1 changed file with 52 additions and 20 deletions.
72 changes: 52 additions & 20 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,19 +220,47 @@
"gt": torch.ops.aten.gt,
}

SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}

SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
# so split by + and 0 index will always give the base version
_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"

# The following are maps from symbolic ops to their non symbolic equivalents.
# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size
# We identify it using the number of args in the node, 1 being default, 2 being int
# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore
# map to torch.aten.size.int.
# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.<type>.
# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP

if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}

SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}

SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default,
torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int,
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
}


@dataclass(frozen=True)
Expand Down Expand Up @@ -638,7 +666,9 @@ def import_program(
node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op)

def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"):
def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main"
):
"""Imports a consolidated torch.export.ExportedProgram instance.
If using the new torch.export path (vs a lower level precursor), then this is
Expand Down Expand Up @@ -1137,14 +1167,14 @@ def import_nodes(
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
)
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
elif target in SYMBOLIC_TORCH_OPS or (
is_symbolic(node.meta.get("val"))
and is_builtin_function_or_method(target)
):
self._import_symbolic_torch_op(loc, node, target)
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
else:
raise NotImplementedError(
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
Expand Down Expand Up @@ -1227,7 +1257,10 @@ def _import_symbolic_torch_op(
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
concrete_target = getattr(torch_op, op_overload)
else:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
if _IS_TORCH_2_1_OR_EARLIER:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
else:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target)

assert (
concrete_target is not None
Expand Down Expand Up @@ -1628,8 +1661,7 @@ def lookup(self, t: type) -> Any:

# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...


Empty = EmptyType()
Expand Down

0 comments on commit 80c7bc3

Please sign in to comment.