Skip to content

Commit

Permalink
Blockwise some linalg Ops by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 25, 2023
1 parent eff0721 commit c88f74c
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 128 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3735,7 +3735,7 @@ def stacklists(arg):
return arg


def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
Expand Down
20 changes: 15 additions & 5 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"

def __init__(self, hermitian):
self.hermitian = hermitian
Expand Down Expand Up @@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op.
"""
return MatrixPinv(hermitian=hermitian)(x)
return Blockwise(MatrixPinv(hermitian=hermitian))(x)


class MatrixInverse(Op):
Expand All @@ -93,6 +95,8 @@ class MatrixInverse(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)

def __init__(self):
pass
Expand Down Expand Up @@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
return shapes


inv = matrix_inverse = MatrixInverse()
inv = matrix_inverse = Blockwise(MatrixInverse())


def matrix_dot(*args):
Expand Down Expand Up @@ -181,6 +185,8 @@ class Det(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)

def make_node(self, x):
x = as_tensor_variable(x)
Expand Down Expand Up @@ -209,7 +215,7 @@ def __str__(self):
return "Det"


det = Det()
det = Blockwise(Det())


class SLogDet(Op):
Expand All @@ -218,6 +224,8 @@ class SLogDet(Op):
"""

__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
Expand All @@ -242,7 +250,7 @@ def __str__(self):
return "SLogDet"


slogdet = SLogDet()
slogdet = Blockwise(SLogDet())


class Eig(Op):
Expand All @@ -252,6 +260,8 @@ class Eig(Op):
"""

__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
Expand All @@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
return [(n,), (n, n)]


eig = Eig()
eig = Blockwise(Eig())


class Eigh(Eig):
Expand Down
176 changes: 104 additions & 72 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import cast

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import basic as at
from pytensor.tensor.basic import TensorVariable, extract_diag, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.nlinalg import Det, MatrixInverse
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
Expand All @@ -17,16 +19,40 @@
logger = logging.getLogger(__name__)


def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
if (
node
and isinstance(node.op, DimShuffle)
and not (node.op.drop or node.op.augment)
):
[inp] = node.inputs
ndims = inp.type.ndim
if ndims < 2:
return False
transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order)
return False


def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)


@register_canonicalize
@node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
(A,) = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
(X,) = A.owner.inputs
return [A.owner.op(node.op(X))]
if is_matrix_transpose(node.outputs[0]):
(A,) = node.inputs
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, MatrixInverse)
):
(X,) = A.owner.inputs
return [A.owner.op(node.op(X))]


@register_stabilize
Expand All @@ -37,86 +63,98 @@ def inv_as_solve(fgraph, node):
"""
if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs
if l.owner and isinstance(l.owner.op, MatrixInverse):
if (
l.owner
and isinstance(l.owner.op, Blockwise)
and isinstance(l.owner.op.core_op, MatrixInverse)
):
return [solve(l.owner.inputs[0], r)]
if r.owner and isinstance(r.owner.op, MatrixInverse):
if (
r.owner
and isinstance(r.owner.op, Blockwise)
and isinstance(r.owner.op.core_op, MatrixInverse)
):
x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True:
return [solve(x, l.T).T]
return [_T(solve(x, _T(l)))]
else:
return [solve(x.T, l.T).T]
return [_T(solve(_T(x), _T(l)))]


@register_stabilize
@register_canonicalize
@node_rewriter([Solve])
@node_rewriter([Blockwise])
def generic_solve_to_solve_triangular(fgraph, node):
"""
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if isinstance(node.op, Solve):
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
else:
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 1:
if node.op.core_op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, Cholesky)
):
if A.owner.op.core_op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if is_matrix_transpose(A):
(A_T,) = A.owner.inputs
if (
A_T.owner
and isinstance(A_T.owner.op, Blockwise)
and isinstance(A_T.owner.op, Cholesky)
):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
else:
return [SolveTriangular(lower=True)(A, b)]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([DimShuffle])
def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle):
if is_matrix_transpose(node.outputs[0]):
x = node.inputs[0]
if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True:
if node.op.new_order == [1, 0]:
return [x]
if getattr(x.tag, "symmetric", None):
return [x]


@register_stabilize
@node_rewriter([Solve])
@node_rewriter([Blockwise])
def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if isinstance(node.op, Solve):
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
A, b = node.inputs # result is solution Ax=b
if getattr(A.tag, "psd", None) is True:
L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the
# __if__ no other Op makes use of the L matrix during the
# stabilization
Li_b = Solve(assume_a="sym", lower=True)(L, b)
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
return [x]


@register_canonicalize
@register_stabilize
@node_rewriter([Cholesky])
@node_rewriter([Blockwise])
def cholesky_ldotlt(fgraph, node):
"""
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op, Cholesky):
if not isinstance(node.op.core_op, Cholesky):
return

A = node.inputs[0]
Expand All @@ -128,45 +166,40 @@ def cholesky_ldotlt(fgraph, node):
# cholesky(dot(L,L.T)) case
if (
getattr(l.tag, "lower_triangular", False)
and r.owner
and isinstance(r.owner.op, DimShuffle)
and r.owner.op.new_order == (1, 0)
and is_matrix_transpose(r)
and r.owner.inputs[0] == l
):
if node.op.lower:
if node.op.core_op.lower:
return [l]
return [r]

# cholesky(dot(U.T,U)) case
if (
getattr(r.tag, "upper_triangular", False)
and l.owner
and isinstance(l.owner.op, DimShuffle)
and l.owner.op.new_order == (1, 0)
and is_matrix_transpose(l)
and l.owner.inputs[0] == r
):
if node.op.lower:
if node.op.core_op.lower:
return [l]
return [r]


@register_stabilize
@register_specialize
@node_rewriter([Det])
@node_rewriter([det])
def local_det_chol(fgraph, node):
"""
If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant.
"""
if isinstance(node.op, Det):
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Cholesky):
L = cl.outputs[0]
return [prod(at.extract_diag(L) ** 2)]
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0]
return [prod(extract_diag(L) ** 2, axis=(-1, -2))]


@register_canonicalize
Expand All @@ -177,16 +210,15 @@ def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
"""
if node.op == log:
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]

# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]

# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
Loading

0 comments on commit c88f74c

Please sign in to comment.