From ae715544506a5a2938c75931da498d3e4931be41 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:17:11 -0700 Subject: [PATCH] Finishes tests. --- tests/runtime/device_test.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index 386ab2849..47567ae66 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -153,8 +153,20 @@ def testFromTorchDevice(self): def testJit(self): from shark_turbine.ops import iree as iree_ops - t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to("cuda:0") - print(iree_ops._test_add(t, t)) + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") + result = iree_ops._test_add(t, t) + expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + torch.testing.assert_close(result.cpu(), expected) + + +class TorchCPUInterop(unittest.TestCase): + def testJit(self): + from shark_turbine.ops import iree as iree_ops + + t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") + result = iree_ops._test_add(t, t) + expected = torch.tensor([ 2., 4., 6., 8., 10.], device="cpu") + torch.testing.assert_close(result, expected) if __name__ == "__main__":