Skip to content

Commit

Permalink
Finishes tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Apr 25, 2024
1 parent c59b870 commit ae71554
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions tests/runtime/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,20 @@ def testFromTorchDevice(self):
def testJit(self):
from shark_turbine.ops import iree as iree_ops

t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to("cuda:0")
print(iree_ops._test_add(t, t))
t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0")
result = iree_ops._test_add(t, t)
expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu")
torch.testing.assert_close(result.cpu(), expected)


class TorchCPUInterop(unittest.TestCase):
def testJit(self):
from shark_turbine.ops import iree as iree_ops

t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu")
result = iree_ops._test_add(t, t)
expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu")
torch.testing.assert_close(result, expected)


if __name__ == "__main__":
Expand Down

0 comments on commit ae71554

Please sign in to comment.