diff --git a/pytensor/graph/__init__.py b/pytensor/graph/__init__.py index e849a090c7..f7c4202452 100644 --- a/pytensor/graph/__init__.py +++ b/pytensor/graph/__init__.py @@ -9,7 +9,7 @@ clone, ancestors, ) -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.graph.op import Op from pytensor.graph.type import Type from pytensor.graph.fg import FunctionGraph diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index d16f4119ba..892a4abd80 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,8 +1,9 @@ -from functools import partial -from typing import Iterable, Optional, Sequence, Union, cast, overload +from functools import partial, singledispatch +from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] @@ -198,3 +199,65 @@ def toposort_key( return list(fg.outputs) else: return fg.outputs[0] + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + # Default implementation is provided in pytensor.tensor.blockwise + raise NotImplementedError + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + op = node.op + return _vectorize_node(op, node, *batched_inputs) + + +def vectorize( + outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable] +) -> Sequence[Variable]: + """Vectorize outputs graph given mapping from old variables to expanded counterparts version. + + Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`. + + Examples + -------- + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + from pytensor.graph import vectorize + + # Original graph + x = pt.vector("x") + y = pt.exp(x) / pt.sum(pt.exp(x)) + + # Vectorized graph + new_x = pt.matrix("new_x") + [new_y] = vectorize([y], {x: new_x}) + + fn = pytensor.function([new_x], new_y) + fn([[0, 1, 2], [2, 1, 0]]) + # array([[0.09003057, 0.24472847, 0.66524096], + # [0.66524096, 0.24472847, 0.09003057]]) + + """ + # Avoid circular import + + inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys()) + new_inputs = [vectorize.get(inp, inp) for inp in inputs] + + def transform(var): + if var in inputs: + return new_inputs[inputs.index(var)] + + node = var.owner + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + # TODO: MergeOptimization or node caching? + return [transform(out) for out in outputs] diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 08def4e230..17ffe5d711 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -2,7 +2,8 @@ from typing import Optional, Sequence, Tuple from pytensor.compile import rebuild_collect_shared -from pytensor.graph import Constant, FunctionGraph, Variable, clone +from pytensor.graph.basic import Constant, Variable, clone +from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 50f7445ce2..210545b255 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,5 +1,4 @@ import re -from functools import singledispatch from typing import Any, Dict, List, Optional, Sequence, Tuple, cast import numpy as np @@ -9,6 +8,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node, vectorize from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor @@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str: return f"{inputs_sig}->{outputs_sig}" -@singledispatch -def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: +@_vectorize_node.register(Op) +def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: if hasattr(op, "gufunc_signature"): signature = op.gufunc_signature else: @@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) -def vectorize_node(node: Apply, *batched_inputs) -> Apply: - """Returns vectorized version of node with new batched inputs.""" - op = node.op - return _vectorize_node(op, node, *batched_inputs) - - class Blockwise(Op): """Generalizes a core `Op` to work with batched dimensions. @@ -279,42 +273,18 @@ def as_core(t, core_t): core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) - batch_ndims = self._batch_ndim_from_outputs(outputs) - - def transform(var): - # From a graph of ScalarOps, make a graph of Broadcast ops. - if isinstance(var.type, (NullType, DisconnectedType)): - return var - if var in core_inputs: - return inputs[core_inputs.index(var)] - if var in core_outputs: - return outputs[core_outputs.index(var)] - if var in core_ograds: - return ograds[core_ograds.index(var)] - - node = var.owner - - # The gradient contains a constant, which may be responsible for broadcasting - if node is None: - if batch_ndims: - var = shape_padleft(var, batch_ndims) - return var - - batched_inputs = [transform(inp) for inp in node.inputs] - batched_node = vectorize_node(node, *batched_inputs) - batched_var = batched_node.outputs[var.owner.outputs.index(var)] - - return batched_var - - ret = [] - for core_igrad, ipt in zip(core_igrads, inputs): - # Undefined gradient - if core_igrad is None: - ret.append(None) - else: - ret.append(transform(core_igrad)) + igrads = vectorize( + [core_igrad for core_igrad in core_igrads if core_igrad is not None], + vectorize=dict( + zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds) + ), + ) - return ret + igrads_iter = iter(igrads) + return [ + None if core_igrad is None else next(igrads_iter) + for core_igrad in core_igrads + ] def L_op(self, inputs, outs, ograds): from pytensor.tensor.math import sum as pt_sum diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 377444609d..ab71e1c586 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -8,6 +8,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.null_type import NullType +from pytensor.graph.replace import _vectorize_node from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp @@ -22,7 +23,7 @@ from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable -from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed +from pytensor.tensor.blockwise import vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 9461bf440a..52ee0cebaa 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -7,6 +7,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( @@ -17,7 +18,6 @@ get_vector_length, infer_static_shape, ) -from pytensor.tensor.blockwise import _vectorize_node from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.utils import ( broadcast_params, diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index c85fba3815..2533eb7aaa 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,7 +1,8 @@ from pytensor.compile.mode import optdb from pytensor.graph import node_rewriter +from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in -from pytensor.tensor.blockwise import Blockwise, vectorize_node +from pytensor.tensor.blockwise import Blockwise @node_rewriter([Blockwise]) diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 7fc0e530f9..393b9c567b 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -1,10 +1,11 @@ import numpy as np import pytest +import scipy.special import pytensor.tensor as pt from pytensor import config, function, shared from pytensor.graph.basic import graph_inputs -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import clone_replace, graph_replace, vectorize from pytensor.tensor import dvector, fvector, vector from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable @@ -223,3 +224,21 @@ def test_graph_replace_disconnected(self): assert oc[0] is o with pytest.raises(ValueError, match="Some replacements were not used"): oc = graph_replace([o], {fake: x.clone()}, strict=True) + + +class TestVectorize: + # TODO: Add tests with multiple outputs, constants, and other singleton types + + def test_basic(self): + x = pt.vector("x") + y = pt.exp(x) / pt.sum(pt.exp(x)) + + new_x = pt.matrix("new_x") + [new_y] = vectorize([y], {x: new_x}) + + fn = function([new_x], new_y) + test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX) + np.testing.assert_allclose( + fn(test_new_y), + scipy.special.softmax(test_new_y, axis=-1), + ) diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 0bc8f0a73f..4a389811e1 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -4,8 +4,8 @@ import pytensor.tensor as at from pytensor import config, function from pytensor.gradient import NullTypeGradError, grad +from pytensor.graph.replace import vectorize_node from pytensor.raise_op import Assert -from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.math import eq from pytensor.tensor.random import normal from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 658c527430..92e07cf4e0 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -8,8 +8,9 @@ from pytensor import config from pytensor.gradient import grad from pytensor.graph import Apply, Op +from pytensor.graph.replace import vectorize_node from pytensor.tensor import tensor -from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.slinalg import Cholesky, Solve diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 3d3aa1b28d..dc64442eff 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -13,11 +13,11 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second -from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Any, Sum from pytensor.tensor.math import all as pt_all