From c88f74c599f44c98e66a48997beaf742e08b7318 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Jun 2023 12:43:23 +0200 Subject: [PATCH] Blockwise some linalg Ops by default --- pytensor/tensor/basic.py | 2 +- pytensor/tensor/nlinalg.py | 20 ++- pytensor/tensor/rewriting/linalg.py | 176 +++++++++++++++----------- pytensor/tensor/slinalg.py | 89 ++++++++++--- tests/link/numba/test_nlinalg.py | 4 +- tests/tensor/rewriting/test_linalg.py | 21 ++- tests/tensor/test_blockwise.py | 10 +- tests/tensor/test_slinalg.py | 32 ++--- 8 files changed, 226 insertions(+), 128 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 485a833542..2454de43fb 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3735,7 +3735,7 @@ def stacklists(arg): return arg -def swapaxes(y, axis1, axis2): +def swapaxes(y, axis1: int, axis2: int) -> TensorVariable: "Swap the axes of a tensor." y = as_tensor_variable(y) ndim = y.ndim diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 32fa47d28d..5e0fc2d580 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -10,11 +10,13 @@ from pytensor.tensor import basic as at from pytensor.tensor import math as tm from pytensor.tensor.basic import as_tensor_variable, extract_diag +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector class MatrixPinv(Op): __props__ = ("hermitian",) + gufunc_signature = "(m,n)->(n,m)" def __init__(self, hermitian): self.hermitian = hermitian @@ -75,7 +77,7 @@ def pinv(x, hermitian=False): solve op. """ - return MatrixPinv(hermitian=hermitian)(x) + return Blockwise(MatrixPinv(hermitian=hermitian))(x) class MatrixInverse(Op): @@ -93,6 +95,8 @@ class MatrixInverse(Op): """ __props__ = () + gufunc_signature = "(m,m)->(m,m)" + gufunc_spec = ("numpy.linalg.inv", 1, 1) def __init__(self): pass @@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes): return shapes -inv = matrix_inverse = MatrixInverse() +inv = matrix_inverse = Blockwise(MatrixInverse()) def matrix_dot(*args): @@ -181,6 +185,8 @@ class Det(Op): """ __props__ = () + gufunc_signature = "(m,m)->()" + gufunc_spec = ("numpy.linalg.det", 1, 1) def make_node(self, x): x = as_tensor_variable(x) @@ -209,7 +215,7 @@ def __str__(self): return "Det" -det = Det() +det = Blockwise(Det()) class SLogDet(Op): @@ -218,6 +224,8 @@ class SLogDet(Op): """ __props__ = () + gufunc_signature = "(m, m)->(),()" + gufunc_spec = ("numpy.linalg.slogdet", 1, 2) def make_node(self, x): x = as_tensor_variable(x) @@ -242,7 +250,7 @@ def __str__(self): return "SLogDet" -slogdet = SLogDet() +slogdet = Blockwise(SLogDet()) class Eig(Op): @@ -252,6 +260,8 @@ class Eig(Op): """ __props__: Tuple[str, ...] = () + gufunc_signature = "(m,m)->(m),(m,m)" + gufunc_spec = ("numpy.linalg.eig", 1, 2) def make_node(self, x): x = as_tensor_variable(x) @@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes): return [(n,), (n, n)] -eig = Eig() +eig = Blockwise(Eig()) class Eigh(Eig): diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 103feb07b6..ae020f27af 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,11 +1,13 @@ import logging +from typing import cast from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor import basic as at +from pytensor.tensor.basic import TensorVariable, extract_diag, swapaxes from pytensor.tensor.blas import Dot22 +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, log, prod -from pytensor.tensor.nlinalg import Det, MatrixInverse +from pytensor.tensor.nlinalg import MatrixInverse, det from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, @@ -17,16 +19,40 @@ logger = logging.getLogger(__name__) +def is_matrix_transpose(x: TensorVariable) -> bool: + """Check if a variable corresponds to a transpose of the last two axes""" + node = x.owner + if ( + node + and isinstance(node.op, DimShuffle) + and not (node.op.drop or node.op.augment) + ): + [inp] = node.inputs + ndims = inp.type.ndim + if ndims < 2: + return False + transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2) + return cast(bool, node.op.new_order == transpose_order) + return False + + +def _T(x: TensorVariable) -> TensorVariable: + """Matrix transpose for potentially higher dimensionality tensors""" + return swapaxes(x, -1, -2) + + @register_canonicalize @node_rewriter([DimShuffle]) def transinv_to_invtrans(fgraph, node): - if isinstance(node.op, DimShuffle): - if node.op.new_order == (1, 0): - (A,) = node.inputs - if A.owner: - if isinstance(A.owner.op, MatrixInverse): - (X,) = A.owner.inputs - return [A.owner.op(node.op(X))] + if is_matrix_transpose(node.outputs[0]): + (A,) = node.inputs + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, MatrixInverse) + ): + (X,) = A.owner.inputs + return [A.owner.op(node.op(X))] @register_stabilize @@ -37,43 +63,56 @@ def inv_as_solve(fgraph, node): """ if isinstance(node.op, (Dot, Dot22)): l, r = node.inputs - if l.owner and isinstance(l.owner.op, MatrixInverse): + if ( + l.owner + and isinstance(l.owner.op, Blockwise) + and isinstance(l.owner.op.core_op, MatrixInverse) + ): return [solve(l.owner.inputs[0], r)] - if r.owner and isinstance(r.owner.op, MatrixInverse): + if ( + r.owner + and isinstance(r.owner.op, Blockwise) + and isinstance(r.owner.op.core_op, MatrixInverse) + ): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: - return [solve(x, l.T).T] + return [_T(solve(x, _T(l)))] else: - return [solve(x.T, l.T).T] + return [_T(solve(_T(x), _T(l)))] @register_stabilize @register_canonicalize -@node_rewriter([Solve]) +@node_rewriter([Blockwise]) def generic_solve_to_solve_triangular(fgraph, node): """ If any solve() is applied to the output of a cholesky op, then replace it with a triangular solve. """ - if isinstance(node.op, Solve): - A, b = node.inputs # result is solution Ax=b - if A.owner and isinstance(A.owner.op, Cholesky): - if A.owner.op.lower: - return [SolveTriangular(lower=True)(A, b)] - else: - return [SolveTriangular(lower=False)(A, b)] - if ( - A.owner - and isinstance(A.owner.op, DimShuffle) - and A.owner.op.new_order == (1, 0) - ): - (A_T,) = A.owner.inputs - if A_T.owner and isinstance(A_T.owner.op, Cholesky): - if A_T.owner.op.lower: - return [SolveTriangular(lower=False)(A, b)] - else: + if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 1: + if node.op.core_op.assume_a == "gen": + A, b = node.inputs # result is solution Ax=b + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, Cholesky) + ): + if A.owner.op.core_op.lower: return [SolveTriangular(lower=True)(A, b)] + else: + return [SolveTriangular(lower=False)(A, b)] + if is_matrix_transpose(A): + (A_T,) = A.owner.inputs + if ( + A_T.owner + and isinstance(A_T.owner.op, Blockwise) + and isinstance(A_T.owner.op, Cholesky) + ): + if A_T.owner.op.lower: + return [SolveTriangular(lower=False)(A, b)] + else: + return [SolveTriangular(lower=True)(A, b)] @register_canonicalize @@ -81,34 +120,33 @@ def generic_solve_to_solve_triangular(fgraph, node): @register_specialize @node_rewriter([DimShuffle]) def no_transpose_symmetric(fgraph, node): - if isinstance(node.op, DimShuffle): + if is_matrix_transpose(node.outputs[0]): x = node.inputs[0] - if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True: - if node.op.new_order == [1, 0]: - return [x] + if getattr(x.tag, "symmetric", None): + return [x] @register_stabilize -@node_rewriter([Solve]) +@node_rewriter([Blockwise]) def psd_solve_with_chol(fgraph, node): """ This utilizes a boolean `psd` tag on matrices. """ - if isinstance(node.op, Solve): + if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2: A, b = node.inputs # result is solution Ax=b if getattr(A.tag, "psd", None) is True: L = cholesky(A) # N.B. this can be further reduced to a yet-unwritten cho_solve Op - # __if__ no other Op makes use of the the L matrix during the + # __if__ no other Op makes use of the L matrix during the # stabilization - Li_b = Solve(assume_a="sym", lower=True)(L, b) - x = Solve(assume_a="sym", lower=False)(L.T, Li_b) + Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2) + x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2) return [x] @register_canonicalize @register_stabilize -@node_rewriter([Cholesky]) +@node_rewriter([Blockwise]) def cholesky_ldotlt(fgraph, node): """ rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, @@ -116,7 +154,7 @@ def cholesky_ldotlt(fgraph, node): This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. """ - if not isinstance(node.op, Cholesky): + if not isinstance(node.op.core_op, Cholesky): return A = node.inputs[0] @@ -128,45 +166,40 @@ def cholesky_ldotlt(fgraph, node): # cholesky(dot(L,L.T)) case if ( getattr(l.tag, "lower_triangular", False) - and r.owner - and isinstance(r.owner.op, DimShuffle) - and r.owner.op.new_order == (1, 0) + and is_matrix_transpose(r) and r.owner.inputs[0] == l ): - if node.op.lower: + if node.op.core_op.lower: return [l] return [r] # cholesky(dot(U.T,U)) case if ( getattr(r.tag, "upper_triangular", False) - and l.owner - and isinstance(l.owner.op, DimShuffle) - and l.owner.op.new_order == (1, 0) + and is_matrix_transpose(l) and l.owner.inputs[0] == r ): - if node.op.lower: + if node.op.core_op.lower: return [l] return [r] @register_stabilize @register_specialize -@node_rewriter([Det]) +@node_rewriter([det]) def local_det_chol(fgraph, node): """ If we have det(X) and there is already an L=cholesky(X) floating around, then we can use prod(diag(L)) to get the determinant. """ - if isinstance(node.op, Det): - (x,) = node.inputs - for cl, xpos in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Cholesky): - L = cl.outputs[0] - return [prod(at.extract_diag(L) ** 2)] + (x,) = node.inputs + for cl, xpos in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): + L = cl.outputs[0] + return [prod(extract_diag(L) ** 2, axis=(-1, -2))] @register_canonicalize @@ -177,16 +210,15 @@ def local_log_prod_sqr(fgraph, node): """ This utilizes a boolean `positive` tag on matrices. """ - if node.op == log: - (x,) = node.inputs - if x.owner and isinstance(x.owner.op, Prod): - # we cannot always make this substitution because - # the prod might include negative terms - p = x.owner.inputs[0] - - # p is the matrix we're reducing with prod - if getattr(p.tag, "positive", None) is True: - return [log(p).sum(axis=x.owner.op.axis)] - - # TODO: have a reduction like prod and sum that simply - # returns the sign of the prod multiplication. + (x,) = node.inputs + if x.owner and isinstance(x.owner.op, Prod): + # we cannot always make this substitution because + # the prod might include negative terms + p = x.owner.inputs[0] + + # p is the matrix we're reducing with prod + if getattr(p.tag, "positive", None) is True: + return [log(p).sum(axis=x.owner.op.axis)] + + # TODO: have a reduction like prod and sum that simply + # returns the sign of the prod multiplication. diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 1d14689768..4b0310c4ce 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,6 +1,6 @@ import logging import warnings -from typing import TYPE_CHECKING, Literal, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np import scipy.linalg @@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at from pytensor.tensor import math as atm +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.var import TensorVariable @@ -46,6 +47,7 @@ class Cholesky(Op): # TODO: LAPACK wrapper with in-place behavior, for solve also __props__ = ("lower", "destructive", "on_error") + gufunc_signature = "(m,m)->(m,m)" def __init__(self, *, lower=True, on_error="raise"): self.lower = lower @@ -107,7 +109,7 @@ def tril_and_halve_diagonal(mtx): def conjugate_solve_triangular(outer, inner): """Computes L^{-T} P L^{-1} for lower-triangular L.""" - solve_upper = SolveTriangular(lower=False) + solve_upper = SolveTriangular(lower=False, b_ndim=2) return solve_upper(outer.T, solve_upper(outer.T, inner.T).T) s = conjugate_solve_triangular( @@ -126,7 +128,7 @@ def conjugate_solve_triangular(outer, inner): def cholesky(x, lower=True, on_error="raise"): - return Cholesky(lower=lower, on_error=on_error)(x) + return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) class SolveBase(Op): @@ -135,6 +137,7 @@ class SolveBase(Op): __props__ = ( "lower", "check_finite", + "b_ndim", ) def __init__( @@ -142,9 +145,16 @@ def __init__( *, lower=False, check_finite=True, + b_ndim, ): self.lower = lower self.check_finite = check_finite + assert b_ndim in (1, 2) + self.b_ndim = b_ndim + if b_ndim == 1: + self.gufunc_signature = "(m,m),(m)->(m)" + else: + self.gufunc_signature = "(m,m),(m,n)->(m,n)" def perform(self, node, inputs, outputs): pass @@ -155,8 +165,8 @@ def make_node(self, A, b): if A.ndim != 2: raise ValueError(f"`A` must be a matrix; got {A.type} instead.") - if b.ndim not in (1, 2): - raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.") + if b.ndim != self.b_ndim: + raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices o_dtype = scipy.linalg.solve( @@ -207,6 +217,16 @@ def L_op(self, inputs, outputs, output_gradients): return [A_bar, b_bar] +def _default_b_ndim(b, b_ndim): + if b_ndim is not None: + assert b_ndim in (1, 2) + return b_ndim + + b = as_tensor_variable(b) + if b_ndim is None: + return min(b.ndim, 2) # By default assume the core case is a matrix + + class CholeskySolve(SolveBase): def __init__(self, **kwargs): kwargs.setdefault("lower", True) @@ -226,7 +246,7 @@ def L_op(self, *args, **kwargs): raise NotImplementedError() -def cho_solve(c_and_lower, b, *, check_finite=True): +def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None): """Solve the linear equations A x = b, given the Cholesky factorization of A. Parameters @@ -239,9 +259,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ A, lower = c_and_lower - return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim) + )(A, b) class SolveTriangular(SolveBase): @@ -252,6 +278,7 @@ class SolveTriangular(SolveBase): "unit_diagonal", "lower", "check_finite", + "b_ndim", ) def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): @@ -289,6 +316,7 @@ def solve_triangular( lower: bool = False, unit_diagonal: bool = False, check_finite: bool = True, + b_ndim: Optional[int] = None, ) -> TensorVariable: """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix. @@ -312,12 +340,19 @@ def solve_triangular( Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ - return SolveTriangular( - lower=lower, - trans=trans, - unit_diagonal=unit_diagonal, - check_finite=check_finite, + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + SolveTriangular( + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + check_finite=check_finite, + b_ndim=b_ndim, + ) )(a, b) @@ -330,6 +365,7 @@ class Solve(SolveBase): "assume_a", "lower", "check_finite", + "b_ndim", ) def __init__(self, *, assume_a="gen", **kwargs): @@ -350,7 +386,15 @@ def perform(self, node, inputs, outputs): ) -def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): +def solve( + a, + b, + *, + assume_a="gen", + lower=False, + check_finite=True, + b_ndim: Optional[int] = None, +): """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. If the data matrix is known to be a particular type then supplying the @@ -373,9 +417,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): Parameters ---------- - a : (N, N) array_like + a : (..., N, N) array_like Square input data - b : (N, NRHS) array_like + b : (..., N, NRHS) array_like Input data for the right hand side. lower : bool, optional If True, only the data contained in the lower triangle of `a`. Default @@ -386,11 +430,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): (crashes, non-termination) if the inputs do contain infinities or NaNs. assume_a : str, optional Valid entries are explained above. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ - return Solve( - lower=lower, - check_finite=check_finite, - assume_a=assume_a, + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + Solve( + lower=lower, + check_finite=check_finite, + assume_a=assume_a, + b_ndim=b_ndim, + ) )(a, b) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 51c1c4b648..857bd49152 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc): ], ) def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower=lower)(A, x) + g = slinalg.Solve(lower=lower, b_ndim=1)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) @@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc): ], ) def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower=lower)(A, x) + g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 3a09244a40..58ea98d626 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -9,11 +9,12 @@ from pytensor import tensor as at from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.rewriting.linalg import inv_as_solve -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve from pytensor.tensor.type import dmatrix, matrix, vector from tests import unittest_tools as utt from tests.test_rop import break_op @@ -23,7 +24,7 @@ def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") v = vector("v") - y = matrix_inverse(mx).sum(axis=0) + y = MatrixInverse()(mx).sum(axis=0) yv = pytensor.gradient.Rop(y, mx, mv) rop_f = function([mx, mv], yv) @@ -83,13 +84,11 @@ def test_transinv_to_invtrans(): def test_generic_solve_to_solve_triangular(): - cholesky_lower = Cholesky(lower=True) - cholesky_upper = Cholesky(lower=False) A = matrix("A") x = matrix("x") - L = cholesky_lower(A) - U = cholesky_upper(A) + L = cholesky(A, lower=True) + U = cholesky(A, lower=False) b1 = solve(L, x) b2 = solve(U, x) f = pytensor.function([A, x], b1) @@ -130,15 +129,15 @@ def test_matrix_inverse_solve(): b = dmatrix("b") node = matrix_inverse(A).dot(b).owner [out] = inv_as_solve.transform(None, node) - assert isinstance(out.owner.op, Solve) + assert isinstance(out.owner.op, Blockwise) and isinstance( + out.owner.op.core_op, Solve + ) @pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("product", ("lower", "upper", None)) def test_cholesky_ldotlt(tag, cholesky_form, product): - cholesky = Cholesky(lower=(cholesky_form == "lower")) - transform_removes_chol = tag is not None and product == tag transform_transposes = transform_removes_chol and cholesky_form != tag @@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): else: M = A - C = cholesky(M) + C = cholesky(M, lower=(cholesky_form == "lower")) f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) - print(f.maker.fgraph.apply_nodes) - no_cholesky_in_graph = not any( isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 437e3bbc22..658c527430 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -24,6 +24,7 @@ def test_vectorize_blockwise(): assert isinstance(vect_node.op, Blockwise) and isinstance( vect_node.op.core_op, MatrixInverse ) + assert vect_node.op.signature == ("(m,m)->(m,m)") assert vect_node.inputs[0] is tns # Useless blockwise @@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester): signature = "(m, m) -> (m, m)" -class TestSolve(BlockwiseOpTester): - core_op = Solve(lower=True) +class TestSolveVector(BlockwiseOpTester): + core_op = Solve(lower=True, b_ndim=1) signature = "(m, m),(m) -> (m)" + + +class TestSolveMatrix(BlockwiseOpTester): + core_op = Solve(lower=True, b_ndim=2) + signature = "(m, m),(m, n) -> (m, n)" diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 5d61a6a2af..849734f05c 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -180,7 +180,7 @@ class TestSolveBase(utt.InferShapeTester): ( matrix, functools.partial(tensor, dtype="floatX", shape=(None,) * 3), - "`b` must be a matrix or a vector.*", + "`b` must have 2 dims.*", ), ], ) @@ -189,20 +189,20 @@ def test_make_node(self, A_func, b_func, error_message): with pytest.raises(ValueError, match=error_message): A = A_func() b = b_func() - SolveBase()(A, b) + SolveBase(b_ndim=2)(A, b) def test__repr__(self): np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = SolveBase()(A, b) - assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0" + y = SolveBase(b_ndim=2)(A, b) + assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0" class TestSolve(utt.InferShapeTester): def test__init__(self): with pytest.raises(ValueError) as excinfo: - Solve(assume_a="test") + Solve(assume_a="test", b_ndim=2) assert "is not a recognized matrix structure" in str(excinfo.value) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @@ -277,7 +277,7 @@ def test_solve_grad(self, m, n, assume_a, lower): if config.floatX == "float64": eps = 2e-8 - solve_op = Solve(assume_a=assume_a, lower=lower) + solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) @@ -348,19 +348,20 @@ def test_solve_grad(self, m, n, lower): if config.floatX == "float64": eps = 2e-8 - solve_op = SolveTriangular(lower=lower) + solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) class TestCholeskySolve(utt.InferShapeTester): def setup_method(self): self.op_class = CholeskySolve - self.op = CholeskySolve() - self.op_upper = CholeskySolve(lower=False) super().setup_method() def test_repr(self): - assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)" + assert ( + repr(CholeskySolve(lower=True, b_ndim=1)) + == "CholeskySolve(lower=True,check_finite=True,b_ndim=1)" + ) def test_infer_shape(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -368,7 +369,7 @@ def test_infer_shape(self): b = matrix() self._compile_and_check( [A, b], # pytensor.function inputs - [self.op(A, b)], # pytensor.function outputs + [self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -382,7 +383,7 @@ def test_infer_shape(self): b = vector() self._compile_and_check( [A, b], # pytensor.function inputs - [self.op(A, b)], # pytensor.function outputs + [self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -396,10 +397,10 @@ def test_solve_correctness(self): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = self.op(A, b) + y = self.op_class(lower=True, b_ndim=2)(A, b) cho_solve_lower_func = pytensor.function([A, b], y) - y = self.op_upper(A, b) + y = self.op_class(lower=False, b_ndim=2)(A, b) cho_solve_upper_func = pytensor.function([A, b], y) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) @@ -434,12 +435,13 @@ def test_solve_dtype(self): A_val = np.eye(2) b_val = np.ones((2, 1)) + op = self.op_class(b_ndim=2) # try all dtype combinations for A_dtype, b_dtype in itertools.product(dtypes, dtypes): A = matrix(dtype=A_dtype) b = matrix(dtype=b_dtype) - x = self.op(A, b) + x = op(A, b) fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))