Skip to content

Commit

Permalink
[torch-mlir][sparse] add a true network to our NN tests (llvm#3305)
Browse files Browse the repository at this point in the history
Objective: make the to_sparse work end-to-end!
  • Loading branch information
aartbik authored May 9, 2024
1 parent cff144b commit 89bb740
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def run(f):


@run
#
# CHECK-LABEL: test_sparse_id
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -250,6 +251,7 @@ def forward(self, x):


@run
#
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -284,6 +286,7 @@ def forward(self, x):


@run
#
# CHECK-LABEL: test_sparse_SpMV
# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -319,6 +322,7 @@ def forward(self, x, v):


@run
#
# CHECK-LABEL: test_sparse_SpMM
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -361,6 +365,7 @@ def forward(self, x, y):


@run
#
# CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -428,6 +433,7 @@ def forward(self, x):


@run
#
# CHECK-LABEL: test_sparse_coo3
# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -473,6 +479,7 @@ def forward(self, x):


@run
#
# CHECK-LABEL: test_sparse_activation
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
Expand Down Expand Up @@ -518,3 +525,86 @@ def forward(self, x):
print(res2[2])
print(res2[3])
print(res2[4])


@run
#
# CHECK-LABEL: test_sparse_network
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> {
# ... lots of IR ...
# CHECK-COUNT-15: torch.aten.mul.Tensor
# ... lots of IR ...
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor([48., 48., 48., 48., 48., 48., 48., 48.])
# CHECK: torch.mlir
# CHECK: [48. 48. 48. 48. 48. 48. 48. 48.]
#
def test_sparse_network():
def spike(input):
return (input >= 0).float()

def sqSum(input):
return (input * input).sum()

class LIF(nn.Module):
def __init__(self):
super(LIF, self).__init__()
self.thresh = 1.0
self.decay = 0.5
self.act = spike

def forward(self, X):
"""A filter that yields a binary-valued sparse tensor."""
mem = 0
spike_pot = []
T = X.size(-1)
for t in range(T):
mem = mem * self.decay + X[..., t]
spike = self.act(mem - self.thresh)
mem = mem * (1.0 - spike)
spike_pot.append(spike)
spike_pot = torch.stack(spike_pot, dim=-1)
# TODO: we would like to see something like
# return spike_pot.to_sparse()
return spike_pot

class tdLayer(nn.Module):
def __init__(self, layer):
super(tdLayer, self).__init__()
self.layer = layer

def forward(self, X):
T = X.size(-1)
out = []
for t in range(T):
m = self.layer(X[..., t])
out.append(m)
out = torch.stack(out, dim=-1)
return out

class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.spike = LIF()
self.layer = tdLayer(sqSum)

def forward(self, X):
out = self.spike(X)
out = self.layer(out)
return out

net = Block()
x = torch.ones(2, 3, 8, 8)
m = export_and_import(net, x)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
res1 = net(x)
res2 = sparse_jit(net, x)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)

0 comments on commit 89bb740

Please sign in to comment.