From a750fd7214a6bed7f947adedf538f6b2e8275fe5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 16 Jan 2023 11:23:22 +0100 Subject: [PATCH] Allow non-TensorVariable types to be traced in new Scan Op --- pytensor/link/jax/dispatch/loop.py | 19 +++++++--- pytensor/loop/op.py | 60 ++++++++++++++++++++++-------- tests/link/jax/test_loop.py | 13 ++++++- tests/loop/test_op.py | 37 ++++++++++++++---- 4 files changed, 100 insertions(+), 29 deletions(-) diff --git a/pytensor/link/jax/dispatch/loop.py b/pytensor/link/jax/dispatch/loop.py index 2c5b9609d9..d17d740d15 100644 --- a/pytensor/link/jax/dispatch/loop.py +++ b/pytensor/link/jax/dispatch/loop.py @@ -1,8 +1,10 @@ import jax +from jax.tree_util import tree_flatten, tree_unflatten from pytensor.compile.mode import get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.loop.op import Scan +from pytensor.typed_list import TypedListType @jax_funcify.register(Scan) @@ -43,10 +45,17 @@ def scan_fn(carry, _): states, traces = jax.lax.scan( scan_fn, init=list(states), xs=None, length=max_iters ) - for i in range(len(states)): - if i not in used_traces_idxs: - traces.insert(i, None) - - return *states, *traces + final_traces = [None] * len(states) + for idx, trace in zip(used_traces_idxs, traces): + if isinstance(op.trace_types[idx], TypedListType): + flattened_trace, treedef = tree_flatten(trace) + transposed_trace = [ + tree_unflatten(treedef, l) for l in zip(*flattened_trace) + ] + final_traces[idx] = transposed_trace + else: + final_traces[idx] = trace + + return *states, *final_traces return scan diff --git a/pytensor/loop/op.py b/pytensor/loop/op.py index ba14e383ed..565d2a3b7c 100644 --- a/pytensor/loop/op.py +++ b/pytensor/loop/op.py @@ -7,18 +7,13 @@ from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter from pytensor.graph.rewriting.basic import in2out from pytensor.scalar import constant -from pytensor.tensor import ( - NoneConst, - add, - and_, - empty, - get_scalar_constant_value, - set_subtensor, -) +from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i +from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type_other import NoneTypeT +from pytensor.typed_list import GetItem, TypedListType, append, make_empty_list def validate_loop_update_types(update): @@ -176,8 +171,7 @@ def __init__( ) ) else: - # We can't concatenate all types of states, such as RandomTypes - self.trace_types.append(NoneConst.type) + self.trace_types.append(TypedListType(state_type)) self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]] self.n_constants = len(self.constant_types) @@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters): if fgraph.clients[trace] ] - # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced - for trace_idx in used_traces_idxs: - assert not isinstance(old_states[trace_idx].type, NoneTypeT) - # Inputs to the new Loop max_iters = node.inputs[0] init_states = node.inputs[1 : 1 + op.n_states] @@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters): (max_iters, *tuple(init_states[trace_idx].shape)), dtype=init_states[trace_idx].dtype, ) + if isinstance(init_states[trace_idx].type, DenseTensorType) + else make_empty_list(init_states[trace_idx].type) for trace_idx in used_traces_idxs ] constants = node.inputs[1 + op.n_states :] @@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters): inner_while_cond, *inner_next_states = update_fg.outputs inner_next_traces = [ set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx]) + if isinstance(prev_trace.type, DenseTensorType) + else append(prev_trace, inner_next_states[trace_idx]) for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces) ] for t in inner_next_traces: @@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters): replacements = dict(zip(old_states, new_states)) for trace_idx, new_trace in zip(used_traces_idxs, new_traces): # If there is no while condition, the whole trace will be used - if op.has_while_condition: + if op.has_while_condition and isinstance(new_trace.type, DenseTensorType): new_trace = new_trace[:final_idx] replacements[old_traces[trace_idx]] = new_trace return replacements @@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters): "not_jax", position=1.0, ) + + +@node_rewriter([Scan]) +def scan_view_last_state(fgraph, node): + """Replace trace[-1] by the last state output of a Scan node""" + replacements = {} + for final_state, trace in zip( + node.outputs[: node.op.n_states], node.outputs[node.op.n_states :] + ): + clients = fgraph.clients[trace] + for client, _ in clients: + if client == "output": + continue + if isinstance(client.op, (Subtensor, GetItem)): + if isinstance(client.op, Subtensor): + idxs = get_idx_list(client.inputs, client.op.idx_list) + if len(idxs) == 1: + idx = idxs[0] + else: + idx = client.inputs[1] + try: + last_index = get_scalar_constant_value(idx) == -1 + except NotScalarConstantError: + continue + if last_index: + replacements[client.default_output()] = final_state + return replacements + + +optdb.register( + "scan_view_last_state", + in2out(scan_view_last_state), + "fast_compile", + "fast_run", + position=0.999, +) diff --git a/tests/link/jax/test_loop.py b/tests/link/jax/test_loop.py index 04ae1e0ba4..e958631b33 100644 --- a/tests/link/jax/test_loop.py +++ b/tests/link/jax/test_loop.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from jax.tree_util import tree_leaves from pytensor import function, shared from pytensor.graph import FunctionGraph @@ -70,7 +71,7 @@ def test_scan_with_sequence_and_carried_state(): def test_scan_with_rvs(): rng = shared(np.random.default_rng(123)) - [next_rng, _], [_, xs] = scan( + [final_rng, _], [rngs, xs] = scan( fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, init_states=[rng, None], n_steps=10, @@ -83,11 +84,19 @@ def test_scan_with_rvs(): assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) # Now with updates - fn = function([], xs, mode="JAX", updates={rng: next_rng}) + fn = function([], xs, mode="JAX", updates={rng: final_rng}) res1 = fn() res2 = fn() assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) + # Test traced rngs + fn = function([], [rngs, final_rng], mode="JAX") + rngs_res, final_rng_res = fn() + assert isinstance(rngs_res, list) and len(rngs_res) == 10 + assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [ + np.array(v).tolist() for v in tree_leaves(final_rng_res) + ] + def test_while_scan_fails(): _, [xs] = scan( diff --git a/tests/loop/test_op.py b/tests/loop/test_op.py index 28393d0d6a..632258143b 100644 --- a/tests/loop/test_op.py +++ b/tests/loop/test_op.py @@ -4,11 +4,13 @@ from pytensor import function, shared from pytensor.compile import DeepCopyOp from pytensor.graph import FunctionGraph -from pytensor.loop.op import Loop, Scan +from pytensor.graph.rewriting.basic import in2out +from pytensor.loop.op import Loop, Scan, scan_view_last_state from pytensor.tensor import constant, empty, lscalar, scalar, vector from pytensor.tensor.random import normal +from pytensor.tensor.random.type import RandomGeneratorType from pytensor.tensor.subtensor import Subtensor -from pytensor.tensor.type_other import NoneTypeT +from pytensor.typed_list import TypedListType def test_loop_basic(): @@ -152,10 +154,31 @@ def test_fori_random_scan(): [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]], ) - _, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared) - assert isinstance(rngs.type, NoneTypeT) + last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)( + n_iters, dummy_init, rng_shared + ) + assert isinstance(last_rng.type, RandomGeneratorType) + assert isinstance(rngs.type, TypedListType) + assert isinstance(rngs.type.ttype, RandomGeneratorType) + + fn = function([], [ys, rngs], updates={rng_shared: last_rng}) + for i in range(2): + ys_res, rngs_res = fn() + for y_res, rng_res in zip(ys_res, rngs_res): + np.testing.assert_almost_equal(y_res, rng_test.normal()) + assert rng_res.__getstate__() == rng_test.__getstate__() - fn = function([], ys, updates={rng_shared: new_rng}) - np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) - np.testing.assert_array_equal(fn(), rng_test.normal(size=5)) +def test_scan_view_last_state(): + x = scalar("x") + update_fg = FunctionGraph([x], [x > 5, x + 2]) + + n_iters = 10 + y1, ys = Scan(update_fg=update_fg)(n_iters, x) + + y2 = ys[-1] + fgraph = FunctionGraph(outputs=[y2, ys], clone=False) + assert fgraph.outputs[0] is not y1 + in2out(scan_view_last_state).apply(fgraph) + assert fgraph.outputs[0] is y1 + assert fgraph.outputs[1] is ys