Skip to content

Commit

Permalink
[torch-mlir] provide FX traced graph importer for sparse tensors (llv…
Browse files Browse the repository at this point in the history
…m#2817)

Note that we are waiting for actual FX traced graph support for sparse
tensors. For details see

pytorch/pytorch#117188

Until then, however, we provide this clever importer that builds the FX
traced graph for for the dense case and then puts a sparse annotation
back on the parameters.

With import test.
  • Loading branch information
aartbik authored Jan 31, 2024
1 parent 1a7442e commit 105aad6
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 10 deletions.
46 changes: 36 additions & 10 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,32 @@
}


"""Check whether an object in our graph is symbolic"""
def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string.
The method currently just supports 2-dim sparse formats. This should be
generalized to the torch.sparse encodings for prefix dense batch dimensions
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
is currently implememented in torch.sparse, this should not a be problem.
"""

# TODO: any rank
if len(shape) != 2:
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")

if sparse_layout is torch.sparse_coo:
return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>'
if sparse_layout is torch.sparse_csr:
return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>'
if sparse_layout is torch.sparse_csc:
return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>'
# TODO: block format (derive block size!)

raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")


def is_symbolic(obj: Any) -> bool:
"""Check whether an object in our graph is symbolic"""
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))


Expand Down Expand Up @@ -337,7 +359,7 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram):
) from e
arg_replacements[input_name] = state_value

# Remove any lifted placeholders, replacing their uses with the state
# Remove any lifted placeholders, replacing their uses with the state
# replacement value.
g = prog.graph
for node in g.nodes:
Expand Down Expand Up @@ -455,17 +477,21 @@ def format_asm_shape(self, shape: torch.Size) -> str:

"""Return IrType for !torch.vtensor with the given shape and dtype"""

def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype):
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None):
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparse_layout is not None:
sparsity = sparsity_encoding(shape, sparse_layout)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
)
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c)

def node_val_to_type(self, node: torch_fx.Node) -> IrType:
try:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparse_layout = node.meta.get("sparsity", None)
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
Expand All @@ -475,12 +501,12 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType:
f"Quantized tensor meta data is not supported."
)
else:
return self.tensor_metadata_to_type(tensor_meta)
return self.tensor_metadata_to_type(tensor_meta, sparse_layout)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type(val.size(), val.dtype)
return self.get_vtensor_type(val.size(), val.dtype, sparse_layout)

t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
if t is not None:
Expand All @@ -495,15 +521,15 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType:
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)

def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType:
def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)

key = (tm_shape, tm.dtype)
key = (tm_shape, tm.dtype, sparse_layout)
t = self._tensor_metadata_cache.get(key)
if t is None:
t = self.get_vtensor_type(tm.shape, tm.dtype)
t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout)
self._tensor_metadata_cache[key] = t
return t

Expand Down
130 changes: 130 additions & 0 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

# RUN: %PYTHON %s | FileCheck %s

from typing import Any, Callable, Dict, Optional, Tuple

import torch
import torch.export
import torch.nn as nn

from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d


# All sparse layouts currently supported in torch.sparse.
SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc
]


def sparse_export(f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.
But until issue
https://github.com/pytorch/pytorch/pull/117907
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, and then
annotation sparse parameters with their actual sparse layout
attributes. This temporary solution accelerates testing
torch-mlir with PyTorch sparse tensors until the issue is
resovled.
"""
# Convert all arguments to dense.
dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args )
mask = [ a.layout in SPARSE_LAYOUTS for a in args ]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs, constraints=None)
# Annotate sparse arguments in the graph.
alen = len(args)
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder" and i < alen and mask[i]:
node.meta['sparsity'] = args[i].layout
# TODO: annotate inputs to change calling conventions!
return prog


def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op


def run(f):
print(f"{f.__name__}")
print("-" * len(f.__name__))
f()
print()


@run
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> {
# CHECK: %[[N:.*]] = torch.constant.none
# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32>
# CHECK: return %[[R]] : !torch.vtensor<[],f32>
# CHECK: }
def test_sparse_sum():

class SumNet(torch.nn.Module):

def __init__(self):
super(SumNet, self).__init__()

def forward(self, x):
return x.sum()


dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_csr()
m = export_and_import(SumNet(), sparse_input)
print(m)


@run
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>,
# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32>
# CHECK: }
def test_sparse_SpMM():

class MatMulNet(torch.nn.Module):

def __init__(self):
super(MatMulNet, self).__init__()

def forward(self, x, y):
return torch.matmul(x, y)


dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_coo()
m = export_and_import(MatMulNet(), sparse_input, dense_input)
print(m)

0 comments on commit 105aad6

Please sign in to comment.