diff --git a/src/qibojit/backends/matrices.py b/src/qibojit/backends/matrices.py index e538f67..a6622d1 100644 --- a/src/qibojit/backends/matrices.py +++ b/src/qibojit/backends/matrices.py @@ -75,6 +75,9 @@ def __init__(self, dtype): self.cp = cp + def I(self, n=2): + return self.cp.eye(n, dtype=self.dtype) + def _cast(self, x, dtype): is_cupy = [ isinstance(item, self.cp.ndarray) for sublist in x for item in sublist