From 105aad6f57a19db1cfcf17bb394367431973b65e Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 30 Jan 2024 21:22:12 -0800 Subject: [PATCH] [torch-mlir] provide FX traced graph importer for sparse tensors (#2817) Note that we are waiting for actual FX traced graph support for sparse tensors. For details see https://github.com/pytorch/pytorch/issues/117188 Until then, however, we provide this clever importer that builds the FX traced graph for for the dense case and then puts a sparse annotation back on the parameters. With import test. --- python/torch_mlir/extras/fx_importer.py | 46 +++++++-- test/python/fx_importer/sparse_test.py | 130 ++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 10 deletions(-) create mode 100644 test/python/fx_importer/sparse_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index d799d61f6a92..8cffcb1ea935 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -207,10 +207,32 @@ } -"""Check whether an object in our graph is symbolic""" +def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str: + """Returns sparse tensor encoding for the given sparse layout as string. + + The method currently just supports 2-dim sparse formats. This should be + generalized to the torch.sparse encodings for prefix dense batch dimensions + and suffix dense subtensor dimensions. Since MLIR supports a superset of what + is currently implememented in torch.sparse, this should not a be problem. + """ + + # TODO: any rank + if len(shape) != 2: + raise RuntimeError(f"Unsupported sparse rank {len(shape)}") + + if sparse_layout is torch.sparse_coo: + return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>' + if sparse_layout is torch.sparse_csr: + return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>' + if sparse_layout is torch.sparse_csc: + return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>' + # TODO: block format (derive block size!) + + raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") def is_symbolic(obj: Any) -> bool: + """Check whether an object in our graph is symbolic""" return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool)) @@ -337,7 +359,7 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): ) from e arg_replacements[input_name] = state_value - # Remove any lifted placeholders, replacing their uses with the state + # Remove any lifted placeholders, replacing their uses with the state # replacement value. g = prog.graph for node in g.nodes: @@ -455,17 +477,21 @@ def format_asm_shape(self, shape: torch.Size) -> str: """Return IrType for !torch.vtensor with the given shape and dtype""" - def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype): + def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) + if sparse_layout is not None: + sparsity = sparsity_encoding(shape, sparse_layout) + return IrType.parse( + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c) return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c - ) + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c) def node_val_to_type(self, node: torch_fx.Node) -> IrType: try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") + sparse_layout = node.meta.get("sparsity", None) if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, @@ -475,12 +501,12 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"Quantized tensor meta data is not supported." ) else: - return self.tensor_metadata_to_type(tensor_meta) + return self.tensor_metadata_to_type(tensor_meta, sparse_layout) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): - return self.get_vtensor_type(val.size(), val.dtype) + return self.get_vtensor_type(val.size(), val.dtype, sparse_layout) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) if t is not None: @@ -495,15 +521,15 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType: + def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype) + key = (tm_shape, tm.dtype, sparse_layout) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type(tm.shape, tm.dtype) + t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout) self._tensor_metadata_cache[key] = t return t diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py new file mode 100644 index 000000000000..1490c160c3f1 --- /dev/null +++ b/test/python/fx_importer/sparse_test.py @@ -0,0 +1,130 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d + + +# All sparse layouts currently supported in torch.sparse. +SPARSE_LAYOUTS = [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc +] + + +def sparse_export(f: Callable, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram: + """ + This is a ***temporary*** wrapper around `torch.export.export` + that eventually should be removed and simply replaced by the + standard API for exporting traced graphs. + + But until issue + + https://github.com/pytorch/pytorch/pull/117907 + + is addressed, this wrapper provides support for the sparse + tensor types by first converting all operands to dense tensors, + building the traced graph as for the dense case, and then + annotation sparse parameters with their actual sparse layout + attributes. This temporary solution accelerates testing + torch-mlir with PyTorch sparse tensors until the issue is + resovled. + """ + # Convert all arguments to dense. + dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args ) + mask = [ a.layout in SPARSE_LAYOUTS for a in args ] + # Build the regular FX traced graph with only dense arguments + # (the current version would crash otherwise, see issue above). + prog = torch.export.export(f, dargs, kwargs, constraints=None) + # Annotate sparse arguments in the graph. + alen = len(args) + for i, node in enumerate(prog.graph.nodes): + if node.op == "placeholder" and i < alen and mask[i]: + node.meta['sparsity'] = args[i].layout + # TODO: annotate inputs to change calling conventions! + return prog + + +def export_and_import(f, *args, **kwargs): + """This method implements Stella's importer, stripped down to essentials.""" + context = ir.Context() + torch_d.register_dialect(context) + fx_importer = FxImporter(context=context) + prog = sparse_export(f, args, kwargs) + fx_importer.import_frozen_exported_program(prog) + return fx_importer.module_op + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sparse_sum +# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> { +# CHECK: %[[N:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32> +# CHECK: return %[[R]] : !torch.vtensor<[],f32> +# CHECK: } +def test_sparse_sum(): + + class SumNet(torch.nn.Module): + + def __init__(self): + super(SumNet, self).__init__() + + def forward(self, x): + return x.sum() + + + dense_input = torch.ones(64, 64) + sparse_input = dense_input.to_sparse_csr() + m = export_and_import(SumNet(), sparse_input) + print(m) + + +@run +# CHECK-LABEL: test_sparse_SpMM +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> { +# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32> +# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32> +# CHECK: } +def test_sparse_SpMM(): + + class MatMulNet(torch.nn.Module): + + def __init__(self): + super(MatMulNet, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + + dense_input = torch.ones(64, 64) + sparse_input = dense_input.to_sparse_coo() + m = export_and_import(MatMulNet(), sparse_input, dense_input) + print(m)