Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement graph.vectorize and Blockwise Op #306

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Blockwise some linalg Ops by default
ricardoV94 committed Sep 5, 2023
commit 9b71e4b5e617f4741a94ec556de0d1884463b0c0
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
@@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg


def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
20 changes: 15 additions & 5 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
@@ -10,11 +10,13 @@
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"

def __init__(self, hermitian):
self.hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op.

"""
return MatrixPinv(hermitian=hermitian)(x)
return Blockwise(MatrixPinv(hermitian=hermitian))(x)


class MatrixInverse(Op):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)

def __init__(self):
pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
return shapes


inv = matrix_inverse = MatrixInverse()
inv = matrix_inverse = Blockwise(MatrixInverse())


def matrix_dot(*args):
@@ -181,6 +185,8 @@ class Det(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)

def make_node(self, x):
x = as_tensor_variable(x)
@@ -209,7 +215,7 @@ def __str__(self):
return "Det"


det = Det()
det = Blockwise(Det())


class SLogDet(Op):
@@ -218,6 +224,8 @@ class SLogDet(Op):
"""

__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
@@ -242,7 +250,7 @@ def __str__(self):
return "SLogDet"


slogdet = SLogDet()
slogdet = Blockwise(SLogDet())


class Eig(Op):
@@ -252,6 +260,8 @@ class Eig(Op):
"""

__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
return [(n,), (n, n)]


eig = Eig()
eig = Blockwise(Eig())


class Eigh(Eig):
Loading