From c88b70b7f45dbbb2642ce851eed9ab32d9eacae8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Jun 2023 11:46:08 +0200 Subject: [PATCH 1/7] Change `py_only` database placement --- pytensor/compile/mode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 514f8f48c4..dd81a81ded 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -251,13 +251,14 @@ def apply(self, fgraph): # especially constant merge optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49) +optdb.register("py_only", EquilibriumDB(), "fast_compile", position=49.1) + optdb.register( "add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5 ) # final pass just to make sure optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100) -optdb.register("py_only", EquilibriumDB(), "fast_compile", position=100) _tags: Union[Tuple[str, str], Tuple] From 2ccd058ba1c62688701dd10d338b02125065ceae Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Jun 2023 14:49:44 +0200 Subject: [PATCH 2/7] Remove _numop attribute from linalg Ops --- pytensor/tensor/nlinalg.py | 25 +++++++++---------------- scripts/mypy-failing.txt | 3 +-- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 3f460e9303..709507afdd 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1,6 +1,5 @@ -import typing from functools import partial -from typing import Callable, Tuple +from typing import Tuple import numpy as np @@ -271,7 +270,6 @@ class Eig(Op): """ - _numop = staticmethod(np.linalg.eig) __props__: Tuple[str, ...] = () def make_node(self, x): @@ -284,7 +282,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): (x,) = inputs (w, v) = outputs - w[0], v[0] = (z.astype(x.dtype) for z in self._numop(x)) + w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x)) def infer_shape(self, fgraph, node, shapes): n = shapes[0][0] @@ -300,7 +298,6 @@ class Eigh(Eig): """ - _numop = typing.cast(Callable, staticmethod(np.linalg.eigh)) __props__ = ("UPLO",) def __init__(self, UPLO="L"): @@ -315,7 +312,7 @@ def make_node(self, x): # LAPACK. Rather than trying to reproduce the (rather # involved) logic, we just probe linalg.eigh with a trivial # input. - w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name + w_dtype = np.linalg.eigh([[np.dtype(x.dtype).type()]])[0].dtype.name w = vector(dtype=w_dtype) v = matrix(dtype=w_dtype) return Apply(self, [x], [w, v]) @@ -323,7 +320,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): (x,) = inputs (w, v) = outputs - w[0], v[0] = self._numop(x, self.UPLO) + w[0], v[0] = np.linalg.eigh(x, self.UPLO) def grad(self, inputs, g_outputs): r"""The gradient function should return @@ -446,7 +443,6 @@ class QRFull(Op): """ - _numop = staticmethod(np.linalg.qr) __props__ = ("mode",) def __init__(self, mode): @@ -478,7 +474,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): (x,) = inputs assert x.ndim == 2, "The input of qr function should be a matrix." - res = self._numop(x, self.mode) + res = np.linalg.qr(x, self.mode) if self.mode != "r": outputs[0][0], outputs[1][0] = res else: @@ -547,7 +543,6 @@ class SVD(Op): """ # See doc in the docstring of the function just after this class. - _numop = staticmethod(np.linalg.svd) __props__ = ("full_matrices", "compute_uv") def __init__(self, full_matrices=True, compute_uv=True): @@ -575,10 +570,10 @@ def perform(self, node, inputs, outputs): assert x.ndim == 2, "The input of svd function should be a matrix." if self.compute_uv: u, s, vt = outputs - u[0], s[0], vt[0] = self._numop(x, self.full_matrices, self.compute_uv) + u[0], s[0], vt[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) else: (s,) = outputs - s[0] = self._numop(x, self.full_matrices, self.compute_uv) + s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) def infer_shape(self, fgraph, node, shapes): (x_shape,) = shapes @@ -730,7 +725,6 @@ class TensorInv(Op): PyTensor utilization of numpy.linalg.tensorinv; """ - _numop = staticmethod(np.linalg.tensorinv) __props__ = ("ind",) def __init__(self, ind=2): @@ -744,7 +738,7 @@ def make_node(self, a): def perform(self, node, inputs, outputs): (a,) = inputs (x,) = outputs - x[0] = self._numop(a, self.ind) + x[0] = np.linalg.tensorinv(a, self.ind) def infer_shape(self, fgraph, node, shapes): sp = shapes[0][self.ind :] + shapes[0][: self.ind] @@ -790,7 +784,6 @@ class TensorSolve(Op): """ - _numop = staticmethod(np.linalg.tensorsolve) __props__ = ("axes",) def __init__(self, axes=None): @@ -809,7 +802,7 @@ def perform(self, node, inputs, outputs): b, ) = inputs (x,) = outputs - x[0] = self._numop(a, b, self.axes) + x[0] = np.linalg.tensorsolve(a, b, self.axes) def tensorsolve(a, b, axes=None): diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 61b7887c06..1cae4d9152 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -32,5 +32,4 @@ pytensor/tensor/slinalg.py pytensor/tensor/subtensor.py pytensor/tensor/type.py pytensor/tensor/type_other.py -pytensor/tensor/variable.py -pytensor/tensor/nlinalg.py \ No newline at end of file +pytensor/tensor/variable.py \ No newline at end of file From d5a988105874cf11da6e84f3d7503786845f079f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Jun 2023 14:50:35 +0200 Subject: [PATCH 3/7] Remove duplicated `Inv` Op --- pytensor/link/numba/dispatch/nlinalg.py | 13 ------------- pytensor/tensor/nlinalg.py | 21 +-------------------- tests/link/numba/test_nlinalg.py | 20 -------------------- 3 files changed, 1 insertion(+), 53 deletions(-) diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 63115e1926..860560d0a6 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -14,7 +14,6 @@ Det, Eig, Eigh, - Inv, MatrixInverse, MatrixPinv, QRFull, @@ -125,18 +124,6 @@ def eigh(x): return eigh -@numba_funcify.register(Inv) -def numba_funcify_Inv(op, node, **kwargs): - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_basic.numba_njit(inline="always") - def inv(x): - return np.linalg.inv(inputs_cast(x)).astype(out_dtype) - - return inv - - @numba_funcify.register(MatrixInverse) def numba_funcify_MatrixInverse(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 709507afdd..32fa47d28d 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -78,25 +78,6 @@ def pinv(x, hermitian=False): return MatrixPinv(hermitian=hermitian)(x) -class Inv(Op): - """Computes the inverse of one or more matrices.""" - - def make_node(self, x): - x = as_tensor_variable(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inputs, outputs): - (x,) = inputs - (z,) = outputs - z[0] = np.linalg.inv(x).astype(x.dtype) - - def infer_shape(self, fgraph, node, shapes): - return shapes - - -inv = Inv() - - class MatrixInverse(Op): r"""Computes the inverse of a matrix :math:`A`. @@ -169,7 +150,7 @@ def infer_shape(self, fgraph, node, shapes): return shapes -matrix_inverse = MatrixInverse() +inv = matrix_inverse = MatrixInverse() def matrix_dot(*args): diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 7bc60d1313..223e14fed0 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -352,26 +352,6 @@ def test_Eigh(x, uplo, exc): None, (), ), - ( - nlinalg.Inv, - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - (), - ), - ( - nlinalg.Inv, - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - None, - (), - ), ( nlinalg.MatrixPinv, set_test_value( From f49b2cc318e2411846d4d61b69f30e5e87755557 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 23 Jun 2023 10:38:38 +0200 Subject: [PATCH 4/7] CholeskySolve inherits from BaseSolve --- pytensor/tensor/slinalg.py | 138 +++++++++++-------------------- tests/link/numba/test_nlinalg.py | 6 +- tests/tensor/test_slinalg.py | 2 +- 3 files changed, 51 insertions(+), 95 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 58b09deca3..68fac3e90b 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -49,7 +49,7 @@ class Cholesky(Op): __props__ = ("lower", "destructive", "on_error") - def __init__(self, lower=True, on_error="raise"): + def __init__(self, *, lower=True, on_error="raise"): self.lower = lower self.destructive = False if on_error not in ("raise", "nan"): @@ -127,77 +127,8 @@ def conjugate_solve_triangular(outer, inner): return [grad] -cholesky = Cholesky() - - -class CholeskySolve(Op): - __props__ = ("lower", "check_finite") - - def __init__( - self, - lower=True, - check_finite=True, - ): - self.lower = lower - self.check_finite = check_finite - - def __repr__(self): - return "CholeskySolve{%s}" % str(self._props()) - - def make_node(self, C, b): - C = as_tensor_variable(C) - b = as_tensor_variable(b) - assert C.ndim == 2 - assert b.ndim in (1, 2) - - # infer dtype by solving the most simple - # case with (1, 1) matrices - o_dtype = scipy.linalg.solve( - np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype) - ).dtype - x = tensor(dtype=o_dtype, shape=b.type.shape) - return Apply(self, [C, b], [x]) - - def perform(self, node, inputs, output_storage): - C, b = inputs - rval = scipy.linalg.cho_solve( - (C, self.lower), - b, - check_finite=self.check_finite, - ) - - output_storage[0][0] = rval - - def infer_shape(self, fgraph, node, shapes): - Cshape, Bshape = shapes - rows = Cshape[1] - if len(Bshape) == 1: # b is a Vector - return [(rows,)] - else: - cols = Bshape[1] # b is a Matrix - return [(rows, cols)] - - -cho_solve = CholeskySolve() - - -def cho_solve(c_and_lower, b, check_finite=True): - """Solve the linear equations A x = b, given the Cholesky factorization of A. - - Parameters - ---------- - (c, lower) : tuple, (array, bool) - Cholesky factorization of a, as given by cho_factor - b : array - Right-hand side - check_finite : bool, optional - Whether to check that the input matrices contain only finite numbers. - Disabling may give a performance gain, but may result in problems - (crashes, non-termination) if the inputs do contain infinities or NaNs. - """ - - A, lower = c_and_lower - return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) +def cholesky(x, lower=True, on_error="raise"): + return Cholesky(lower=lower, on_error=on_error)(x) class SolveBase(Op): @@ -210,6 +141,7 @@ class SolveBase(Op): def __init__( self, + *, lower=False, check_finite=True, ): @@ -276,28 +208,56 @@ def L_op(self, inputs, outputs, output_gradients): return [A_bar, b_bar] - def __repr__(self): - return f"{type(self).__name__}{self._props()}" + +class CholeskySolve(SolveBase): + def __init__(self, **kwargs): + kwargs.setdefault("lower", True) + super().__init__(**kwargs) + + def perform(self, node, inputs, output_storage): + C, b = inputs + rval = scipy.linalg.cho_solve( + (C, self.lower), + b, + check_finite=self.check_finite, + ) + + output_storage[0][0] = rval + + def L_op(self, *args, **kwargs): + raise NotImplementedError() + + +def cho_solve(c_and_lower, b, *, check_finite=True): + """Solve the linear equations A x = b, given the Cholesky factorization of A. + + Parameters + ---------- + (c, lower) : tuple, (array, bool) + Cholesky factorization of a, as given by cho_factor + b : array + Right-hand side + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + """ + A, lower = c_and_lower + return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) class SolveTriangular(SolveBase): """Solve a system of linear equations.""" __props__ = ( - "lower", "trans", "unit_diagonal", + "lower", "check_finite", ) - def __init__( - self, - trans=0, - lower=False, - unit_diagonal=False, - check_finite=True, - ): - super().__init__(lower=lower, check_finite=check_finite) + def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): + super().__init__(**kwargs) self.trans = trans self.unit_diagonal = unit_diagonal @@ -326,6 +286,7 @@ def L_op(self, inputs, outputs, output_gradients): def solve_triangular( a: TensorVariable, b: TensorVariable, + *, trans: Union[int, str] = 0, lower: bool = False, unit_diagonal: bool = False, @@ -373,16 +334,11 @@ class Solve(SolveBase): "check_finite", ) - def __init__( - self, - assume_a="gen", - lower=False, - check_finite=True, - ): + def __init__(self, *, assume_a="gen", **kwargs): if assume_a not in ("gen", "sym", "her", "pos"): raise ValueError(f"{assume_a} is not a recognized matrix structure") - super().__init__(lower=lower, check_finite=check_finite) + super().__init__(**kwargs) self.assume_a = assume_a def perform(self, node, inputs, outputs): @@ -396,7 +352,7 @@ def perform(self, node, inputs, outputs): ) -def solve(a, b, assume_a="gen", lower=False, check_finite=True): +def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. If the data matrix is known to be a particular type then supplying the diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 223e14fed0..51c1c4b648 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -46,7 +46,7 @@ ], ) def test_Cholesky(x, lower, exc): - g = slinalg.Cholesky(lower)(x) + g = slinalg.Cholesky(lower=lower)(x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) @@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc): ], ) def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower)(A, x) + g = slinalg.Solve(lower=lower)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) @@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc): ], ) def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower)(A, x) + g = slinalg.SolveTriangular(lower=lower)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 1f6ffccdd4..fa3b8844ff 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -361,7 +361,7 @@ def setup_method(self): super().setup_method() def test_repr(self): - assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}" + assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)" def test_infer_shape(self): rng = np.random.default_rng(utt.fetch_seed()) From 4a091fd6e3d77407e31d56c978f70159d5d65ecb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 May 2023 22:41:41 +0200 Subject: [PATCH 5/7] Implement Blockwise Op to vectorize existing Ops Inspired by: https://github.com/aesara-devs/aesara/pull/1215 Co-authored-by: Brandon T. Willard Co-authored-by: Purna Chandra Mansingh Co-authored-by: Sayam Kumar Co-authored-by: Kaustubh --- pytensor/tensor/blockwise.py | 413 +++++++++++++++++++++++ pytensor/tensor/elemwise.py | 73 ++-- pytensor/tensor/random/op.py | 29 +- pytensor/tensor/rewriting/__init__.py | 1 + pytensor/tensor/rewriting/blockwise.py | 41 +++ pytensor/tensor/utils.py | 53 +++ tests/tensor/random/test_op.py | 36 ++ tests/tensor/rewriting/test_blockwise.py | 38 +++ tests/tensor/test_blockwise.py | 258 ++++++++++++++ tests/tensor/test_elemwise.py | 70 +++- 10 files changed, 966 insertions(+), 46 deletions(-) create mode 100644 pytensor/tensor/blockwise.py create mode 100644 pytensor/tensor/rewriting/blockwise.py create mode 100644 tests/tensor/rewriting/test_blockwise.py create mode 100644 tests/tensor/test_blockwise.py diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py new file mode 100644 index 0000000000..859bd68c55 --- /dev/null +++ b/pytensor/tensor/blockwise.py @@ -0,0 +1,413 @@ +import re +from functools import singledispatch +from typing import Any, Dict, List, Optional, Sequence, Tuple, cast + +import numpy as np + +from pytensor import config +from pytensor.gradient import DisconnectedType +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.null_type import NullType +from pytensor.graph.op import Op +from pytensor.tensor import as_tensor_variable +from pytensor.tensor.shape import shape_padleft +from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string +from pytensor.tensor.variable import TensorVariable + + +# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad) + +# Copied verbatim from numpy.lib.function_base +# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029 +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _parse_gufunc_signature(signature): + """ + Parse string signatures for a generalized universal function. + + Arguments + --------- + signature : string + Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` + for ``np.matmul``. + + Returns + ------- + Tuple of input and output core dimensions parsed from the signature, each + of the form List[Tuple[str, ...]]. + """ + signature = re.sub(r"\s+", "", signature) + + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + return tuple( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + + +def safe_signature( + core_inputs: Sequence[Variable], + core_outputs: Sequence[Variable], +) -> str: + def operand_sig(operand: Variable, prefix: str) -> str: + operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) + return f"({operands})" + + inputs_sig = ",".join( + operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) + ) + outputs_sig = ",".join( + operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) + ) + return f"{inputs_sig}->{outputs_sig}" + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + if hasattr(op, "gufunc_signature"): + signature = op.gufunc_signature + else: + # TODO: This is pretty bad for shape inference and merge optimization! + # Should get better as we add signatures to our Ops + signature = safe_signature(node.inputs, node.outputs) + return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + op = node.op + return _vectorize_node(op, node, *batched_inputs) + + +class Blockwise(Op): + """Generalizes a core `Op` to work with batched dimensions. + + TODO: Dispatch JAX (should be easy with the vectorize macro) + TODO: Dispatch Numba + TODO: C implementation? + TODO: Fuse Blockwise? + """ + + __props__ = ("core_op", "signature") + + def __init__( + self, + core_op: Op, + signature: Optional[str] = None, + name: Optional[str] = None, + **kwargs, + ): + """ + + Parameters + ---------- + core_op + An instance of a subclass of `Op` which works on the core case. + signature + Generalized universal function signature, + e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication + + """ + if isinstance(core_op, Blockwise): + raise TypeError("Core Op is already a Blockwise") + + if signature is None: + signature = getattr(core_op, "gufunc_signature", None) + if signature is None: + raise ValueError( + f"Signature not provided nor found in core_op {core_op}" + ) + + self.core_op = core_op + self.signature = signature + self.name = name + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self._gufunc = None + super().__init__(**kwargs) + + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: + core_input_types = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if inp.type.ndim < len(sig): + raise ValueError( + f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" + ) + # ndim_supp = 0 case + if not sig: + core_shape = () + else: + core_shape = inp.type.shape[-len(sig) :] + core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) + + core_node = self.core_op.make_node(*core_input_types) + + if len(core_node.outputs) != len(self.outputs_sig): + raise ValueError( + f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" + ) + for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + if core_out.type.ndim != len(sig): + raise ValueError( + f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" + ) + + return core_node + + def make_node(self, *inputs): + inputs = [as_tensor_variable(i) for i in inputs] + + core_node = self._create_dummy_core_node(inputs) + + batch_ndims = max( + inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + ) + + batched_inputs = [] + batch_shapes = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + # Append missing dims to the left + missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) + if missing_batch_ndims: + inp = shape_padleft(inp, missing_batch_ndims) + batched_inputs.append(inp) + + if not sig: + batch_shapes.append(inp.type.shape) + else: + batch_shapes.append(inp.type.shape[: -len(sig)]) + + try: + batch_shape = tuple( + [ + broadcast_static_dim_lengths(batch_dims) + for batch_dims in zip(*batch_shapes) + ] + ) + except ValueError: + raise ValueError( + f"Incompatible Blockwise batch input shapes {[inp.type.shape for inp in inputs]}" + ) + + batched_outputs = [ + tensor(dtype=core_out.type.dtype, shape=batch_shape + core_out.type.shape) + for core_out in core_node.outputs + ] + + return Apply(self, batched_inputs, batched_outputs) + + def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: + return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + + def infer_shape( + self, fgraph, node, input_shapes + ) -> List[Tuple[TensorVariable, ...]]: + from pytensor.tensor import broadcast_shape + from pytensor.tensor.shape import Shape_i + + batch_ndims = self._batch_ndim_from_outputs(node.outputs) + core_dims: Dict[str, Any] = {} + batch_shapes = [] + for input_shape, sig in zip(input_shapes, self.inputs_sig): + batch_shapes.append(input_shape[:batch_ndims]) + core_shape = input_shape[batch_ndims:] + + for core_dim, dim_name in zip(core_shape, sig): + prev_core_dim = core_dims.get(core_dim) + if prev_core_dim is None: + core_dims[dim_name] = core_dim + # Prefer constants + elif not isinstance(prev_core_dim, Constant): + core_dims[dim_name] = core_dim + + batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + + out_shapes = [] + for output, sig in zip(node.outputs, self.outputs_sig): + core_out_shape = [] + for i, dim_name in enumerate(sig): + # The output dim is the same as another input dim + if dim_name in core_dims: + core_out_shape.append(core_dims[dim_name]) + else: + # TODO: We could try to make use of infer_shape of core_op + core_out_shape.append(Shape_i(batch_ndims + i)(output)) + out_shapes.append((*batch_shape, *core_out_shape)) + + return out_shapes + + def connection_pattern(self, node): + if hasattr(self.core_op, "connection_pattern"): + return self.core_op.connection_pattern(node) + + return [[True for _ in node.outputs] for _ in node.inputs] + + def _bgrad(self, inputs, outputs, ograds): + # Grad, with respect to broadcasted versions of inputs + + def as_core(t, core_t): + # Inputs could be NullType or DisconnectedType + if isinstance(t.type, (NullType, DisconnectedType)): + return t + return core_t.type() + + with config.change_flags(compute_test_value="off"): + safe_inputs = [ + tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) + for inp, sig in zip(inputs, self.inputs_sig) + ] + core_node = self._create_dummy_core_node(safe_inputs) + + core_inputs = [ + as_core(inp, core_inp) + for inp, core_inp in zip(inputs, core_node.inputs) + ] + core_ograds = [ + as_core(ograd, core_ograd) + for ograd, core_ograd in zip(ograds, core_node.outputs) + ] + core_outputs = core_node.outputs + + core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) + + batch_ndims = self._batch_ndim_from_outputs(outputs) + + def transform(var): + # From a graph of ScalarOps, make a graph of Broadcast ops. + if isinstance(var.type, (NullType, DisconnectedType)): + return var + if var in core_inputs: + return inputs[core_inputs.index(var)] + if var in core_outputs: + return outputs[core_outputs.index(var)] + if var in core_ograds: + return ograds[core_ograds.index(var)] + + node = var.owner + + # The gradient contains a constant, which may be responsible for broadcasting + if node is None: + if batch_ndims: + var = shape_padleft(var, batch_ndims) + return var + + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + ret = [] + for core_igrad, ipt in zip(core_igrads, inputs): + # Undefined gradient + if core_igrad is None: + ret.append(None) + else: + ret.append(transform(core_igrad)) + + return ret + + def L_op(self, inputs, outs, ograds): + from pytensor.tensor.math import sum as pt_sum + + # Compute grad with respect to broadcasted input + rval = self._bgrad(inputs, outs, ograds) + + # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable + # to the gradient.grad method when the outputs have + # some integer and some floating point outputs + if any(out.type.dtype not in continuous_dtypes for out in outs): + # For integer output, return value may only be zero or undefined + # We don't bother with trying to check that the scalar ops + # correctly returned something that evaluates to 0, we just make + # the return value obviously zero so that gradient.grad can tell + # this op did the right thing. + new_rval = [] + for elem, inp in zip(rval, inputs): + if isinstance(elem.type, (NullType, DisconnectedType)): + new_rval.append(elem) + else: + elem = inp.zeros_like() + if str(elem.type.dtype) not in continuous_dtypes: + elem = elem.astype(config.floatX) + assert str(elem.type.dtype) not in discrete_dtypes + new_rval.append(elem) + return new_rval + + # Sum out the broadcasted dimensions + batch_ndims = self._batch_ndim_from_outputs(outs) + batch_shape = outs[0].type.shape[:batch_ndims] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if isinstance(rval[i].type, (NullType, DisconnectedType)): + continue + + assert inp.type.ndim == batch_ndims + len(sig) + + to_sum = [ + j + for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + if inp_s == 1 and out_s != 1 + ] + if to_sum: + rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True) + + return rval + + def _create_gufunc(self, node): + if hasattr(self.core_op, "gufunc_spec"): + self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0]) + if self._gufunc: + return self._gufunc + + n_outs = len(self.outputs_sig) + core_node = self._create_dummy_core_node(node.inputs) + + def core_func(*inner_inputs): + inner_outputs = [[None] for _ in range(n_outs)] + + inner_inputs = [np.asarray(inp) for inp in inner_inputs] + self.core_op.perform(core_node, inner_inputs, inner_outputs) + + if len(inner_outputs) == 1: + return inner_outputs[0][0] + else: + return tuple(r[0] for r in inner_outputs) + + self._gufunc = np.vectorize(core_func, signature=self.signature) + return self._gufunc + + def perform(self, node, inputs, output_storage): + gufunc = self._gufunc + + if gufunc is None: + gufunc = self._create_gufunc(node) + + res = gufunc(*inputs) + if not isinstance(res, tuple): + res = (res,) + + for node_out, out_storage, r in zip(node.outputs, output_storage, res): + out_dtype = getattr(node_out, "dtype", None) + if out_dtype and out_dtype != r.dtype: + r = np.asarray(r, dtype=out_dtype) + out_storage[0] = r + + def __str__(self): + if self.name is None: + return f"{type(self).__name__}{{{self.core_op}, {self.signature}}}" + else: + return self.name + + +@_vectorize_node.register(Blockwise) +def vectorize_not_needed(op, node, *batch_inputs): + return op.make_node(*batch_inputs) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 995ec7d45e..a1ff659882 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -22,6 +22,7 @@ from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -29,6 +30,7 @@ float_dtypes, lvector, ) +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string from pytensor.tensor.variable import TensorVariable from pytensor.utils import uniq @@ -232,7 +234,7 @@ def __str__(self): return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" - def perform(self, node, inp, out, params): + def perform(self, node, inp, out, params=None): (res,) = inp (storage,) = out @@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs): # of all inputs in parallel... the all() gives us each output # broadcastable bit in turn. - def get_most_specialized_shape(shapes): - shapes = set(shapes) - # All shapes are the same - if len(shapes) == 1: - return tuple(shapes)[0] - - # Only valid indeterminate case - if shapes == {None, 1}: - return None - - shapes.discard(1) - shapes.discard(None) - if len(shapes) > 1: - raise ValueError - return tuple(shapes)[0] - # it is multiplied by nout because Elemwise supports multiple outputs # (nout of them) try: out_shapes = [ [ - get_most_specialized_shape(shape) + broadcast_static_dim_lengths(shape) for shape in zip(*[inp.type.shape for inp in inputs]) ] ] * shadow.nout @@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): impl = "c" if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = getattr(np, self.nfunc_spec[0], None) - if self.nfunc is None: - # Not inside NumPy. So probably another package like scipy. - symb = self.nfunc_spec[0].split(".") - for idx in range(1, len(self.nfunc_spec[0])): - try: - module = __import__(".".join(symb[:idx])) - except ImportError: - break - for sub in symb[1:]: - try: - module = getattr(module, sub) - except AttributeError: - module = None - break - self.nfunc = module + self.nfunc = import_func_from_string(self.nfunc_spec[0]) if ( (len(node.inputs) + len(node.outputs)) <= 32 @@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var): return get_vector_length(var.owner.inputs[0]) raise ValueError(f"Length of {var} cannot be determined") + + +_vectorize_node.register(Elemwise, vectorize_not_needed) + + +@_vectorize_node.register(DimShuffle) +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) + new_order = list(range(batched_ndims)) + [ + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order + ] + return DimShuffle(input_broadcastable, new_order).make_node(x) + + +@_vectorize_node.register(CAReduce) +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + axes = op.axis + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) + if axes is None: + axes = list(range(node.inputs[0].type.ndim)) + else: + axes = list(axes) + new_axes = [axis + batched_ndims for axis in axes] + new_op = op.clone(axis=new_axes) + return new_op.make_node(x) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 8575114f56..628916f508 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -5,19 +5,25 @@ import pytensor from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Variable +from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( as_tensor_variable, + concatenate, constant, get_underlying_scalar_constant_value, get_vector_length, infer_static_shape, ) +from pytensor.tensor.blockwise import _vectorize_node from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType -from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes +from pytensor.tensor.random.utils import ( + broadcast_params, + normalize_size_param, + params_broadcast_shapes, +) from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type_other import NoneConst @@ -383,3 +389,22 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor): default_rng = DefaultGeneratorMakerOp() + + +@_vectorize_node.register(RandomVariable) +def vectorize_random_variable( + op: RandomVariable, node: Apply, rng, size, dtype, *dist_params +) -> Apply: + # If size was provided originally and a new size hasn't been provided, + # We extend it to accommodate the new input batch dimensions. + # Otherwise, we assume the new size already has the right values + old_size = node.inputs[1] + len_old_size = get_vector_length(old_size) + if len_old_size and equal_computations([old_size], [size]): + bcasted_param = broadcast_params(dist_params, op.ndims_params)[0] + new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size + if new_param_ndim >= 0: + new_size_dims = bcasted_param.shape[:new_param_ndim] + size = concatenate([new_size_dims, size]) + + return op.make_node(rng, size, dtype, *dist_params) diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 80946d524c..617eab04fa 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -2,6 +2,7 @@ import pytensor.tensor.rewriting.blas import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_scipy +import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py new file mode 100644 index 0000000000..c85fba3815 --- /dev/null +++ b/pytensor/tensor/rewriting/blockwise.py @@ -0,0 +1,41 @@ +from pytensor.compile.mode import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.tensor.blockwise import Blockwise, vectorize_node + + +@node_rewriter([Blockwise]) +def local_useless_blockwise(fgraph, node): + """ + If there is a dispatch implementation that does not require Blockwise, use that instead. + This means a user created a Blockwise manually when there was no need. + + Note: This rewrite is not registered by default anywhere + """ + op = node.op + inputs = node.inputs + dummy_core_node = op._create_dummy_core_node(node.inputs) + vect_node = vectorize_node(dummy_core_node, *inputs) + if not isinstance(vect_node.op, Blockwise): + return copy_stack_trace(node.outputs, vect_node.outputs) + + +@node_rewriter([Blockwise]) +def local_useless_unbatched_blockwise(fgraph, node): + """Remove Blockwise that don't have any batched dims.""" + op = node.op + inputs = node.inputs + + if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0: + return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs) + + +# We register this rewrite late, so that other rewrites need only target Blockwise Ops +optdb.register( + "local_useless_unbatched_blockwise", + out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), + "fast_run", + "fast_compile", + "blockwise", + position=49, +) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 7535f47c5c..2150587180 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,3 +1,5 @@ +from typing import Sequence, Union + import numpy as np import pytensor @@ -107,3 +109,54 @@ def as_list(x): return list(x) except TypeError: return [x] + + +def import_func_from_string(func_string: str): # -> Optional[Callable]: + func = getattr(np, func_string, None) + if func is not None: + return func + + # Not inside NumPy or Scipy. So probably another package like scipy. + module = None + items = func_string.split(".") + for idx in range(1, len(items)): + try: + module = __import__(".".join(items[:idx])) + except ImportError: + break + + if module: + for sub in items[1:]: + try: + module = getattr(module, sub) + except AttributeError: + module = None + break + return module + + +def broadcast_static_dim_lengths( + dim_lengths: Sequence[Union[int, None]] +) -> Union[int, None]: + """Apply static broadcast given static dim length of inputs (obtained from var.type.shape). + + Raises + ------ + ValueError + When static dim lengths are incompatible + """ + + dim_lengths_set = set(dim_lengths) + # All dim_lengths are the same + if len(dim_lengths_set) == 1: + return tuple(dim_lengths_set)[0] + + # Only valid indeterminate case + if dim_lengths_set == {None, 1}: + return None + + dim_lengths_set.discard(1) + dim_lengths_set.discard(None) + if len(dim_lengths_set) > 1: + raise ValueError + return tuple(dim_lengths_set)[0] diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 0eec50e5a6..0bc8f0a73f 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -5,7 +5,9 @@ from pytensor import config, function from pytensor.gradient import NullTypeGradError, grad from pytensor.raise_op import Assert +from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.math import eq +from pytensor.tensor.random import normal from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import all_dtypes, iscalar, tensor @@ -202,3 +204,37 @@ def test_RandomVariable_incompatible_size(): ValueError, match="Size length is incompatible with batched dimensions" ): rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) + + +def test_vectorize_node(): + vec = tensor(shape=(None,)) + vec.tag.test_value = [0, 0, 0] + mat = tensor(shape=(None, None)) + mat.tag.test_value = [[0, 0, 0], [1, 1, 1]] + + # Test without size + node = normal(vec).owner + new_inputs = node.inputs.copy() + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + + # Test with size, new size provided + node = normal(vec, size=(3,)).owner + new_inputs = node.inputs.copy() + new_inputs[1] = (2, 3) + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert tuple(vect_node.inputs[1].eval()) == (2, 3) + assert vect_node.inputs[3] is mat + + # Test with size, new size not provided + node = normal(vec, size=(3,)).owner + new_inputs = node.inputs.copy() + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3) diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py new file mode 100644 index 0000000000..0b67eba197 --- /dev/null +++ b/tests/tensor/rewriting/test_blockwise.py @@ -0,0 +1,38 @@ +from pytensor import function +from pytensor.graph import FunctionGraph +from pytensor.scalar import log as scalar_log +from pytensor.tensor import matrix, tensor3 +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.nlinalg import MatrixPinv +from pytensor.tensor.rewriting.blockwise import local_useless_blockwise + + +def test_useless_blockwise_of_elemwise(): + x = matrix("x") + out = Blockwise(Elemwise(scalar_log), signature="()->()")(x) + assert isinstance(out.owner.op, Blockwise) + assert isinstance(out.owner.op.core_op, Elemwise) + + fg = FunctionGraph([x], [out], clone=False) + [new_out] = local_useless_blockwise.transform(fg, out.owner) + assert isinstance(new_out.owner.op, Elemwise) + + +def test_useless_unbatched_blockwise(): + x = matrix("x") + blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)") + out = blockwise_op(x) + + assert isinstance(out.owner.op, Blockwise) + assert isinstance(out.owner.op.core_op, MatrixPinv) + + fn = function([x], out, mode="FAST_COMPILE") + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv) + + # Test that it's not removed when there are batched dims + x = tensor3("x") + out = blockwise_op(x) + fn = function([x], out, mode="FAST_COMPILE") + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) + assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py new file mode 100644 index 0000000000..437e3bbc22 --- /dev/null +++ b/tests/tensor/test_blockwise.py @@ -0,0 +1,258 @@ +from itertools import product +from typing import Optional, Tuple, Union + +import numpy as np +import pytest + +import pytensor +from pytensor import config +from pytensor.gradient import grad +from pytensor.graph import Apply, Op +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.slinalg import Cholesky, Solve + + +def test_vectorize_blockwise(): + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + # Something that falls back to Blockwise + node = MatrixInverse()(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, MatrixInverse + ) + assert vect_node.inputs[0] is tns + + # Useless blockwise + tns4 = tensor(shape=(5, None, None, None)) + new_vect_node = vectorize_node(vect_node, tns4) + assert new_vect_node.op is vect_node.op + assert isinstance(new_vect_node.op, Blockwise) and isinstance( + new_vect_node.op.core_op, MatrixInverse + ) + assert new_vect_node.inputs[0] is tns4 + + +class TestOp(Op): + def make_node(self, *inputs): + return Apply(self, inputs, [i.type() for i in inputs]) + + def perform(self, *args, **kwargs): + raise NotImplementedError("Test Op should not be present in final graph") + + +test_op = TestOp() + + +def test_vectorize_node_default_signature(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(5, None)) + node = test_op.make_node(vec, mat) + + vect_node = vectorize_node(node, mat, mat) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, TestOp + ) + assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") + + with pytest.raises( + ValueError, match="Signature not provided nor found in core_op TestOp" + ): + Blockwise(test_op) + + vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + + +def test_blockwise_shape(): + # Single output + inp = tensor(shape=(5, None, None)) + inp_test = np.zeros((5, 4, 3), dtype=config.floatX) + + # Shape can be inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (n, m)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 3, 4) + + # Shape can only be partially inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (m, k)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp], out.shape[:-1]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 4) + + # Mutiple outputs + inp1 = tensor(shape=(7, 1, None, None)) + inp2 = tensor(shape=(1, 5, None, None)) + inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) + inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) + + op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") + outs = op(inp1, inp2) + assert outs[0].type.shape == (7, 5, None, None) + assert outs[1].type.shape == (7, 5, None, None) + + shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp1, inp2], outs[0].shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) + + shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) + assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) + + +class BlockwiseOpTester: + """Base class to test Blockwise works for specific Ops""" + + core_op = None + signature = None + batcheable_axes = None + + @classmethod + def setup_class(cls): + seed = sum(map(ord, str(cls.core_op))) + cls.rng = np.random.default_rng(seed) + cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature) + if cls.batcheable_axes is None: + cls.batcheable_axes = list(range(len(cls.params_sig))) + batch_shapes = [(), (1,), (5,), (1, 1), (1, 5), (3, 1), (3, 5)] + cls.test_batch_shapes = list( + product(batch_shapes, repeat=len(cls.batcheable_axes)) + ) + cls.block_op = Blockwise(core_op=cls.core_op, signature=cls.signature) + + @staticmethod + def parse_shape(shape: Tuple[Union[str, int], ...]) -> Tuple[int, ...]: + """ + Convert (5, "m", "n") -> (5, 7, 11) + """ + mapping = {"m": 7, "n": 11, "k": 19} + return tuple(mapping.get(p, p) for p in shape) + + def create_testvals(self, shape): + return self.rng.normal(size=self.parse_shape(shape)).astype(config.floatX) + + def create_batched_inputs(self, batch_idx: Optional[int] = None): + for batch_shapes in self.test_batch_shapes: + vec_inputs = [] + vec_inputs_testvals = [] + for idx, (batch_shape, param_sig) in enumerate( + zip(batch_shapes, self.params_sig) + ): + if batch_idx is not None and idx != batch_idx: + # Skip out combinations in which other inputs are batched + if batch_shape != (): + break + vec_inputs.append(tensor(shape=batch_shape + (None,) * len(param_sig))) + vec_inputs_testvals.append( + self.create_testvals(shape=batch_shape + param_sig) + ) + else: # no-break + yield vec_inputs, vec_inputs_testvals + + def test_perform(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + core_func = pytensor.function(base_inputs, self.core_op(*base_inputs)) + np_func = np.vectorize(core_func, signature=self.signature) + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + pt_func = pytensor.function(vec_inputs, self.block_op(*vec_inputs)) + if len(self.outputs_sig) != 1: + raise NotImplementedError("Did not implement test for multi-output Ops") + np.testing.assert_allclose( + pt_func(*vec_inputs_testvals), + np_func(*vec_inputs_testvals), + ) + + def test_grad(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + out = self.core_op(*base_inputs).sum() + # Create separate numpy vectorized functions for each input + np_funcs = [] + for i, inp in enumerate(base_inputs): + core_grad_func = pytensor.function(base_inputs, grad(out, wrt=inp)) + params_sig = self.signature.split("->")[0] + param_sig = f"({','.join(self.params_sig[i])})" + grad_sig = f"{params_sig}->{param_sig}" + np_func = np.vectorize(core_grad_func, signature=grad_sig) + np_funcs.append(np_func) + + # We test gradient wrt to one batched input at a time + for test_input_idx in range(len(base_inputs)): + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs( + batch_idx=test_input_idx + ): + out = self.block_op(*vec_inputs).sum() + pt_func = pytensor.function( + vec_inputs, grad(out, wrt=vec_inputs[test_input_idx]) + ) + pt_out = pt_func(*vec_inputs_testvals) + np_out = np_funcs[test_input_idx](*vec_inputs_testvals) + np.testing.assert_allclose(pt_out, np_out, atol=1e-6) + + +class MatrixOpBlockwiseTester(BlockwiseOpTester): + def create_testvals(self, shape): + # Return a posdef matrix + X = super().create_testvals(shape) + return np.einsum("...ij,...kj->...ik", X, X) + + +class TestCholesky(MatrixOpBlockwiseTester): + core_op = Cholesky(lower=True) + signature = "(m, m) -> (m, m)" + + +class TestMatrixInverse(MatrixOpBlockwiseTester): + core_op = MatrixInverse() + signature = "(m, m) -> (m, m)" + + +class TestSolve(BlockwiseOpTester): + core_op = Solve(lower=True) + signature = "(m, m),(m) -> (m)" diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 9f0c3a5976..1346744b6d 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -17,10 +17,13 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second +from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import all as at_all -from pytensor.tensor.math import any as at_any +from pytensor.tensor.math import Any, Sum +from pytensor.tensor.math import all as pt_all +from pytensor.tensor.math import any as pt_any from pytensor.tensor.math import exp +from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( TensorType, bmatrix, @@ -470,12 +473,12 @@ def with_mode( axis2.append(a) assert len(axis2) == len(tosum) tosum = tuple(axis2) - if tensor_op == at_all: + if tensor_op == pt_all: for axis in sorted(tosum, reverse=True): zv = np.all(zv, axis) if len(tosum) == 0: zv = zv != 0 - elif tensor_op == at_any: + elif tensor_op == pt_any: for axis in sorted(tosum, reverse=True): zv = np.any(zv, axis) if len(tosum) == 0: @@ -553,8 +556,8 @@ def test_perform(self): self.with_mode(Mode(linker="py"), aes.mul, dtype=dtype) self.with_mode(Mode(linker="py"), aes.scalar_maximum, dtype=dtype) self.with_mode(Mode(linker="py"), aes.scalar_minimum, dtype=dtype) - self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype, tensor_op=at_all) - self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype, tensor_op=at_any) + self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype, tensor_op=pt_all) + self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["int8", "uint8"]: self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype) self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype) @@ -575,14 +578,14 @@ def test_perform_nan(self): aes.or_, dtype=dtype, test_nan=True, - tensor_op=at_any, + tensor_op=pt_any, ) self.with_mode( Mode(linker="py"), aes.and_, dtype=dtype, test_nan=True, - tensor_op=at_all, + tensor_op=pt_all, ) @pytest.mark.skipif( @@ -606,8 +609,8 @@ def test_c(self): for dtype in ["bool", "floatX", "int8", "uint8"]: self.with_mode(Mode(linker="c"), aes.scalar_minimum, dtype=dtype) self.with_mode(Mode(linker="c"), aes.scalar_maximum, dtype=dtype) - self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype, tensor_op=at_all) - self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype, tensor_op=at_any) + self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype, tensor_op=pt_all) + self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["bool", "int8", "uint8"]: self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype) self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype) @@ -915,3 +918,50 @@ def grad(self, inputs, gout): # Verify that trying to use the not implemented gradient fails. with pytest.raises(pytensor.gradient.NullTypeGradError): pytensor.gradient.grad(test_op(x, 2), x) + + +class TestVectorize: + def test_elemwise(self): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + def test_dimshuffle(self): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + col_mat = tensor(shape=(None, 1)) + tcol_mat = tensor(shape=(None, None, 1)) + node = col_mat.dimshuffle(0).owner # drop column + vect_node = vectorize_node(node, tcol_mat) + assert isinstance(vect_node.op, DimShuffle) + assert vect_node.op.new_order == (0, 1) + assert vect_node.inputs[0] is tcol_mat + assert vect_node.outputs[0].type.shape == (None, None) + + def test_CAReduce(self): + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + node = pt_sum(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Sum) + assert vect_node.op.axis == (1, 2) + assert vect_node.inputs[0] is tns + + bool_mat = tensor(dtype="bool", shape=(None, None)) + bool_tns = tensor(dtype="bool", shape=(None, None, None)) + node = pt_any(bool_mat, axis=-2).owner + vect_node = vectorize_node(node, bool_tns) + assert isinstance(vect_node.op, Any) + assert vect_node.op.axis == (1,) + assert vect_node.inputs[0] is bool_tns From 9b71e4b5e617f4741a94ec556de0d1884463b0c0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Jun 2023 12:43:23 +0200 Subject: [PATCH 6/7] Blockwise some linalg Ops by default --- pytensor/tensor/basic.py | 2 +- pytensor/tensor/nlinalg.py | 20 ++- pytensor/tensor/rewriting/linalg.py | 194 ++++++++++++++++---------- pytensor/tensor/slinalg.py | 89 +++++++++--- tests/link/numba/test_nlinalg.py | 4 +- tests/tensor/rewriting/test_linalg.py | 21 ++- tests/tensor/test_blockwise.py | 10 +- tests/tensor/test_slinalg.py | 32 +++-- 8 files changed, 243 insertions(+), 129 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index c21b33720f..95d9d7a03e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3764,7 +3764,7 @@ def stacklists(arg): return arg -def swapaxes(y, axis1, axis2): +def swapaxes(y, axis1: int, axis2: int) -> TensorVariable: "Swap the axes of a tensor." y = as_tensor_variable(y) ndim = y.ndim diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 32fa47d28d..5e0fc2d580 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -10,11 +10,13 @@ from pytensor.tensor import basic as at from pytensor.tensor import math as tm from pytensor.tensor.basic import as_tensor_variable, extract_diag +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector class MatrixPinv(Op): __props__ = ("hermitian",) + gufunc_signature = "(m,n)->(n,m)" def __init__(self, hermitian): self.hermitian = hermitian @@ -75,7 +77,7 @@ def pinv(x, hermitian=False): solve op. """ - return MatrixPinv(hermitian=hermitian)(x) + return Blockwise(MatrixPinv(hermitian=hermitian))(x) class MatrixInverse(Op): @@ -93,6 +95,8 @@ class MatrixInverse(Op): """ __props__ = () + gufunc_signature = "(m,m)->(m,m)" + gufunc_spec = ("numpy.linalg.inv", 1, 1) def __init__(self): pass @@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes): return shapes -inv = matrix_inverse = MatrixInverse() +inv = matrix_inverse = Blockwise(MatrixInverse()) def matrix_dot(*args): @@ -181,6 +185,8 @@ class Det(Op): """ __props__ = () + gufunc_signature = "(m,m)->()" + gufunc_spec = ("numpy.linalg.det", 1, 1) def make_node(self, x): x = as_tensor_variable(x) @@ -209,7 +215,7 @@ def __str__(self): return "Det" -det = Det() +det = Blockwise(Det()) class SLogDet(Op): @@ -218,6 +224,8 @@ class SLogDet(Op): """ __props__ = () + gufunc_signature = "(m, m)->(),()" + gufunc_spec = ("numpy.linalg.slogdet", 1, 2) def make_node(self, x): x = as_tensor_variable(x) @@ -242,7 +250,7 @@ def __str__(self): return "SLogDet" -slogdet = SLogDet() +slogdet = Blockwise(SLogDet()) class Eig(Op): @@ -252,6 +260,8 @@ class Eig(Op): """ __props__: Tuple[str, ...] = () + gufunc_signature = "(m,m)->(m),(m,m)" + gufunc_spec = ("numpy.linalg.eig", 1, 2) def make_node(self, x): x = as_tensor_variable(x) @@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes): return [(n,), (n, n)] -eig = Eig() +eig = Blockwise(Eig()) class Eigh(Eig): diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 103feb07b6..9589c7aa79 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,32 +1,58 @@ import logging +from typing import cast from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.tensor import basic as at +from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes from pytensor.tensor.blas import Dot22 +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import Dot, Prod, log, prod -from pytensor.tensor.nlinalg import Det, MatrixInverse +from pytensor.tensor.nlinalg import MatrixInverse, det from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, ) -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve +from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular logger = logging.getLogger(__name__) +def is_matrix_transpose(x: TensorVariable) -> bool: + """Check if a variable corresponds to a transpose of the last two axes""" + node = x.owner + if ( + node + and isinstance(node.op, DimShuffle) + and not (node.op.drop or node.op.augment) + ): + [inp] = node.inputs + ndims = inp.type.ndim + if ndims < 2: + return False + transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2) + return cast(bool, node.op.new_order == transpose_order) + return False + + +def _T(x: TensorVariable) -> TensorVariable: + """Matrix transpose for potentially higher dimensionality tensors""" + return swapaxes(x, -1, -2) + + @register_canonicalize @node_rewriter([DimShuffle]) def transinv_to_invtrans(fgraph, node): - if isinstance(node.op, DimShuffle): - if node.op.new_order == (1, 0): - (A,) = node.inputs - if A.owner: - if isinstance(A.owner.op, MatrixInverse): - (X,) = A.owner.inputs - return [A.owner.op(node.op(X))] + if is_matrix_transpose(node.outputs[0]): + (A,) = node.inputs + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, MatrixInverse) + ): + (X,) = A.owner.inputs + return [A.owner.op(node.op(X))] @register_stabilize @@ -37,43 +63,72 @@ def inv_as_solve(fgraph, node): """ if isinstance(node.op, (Dot, Dot22)): l, r = node.inputs - if l.owner and isinstance(l.owner.op, MatrixInverse): + if ( + l.owner + and isinstance(l.owner.op, Blockwise) + and isinstance(l.owner.op.core_op, MatrixInverse) + ): return [solve(l.owner.inputs[0], r)] - if r.owner and isinstance(r.owner.op, MatrixInverse): + if ( + r.owner + and isinstance(r.owner.op, Blockwise) + and isinstance(r.owner.op.core_op, MatrixInverse) + ): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: - return [solve(x, l.T).T] + return [_T(solve(x, _T(l)))] else: - return [solve(x.T, l.T).T] + return [_T(solve(_T(x), _T(l)))] @register_stabilize @register_canonicalize -@node_rewriter([Solve]) +@node_rewriter([Blockwise]) def generic_solve_to_solve_triangular(fgraph, node): """ If any solve() is applied to the output of a cholesky op, then replace it with a triangular solve. """ - if isinstance(node.op, Solve): - A, b = node.inputs # result is solution Ax=b - if A.owner and isinstance(A.owner.op, Cholesky): - if A.owner.op.lower: - return [SolveTriangular(lower=True)(A, b)] - else: - return [SolveTriangular(lower=False)(A, b)] - if ( - A.owner - and isinstance(A.owner.op, DimShuffle) - and A.owner.op.new_order == (1, 0) - ): - (A_T,) = A.owner.inputs - if A_T.owner and isinstance(A_T.owner.op, Cholesky): - if A_T.owner.op.lower: - return [SolveTriangular(lower=False)(A, b)] + if isinstance(node.op.core_op, Solve): + if node.op.core_op.assume_a == "gen": + A, b = node.inputs # result is solution Ax=b + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, Cholesky) + ): + if A.owner.op.core_op.lower: + return [ + solve_triangular( + A, b, lower=True, b_ndim=node.op.core_op.b_ndim + ) + ] else: - return [SolveTriangular(lower=True)(A, b)] + return [ + solve_triangular( + A, b, lower=False, b_ndim=node.op.core_op.b_ndim + ) + ] + if is_matrix_transpose(A): + (A_T,) = A.owner.inputs + if ( + A_T.owner + and isinstance(A_T.owner.op, Blockwise) + and isinstance(A_T.owner.op, Cholesky) + ): + if A_T.owner.op.lower: + return [ + solve_triangular( + A, b, lower=False, b_ndim=node.op.core_op.b_ndim + ) + ] + else: + return [ + solve_triangular( + A, b, lower=True, b_ndim=node.op.core_op.b_ndim + ) + ] @register_canonicalize @@ -81,34 +136,33 @@ def generic_solve_to_solve_triangular(fgraph, node): @register_specialize @node_rewriter([DimShuffle]) def no_transpose_symmetric(fgraph, node): - if isinstance(node.op, DimShuffle): + if is_matrix_transpose(node.outputs[0]): x = node.inputs[0] - if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True: - if node.op.new_order == [1, 0]: - return [x] + if getattr(x.tag, "symmetric", None): + return [x] @register_stabilize -@node_rewriter([Solve]) +@node_rewriter([Blockwise]) def psd_solve_with_chol(fgraph, node): """ This utilizes a boolean `psd` tag on matrices. """ - if isinstance(node.op, Solve): + if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2: A, b = node.inputs # result is solution Ax=b if getattr(A.tag, "psd", None) is True: L = cholesky(A) # N.B. this can be further reduced to a yet-unwritten cho_solve Op - # __if__ no other Op makes use of the the L matrix during the + # __if__ no other Op makes use of the L matrix during the # stabilization - Li_b = Solve(assume_a="sym", lower=True)(L, b) - x = Solve(assume_a="sym", lower=False)(L.T, Li_b) + Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2) + x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2) return [x] @register_canonicalize @register_stabilize -@node_rewriter([Cholesky]) +@node_rewriter([Blockwise]) def cholesky_ldotlt(fgraph, node): """ rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, @@ -116,7 +170,7 @@ def cholesky_ldotlt(fgraph, node): This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. """ - if not isinstance(node.op, Cholesky): + if not isinstance(node.op.core_op, Cholesky): return A = node.inputs[0] @@ -128,45 +182,40 @@ def cholesky_ldotlt(fgraph, node): # cholesky(dot(L,L.T)) case if ( getattr(l.tag, "lower_triangular", False) - and r.owner - and isinstance(r.owner.op, DimShuffle) - and r.owner.op.new_order == (1, 0) + and is_matrix_transpose(r) and r.owner.inputs[0] == l ): - if node.op.lower: + if node.op.core_op.lower: return [l] return [r] # cholesky(dot(U.T,U)) case if ( getattr(r.tag, "upper_triangular", False) - and l.owner - and isinstance(l.owner.op, DimShuffle) - and l.owner.op.new_order == (1, 0) + and is_matrix_transpose(l) and l.owner.inputs[0] == r ): - if node.op.lower: + if node.op.core_op.lower: return [l] return [r] @register_stabilize @register_specialize -@node_rewriter([Det]) +@node_rewriter([det]) def local_det_chol(fgraph, node): """ If we have det(X) and there is already an L=cholesky(X) floating around, then we can use prod(diag(L)) to get the determinant. """ - if isinstance(node.op, Det): - (x,) = node.inputs - for cl, xpos in fgraph.clients[x]: - if cl == "output": - continue - if isinstance(cl.op, Cholesky): - L = cl.outputs[0] - return [prod(at.extract_diag(L) ** 2)] + (x,) = node.inputs + for cl, xpos in fgraph.clients[x]: + if cl == "output": + continue + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): + L = cl.outputs[0] + return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] @register_canonicalize @@ -177,16 +226,15 @@ def local_log_prod_sqr(fgraph, node): """ This utilizes a boolean `positive` tag on matrices. """ - if node.op == log: - (x,) = node.inputs - if x.owner and isinstance(x.owner.op, Prod): - # we cannot always make this substitution because - # the prod might include negative terms - p = x.owner.inputs[0] - - # p is the matrix we're reducing with prod - if getattr(p.tag, "positive", None) is True: - return [log(p).sum(axis=x.owner.op.axis)] - - # TODO: have a reduction like prod and sum that simply - # returns the sign of the prod multiplication. + (x,) = node.inputs + if x.owner and isinstance(x.owner.op, Prod): + # we cannot always make this substitution because + # the prod might include negative terms + p = x.owner.inputs[0] + + # p is the matrix we're reducing with prod + if getattr(p.tag, "positive", None) is True: + return [log(p).sum(axis=x.owner.op.axis)] + + # TODO: have a reduction like prod and sum that simply + # returns the sign of the prod multiplication. diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 68fac3e90b..6e24c56f55 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,7 +1,7 @@ import logging import typing import warnings -from typing import TYPE_CHECKING, Literal, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np import scipy.linalg @@ -13,6 +13,7 @@ from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at from pytensor.tensor import math as atm +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.nlinalg import matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector @@ -48,6 +49,7 @@ class Cholesky(Op): # TODO: LAPACK wrapper with in-place behavior, for solve also __props__ = ("lower", "destructive", "on_error") + gufunc_signature = "(m,m)->(m,m)" def __init__(self, *, lower=True, on_error="raise"): self.lower = lower @@ -109,7 +111,7 @@ def tril_and_halve_diagonal(mtx): def conjugate_solve_triangular(outer, inner): """Computes L^{-T} P L^{-1} for lower-triangular L.""" - solve_upper = SolveTriangular(lower=False) + solve_upper = SolveTriangular(lower=False, b_ndim=2) return solve_upper(outer.T, solve_upper(outer.T, inner.T).T) s = conjugate_solve_triangular( @@ -128,7 +130,7 @@ def conjugate_solve_triangular(outer, inner): def cholesky(x, lower=True, on_error="raise"): - return Cholesky(lower=lower, on_error=on_error)(x) + return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) class SolveBase(Op): @@ -137,6 +139,7 @@ class SolveBase(Op): __props__ = ( "lower", "check_finite", + "b_ndim", ) def __init__( @@ -144,9 +147,16 @@ def __init__( *, lower=False, check_finite=True, + b_ndim, ): self.lower = lower self.check_finite = check_finite + assert b_ndim in (1, 2) + self.b_ndim = b_ndim + if b_ndim == 1: + self.gufunc_signature = "(m,m),(m)->(m)" + else: + self.gufunc_signature = "(m,m),(m,n)->(m,n)" def perform(self, node, inputs, outputs): pass @@ -157,8 +167,8 @@ def make_node(self, A, b): if A.ndim != 2: raise ValueError(f"`A` must be a matrix; got {A.type} instead.") - if b.ndim not in (1, 2): - raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.") + if b.ndim != self.b_ndim: + raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.") # Infer dtype by solving the most simple case with 1x1 matrices o_dtype = scipy.linalg.solve( @@ -209,6 +219,16 @@ def L_op(self, inputs, outputs, output_gradients): return [A_bar, b_bar] +def _default_b_ndim(b, b_ndim): + if b_ndim is not None: + assert b_ndim in (1, 2) + return b_ndim + + b = as_tensor_variable(b) + if b_ndim is None: + return min(b.ndim, 2) # By default assume the core case is a matrix + + class CholeskySolve(SolveBase): def __init__(self, **kwargs): kwargs.setdefault("lower", True) @@ -228,7 +248,7 @@ def L_op(self, *args, **kwargs): raise NotImplementedError() -def cho_solve(c_and_lower, b, *, check_finite=True): +def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None): """Solve the linear equations A x = b, given the Cholesky factorization of A. Parameters @@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ A, lower = c_and_lower - return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim) + )(A, b) class SolveTriangular(SolveBase): @@ -254,6 +280,7 @@ class SolveTriangular(SolveBase): "unit_diagonal", "lower", "check_finite", + "b_ndim", ) def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): @@ -291,6 +318,7 @@ def solve_triangular( lower: bool = False, unit_diagonal: bool = False, check_finite: bool = True, + b_ndim: Optional[int] = None, ) -> TensorVariable: """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix. @@ -314,12 +342,19 @@ def solve_triangular( Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ - return SolveTriangular( - lower=lower, - trans=trans, - unit_diagonal=unit_diagonal, - check_finite=check_finite, + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + SolveTriangular( + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + check_finite=check_finite, + b_ndim=b_ndim, + ) )(a, b) @@ -332,6 +367,7 @@ class Solve(SolveBase): "assume_a", "lower", "check_finite", + "b_ndim", ) def __init__(self, *, assume_a="gen", **kwargs): @@ -352,7 +388,15 @@ def perform(self, node, inputs, outputs): ) -def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): +def solve( + a, + b, + *, + assume_a="gen", + lower=False, + check_finite=True, + b_ndim: Optional[int] = None, +): """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. If the data matrix is known to be a particular type then supplying the @@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): Parameters ---------- - a : (N, N) array_like + a : (..., N, N) array_like Square input data - b : (N, NRHS) array_like + b : (..., N, NRHS) array_like Input data for the right hand side. lower : bool, optional If True, only the data contained in the lower triangle of `a`. Default @@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): (crashes, non-termination) if the inputs do contain infinities or NaNs. assume_a : str, optional Valid entries are explained above. + b_ndim : int + Whether the core case of b is a vector (1) or matrix (2). + This will influence how batched dimensions are interpreted. """ - return Solve( - lower=lower, - check_finite=check_finite, - assume_a=assume_a, + b_ndim = _default_b_ndim(b, b_ndim) + return Blockwise( + Solve( + lower=lower, + check_finite=check_finite, + assume_a=assume_a, + b_ndim=b_ndim, + ) )(a, b) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 51c1c4b648..857bd49152 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc): ], ) def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower=lower)(A, x) + g = slinalg.Solve(lower=lower, b_ndim=1)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) @@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc): ], ) def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower=lower)(A, x) + g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x) if isinstance(g, list): g_fg = FunctionGraph(outputs=g) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 3a09244a40..58ea98d626 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -9,11 +9,12 @@ from pytensor import tensor as at from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.rewriting.linalg import inv_as_solve -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve from pytensor.tensor.type import dmatrix, matrix, vector from tests import unittest_tools as utt from tests.test_rop import break_op @@ -23,7 +24,7 @@ def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") v = vector("v") - y = matrix_inverse(mx).sum(axis=0) + y = MatrixInverse()(mx).sum(axis=0) yv = pytensor.gradient.Rop(y, mx, mv) rop_f = function([mx, mv], yv) @@ -83,13 +84,11 @@ def test_transinv_to_invtrans(): def test_generic_solve_to_solve_triangular(): - cholesky_lower = Cholesky(lower=True) - cholesky_upper = Cholesky(lower=False) A = matrix("A") x = matrix("x") - L = cholesky_lower(A) - U = cholesky_upper(A) + L = cholesky(A, lower=True) + U = cholesky(A, lower=False) b1 = solve(L, x) b2 = solve(U, x) f = pytensor.function([A, x], b1) @@ -130,15 +129,15 @@ def test_matrix_inverse_solve(): b = dmatrix("b") node = matrix_inverse(A).dot(b).owner [out] = inv_as_solve.transform(None, node) - assert isinstance(out.owner.op, Solve) + assert isinstance(out.owner.op, Blockwise) and isinstance( + out.owner.op.core_op, Solve + ) @pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("product", ("lower", "upper", None)) def test_cholesky_ldotlt(tag, cholesky_form, product): - cholesky = Cholesky(lower=(cholesky_form == "lower")) - transform_removes_chol = tag is not None and product == tag transform_transposes = transform_removes_chol and cholesky_form != tag @@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): else: M = A - C = cholesky(M) + C = cholesky(M, lower=(cholesky_form == "lower")) f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) - print(f.maker.fgraph.apply_nodes) - no_cholesky_in_graph = not any( isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 437e3bbc22..658c527430 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -24,6 +24,7 @@ def test_vectorize_blockwise(): assert isinstance(vect_node.op, Blockwise) and isinstance( vect_node.op.core_op, MatrixInverse ) + assert vect_node.op.signature == ("(m,m)->(m,m)") assert vect_node.inputs[0] is tns # Useless blockwise @@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester): signature = "(m, m) -> (m, m)" -class TestSolve(BlockwiseOpTester): - core_op = Solve(lower=True) +class TestSolveVector(BlockwiseOpTester): + core_op = Solve(lower=True, b_ndim=1) signature = "(m, m),(m) -> (m)" + + +class TestSolveMatrix(BlockwiseOpTester): + core_op = Solve(lower=True, b_ndim=2) + signature = "(m, m),(m, n) -> (m, n)" diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index fa3b8844ff..504d848140 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester): ( matrix, functools.partial(tensor, dtype="floatX", shape=(None,) * 3), - "`b` must be a matrix or a vector.*", + "`b` must have 2 dims.*", ), ], ) @@ -190,20 +190,20 @@ def test_make_node(self, A_func, b_func, error_message): with pytest.raises(ValueError, match=error_message): A = A_func() b = b_func() - SolveBase()(A, b) + SolveBase(b_ndim=2)(A, b) def test__repr__(self): np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = SolveBase()(A, b) - assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0" + y = SolveBase(b_ndim=2)(A, b) + assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0" class TestSolve(utt.InferShapeTester): def test__init__(self): with pytest.raises(ValueError) as excinfo: - Solve(assume_a="test") + Solve(assume_a="test", b_ndim=2) assert "is not a recognized matrix structure" in str(excinfo.value) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @@ -278,7 +278,7 @@ def test_solve_grad(self, m, n, assume_a, lower): if config.floatX == "float64": eps = 2e-8 - solve_op = Solve(assume_a=assume_a, lower=lower) + solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) @@ -349,19 +349,20 @@ def test_solve_grad(self, m, n, lower): if config.floatX == "float64": eps = 2e-8 - solve_op = SolveTriangular(lower=lower) + solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) class TestCholeskySolve(utt.InferShapeTester): def setup_method(self): self.op_class = CholeskySolve - self.op = CholeskySolve() - self.op_upper = CholeskySolve(lower=False) super().setup_method() def test_repr(self): - assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)" + assert ( + repr(CholeskySolve(lower=True, b_ndim=1)) + == "CholeskySolve(lower=True,check_finite=True,b_ndim=1)" + ) def test_infer_shape(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -369,7 +370,7 @@ def test_infer_shape(self): b = matrix() self._compile_and_check( [A, b], # pytensor.function inputs - [self.op(A, b)], # pytensor.function outputs + [self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -383,7 +384,7 @@ def test_infer_shape(self): b = vector() self._compile_and_check( [A, b], # pytensor.function inputs - [self.op(A, b)], # pytensor.function outputs + [self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs # A must be square [ np.asarray(rng.random((5, 5)), dtype=config.floatX), @@ -397,10 +398,10 @@ def test_solve_correctness(self): rng = np.random.default_rng(utt.fetch_seed()) A = matrix() b = matrix() - y = self.op(A, b) + y = self.op_class(lower=True, b_ndim=2)(A, b) cho_solve_lower_func = pytensor.function([A, b], y) - y = self.op_upper(A, b) + y = self.op_class(lower=False, b_ndim=2)(A, b) cho_solve_upper_func = pytensor.function([A, b], y) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) @@ -435,12 +436,13 @@ def test_solve_dtype(self): A_val = np.eye(2) b_val = np.ones((2, 1)) + op = self.op_class(b_ndim=2) # try all dtype combinations for A_dtype, b_dtype in itertools.product(dtypes, dtypes): A = matrix(dtype=A_dtype) b = matrix(dtype=b_dtype) - x = self.op(A, b) + x = op(A, b) fn = function([A, b], x) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) From f19f95e0611e0bb8deb9c505406013df776e135a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 25 Aug 2023 12:52:13 +0200 Subject: [PATCH 7/7] Implement vectorize utility --- pytensor/graph/__init__.py | 2 +- pytensor/graph/replace.py | 67 +++++++++++++++++++++++++- pytensor/scalar/loop.py | 3 +- pytensor/tensor/blockwise.py | 58 ++++++---------------- pytensor/tensor/elemwise.py | 3 +- pytensor/tensor/random/op.py | 2 +- pytensor/tensor/rewriting/blockwise.py | 3 +- tests/graph/test_replace.py | 21 +++++++- tests/tensor/random/test_op.py | 2 +- tests/tensor/test_blockwise.py | 3 +- tests/tensor/test_elemwise.py | 2 +- 11 files changed, 111 insertions(+), 55 deletions(-) diff --git a/pytensor/graph/__init__.py b/pytensor/graph/__init__.py index e849a090c7..f7c4202452 100644 --- a/pytensor/graph/__init__.py +++ b/pytensor/graph/__init__.py @@ -9,7 +9,7 @@ clone, ancestors, ) -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.graph.op import Op from pytensor.graph.type import Type from pytensor.graph.fg import FunctionGraph diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index d16f4119ba..892a4abd80 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,8 +1,9 @@ -from functools import partial -from typing import Iterable, Optional, Sequence, Union, cast, overload +from functools import partial, singledispatch +from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] @@ -198,3 +199,65 @@ def toposort_key( return list(fg.outputs) else: return fg.outputs[0] + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + # Default implementation is provided in pytensor.tensor.blockwise + raise NotImplementedError + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + op = node.op + return _vectorize_node(op, node, *batched_inputs) + + +def vectorize( + outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable] +) -> Sequence[Variable]: + """Vectorize outputs graph given mapping from old variables to expanded counterparts version. + + Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`. + + Examples + -------- + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + from pytensor.graph import vectorize + + # Original graph + x = pt.vector("x") + y = pt.exp(x) / pt.sum(pt.exp(x)) + + # Vectorized graph + new_x = pt.matrix("new_x") + [new_y] = vectorize([y], {x: new_x}) + + fn = pytensor.function([new_x], new_y) + fn([[0, 1, 2], [2, 1, 0]]) + # array([[0.09003057, 0.24472847, 0.66524096], + # [0.66524096, 0.24472847, 0.09003057]]) + + """ + # Avoid circular import + + inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys()) + new_inputs = [vectorize.get(inp, inp) for inp in inputs] + + def transform(var): + if var in inputs: + return new_inputs[inputs.index(var)] + + node = var.owner + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + # TODO: MergeOptimization or node caching? + return [transform(out) for out in outputs] diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 08def4e230..17ffe5d711 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -2,7 +2,8 @@ from typing import Optional, Sequence, Tuple from pytensor.compile import rebuild_collect_shared -from pytensor.graph import Constant, FunctionGraph, Variable, clone +from pytensor.graph.basic import Constant, Variable, clone +from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 859bd68c55..91d2fcaef1 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,5 +1,4 @@ import re -from functools import singledispatch from typing import Any, Dict, List, Optional, Sequence, Tuple, cast import numpy as np @@ -9,6 +8,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node, vectorize from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor @@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str: return f"{inputs_sig}->{outputs_sig}" -@singledispatch -def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: +@_vectorize_node.register(Op) +def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: if hasattr(op, "gufunc_signature"): signature = op.gufunc_signature else: @@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) -def vectorize_node(node: Apply, *batched_inputs) -> Apply: - """Returns vectorized version of node with new batched inputs.""" - op = node.op - return _vectorize_node(op, node, *batched_inputs) - - class Blockwise(Op): """Generalizes a core `Op` to work with batched dimensions. @@ -279,42 +273,18 @@ def as_core(t, core_t): core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) - batch_ndims = self._batch_ndim_from_outputs(outputs) - - def transform(var): - # From a graph of ScalarOps, make a graph of Broadcast ops. - if isinstance(var.type, (NullType, DisconnectedType)): - return var - if var in core_inputs: - return inputs[core_inputs.index(var)] - if var in core_outputs: - return outputs[core_outputs.index(var)] - if var in core_ograds: - return ograds[core_ograds.index(var)] - - node = var.owner - - # The gradient contains a constant, which may be responsible for broadcasting - if node is None: - if batch_ndims: - var = shape_padleft(var, batch_ndims) - return var - - batched_inputs = [transform(inp) for inp in node.inputs] - batched_node = vectorize_node(node, *batched_inputs) - batched_var = batched_node.outputs[var.owner.outputs.index(var)] - - return batched_var - - ret = [] - for core_igrad, ipt in zip(core_igrads, inputs): - # Undefined gradient - if core_igrad is None: - ret.append(None) - else: - ret.append(transform(core_igrad)) + igrads = vectorize( + [core_igrad for core_igrad in core_igrads if core_igrad is not None], + vectorize=dict( + zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) + ), + ) - return ret + igrads_iter = iter(igrads) + return [ + None if core_igrad is None else next(igrads_iter) + for core_igrad in core_igrads + ] def L_op(self, inputs, outs, ograds): from pytensor.tensor.math import sum as pt_sum diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a1ff659882..94cbcd5f6f 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -8,6 +8,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.null_type import NullType +from pytensor.graph.replace import _vectorize_node from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp @@ -22,7 +23,7 @@ from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable -from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed +from pytensor.tensor.blockwise import vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 628916f508..597f68865d 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -7,6 +7,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( @@ -17,7 +18,6 @@ get_vector_length, infer_static_shape, ) -from pytensor.tensor.blockwise import _vectorize_node from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import ( broadcast_params, diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index c85fba3815..2533eb7aaa 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,7 +1,8 @@ from pytensor.compile.mode import optdb from pytensor.graph import node_rewriter +from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in -from pytensor.tensor.blockwise import Blockwise, vectorize_node +from pytensor.tensor.blockwise import Blockwise @node_rewriter([Blockwise]) diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 7fc0e530f9..393b9c567b 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -1,10 +1,11 @@ import numpy as np import pytest +import scipy.special import pytensor.tensor as pt from pytensor import config, function, shared from pytensor.graph.basic import graph_inputs -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.tensor import dvector, fvector, vector from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable @@ -223,3 +224,21 @@ def test_graph_replace_disconnected(self): assert oc[0] is o with pytest.raises(ValueError, match="Some replacements were not used"): oc = graph_replace([o], {fake: x.clone()}, strict=True) + + +class TestVectorize: + # TODO: Add tests with multiple outputs, constants, and other singleton types + + def test_basic(self): + x = pt.vector("x") + y = pt.exp(x) / pt.sum(pt.exp(x)) + + new_x = pt.matrix("new_x") + [new_y] = vectorize([y], {x: new_x}) + + fn = function([new_x], new_y) + test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX) + np.testing.assert_allclose( + fn(test_new_y), + scipy.special.softmax(test_new_y, axis=-1), + ) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 0bc8f0a73f..4a389811e1 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -4,8 +4,8 @@ import pytensor.tensor as at from pytensor import config, function from pytensor.gradient import NullTypeGradError, grad +from pytensor.graph.replace import vectorize_node from pytensor.raise_op import Assert -from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.math import eq from pytensor.tensor.random import normal from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 658c527430..92e07cf4e0 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -8,8 +8,9 @@ from pytensor import config from pytensor.gradient import grad from pytensor.graph import Apply, Op +from pytensor.graph.replace import vectorize_node from pytensor.tensor import tensor -from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.slinalg import Cholesky, Solve diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 1346744b6d..b4a8cc9aea 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -13,11 +13,11 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second -from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Any, Sum from pytensor.tensor.math import all as pt_all