Skip to content

Commit

Permalink
[torch-mlir][sparse] higher dimension COO (llvm#3042)
Browse files Browse the repository at this point in the history
Lift this from 2-dim only to n-dim for n>=2
  • Loading branch information
aartbik authored Mar 19, 2024
1 parent df02692 commit fe59f1e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,14 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
assert dim == len(shape)
blocksize = sparsity.blocksize

dims = ",".join(f"d{d}" for d in range(0, dim))
dims = ",".join(f"d{d}" for d in range(dim))

if sparsity.layout is torch.sparse_coo:
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)"
assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
Expand All @@ -322,7 +325,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
)

if batch_dim > 0:
batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim))
batch = ",".join(f"d{d}:dense" for d in range(batch_dim))
lvls = f"{batch},{lvls}"

if dense_dim > 0:
Expand Down

0 comments on commit fe59f1e

Please sign in to comment.