Skip to content

Commit

Permalink
Allow non-TensorVariable types to be traced in new Scan Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 16, 2023
1 parent 60e80f5 commit a750fd7
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 29 deletions.
19 changes: 14 additions & 5 deletions pytensor/link/jax/dispatch/loop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import jax
from jax.tree_util import tree_flatten, tree_unflatten

from pytensor.compile.mode import get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.loop.op import Scan
from pytensor.typed_list import TypedListType


@jax_funcify.register(Scan)
Expand Down Expand Up @@ -43,10 +45,17 @@ def scan_fn(carry, _):
states, traces = jax.lax.scan(
scan_fn, init=list(states), xs=None, length=max_iters
)
for i in range(len(states)):
if i not in used_traces_idxs:
traces.insert(i, None)

return *states, *traces
final_traces = [None] * len(states)
for idx, trace in zip(used_traces_idxs, traces):
if isinstance(op.trace_types[idx], TypedListType):
flattened_trace, treedef = tree_flatten(trace)
transposed_trace = [
tree_unflatten(treedef, l) for l in zip(*flattened_trace)
]
final_traces[idx] = transposed_trace
else:
final_traces[idx] = trace

return *states, *final_traces

return scan
60 changes: 45 additions & 15 deletions pytensor/loop/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
from pytensor.graph.rewriting.basic import in2out
from pytensor.scalar import constant
from pytensor.tensor import (
NoneConst,
add,
and_,
empty,
get_scalar_constant_value,
set_subtensor,
)
from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.type_other import NoneTypeT
from pytensor.typed_list import GetItem, TypedListType, append, make_empty_list


def validate_loop_update_types(update):
Expand Down Expand Up @@ -176,8 +171,7 @@ def __init__(
)
)
else:
# We can't concatenate all types of states, such as RandomTypes
self.trace_types.append(NoneConst.type)
self.trace_types.append(TypedListType(state_type))

self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]]
self.n_constants = len(self.constant_types)
Expand Down Expand Up @@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
if fgraph.clients[trace]
]

# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
for trace_idx in used_traces_idxs:
assert not isinstance(old_states[trace_idx].type, NoneTypeT)

# Inputs to the new Loop
max_iters = node.inputs[0]
init_states = node.inputs[1 : 1 + op.n_states]
Expand All @@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
(max_iters, *tuple(init_states[trace_idx].shape)),
dtype=init_states[trace_idx].dtype,
)
if isinstance(init_states[trace_idx].type, DenseTensorType)
else make_empty_list(init_states[trace_idx].type)
for trace_idx in used_traces_idxs
]
constants = node.inputs[1 + op.n_states :]
Expand Down Expand Up @@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
inner_while_cond, *inner_next_states = update_fg.outputs
inner_next_traces = [
set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx])
if isinstance(prev_trace.type, DenseTensorType)
else append(prev_trace, inner_next_states[trace_idx])
for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces)
]
for t in inner_next_traces:
Expand Down Expand Up @@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
replacements = dict(zip(old_states, new_states))
for trace_idx, new_trace in zip(used_traces_idxs, new_traces):
# If there is no while condition, the whole trace will be used
if op.has_while_condition:
if op.has_while_condition and isinstance(new_trace.type, DenseTensorType):
new_trace = new_trace[:final_idx]
replacements[old_traces[trace_idx]] = new_trace
return replacements
Expand All @@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters):
"not_jax",
position=1.0,
)


@node_rewriter([Scan])
def scan_view_last_state(fgraph, node):
"""Replace trace[-1] by the last state output of a Scan node"""
replacements = {}
for final_state, trace in zip(
node.outputs[: node.op.n_states], node.outputs[node.op.n_states :]
):
clients = fgraph.clients[trace]
for client, _ in clients:
if client == "output":
continue
if isinstance(client.op, (Subtensor, GetItem)):
if isinstance(client.op, Subtensor):
idxs = get_idx_list(client.inputs, client.op.idx_list)
if len(idxs) == 1:
idx = idxs[0]
else:
idx = client.inputs[1]
try:
last_index = get_scalar_constant_value(idx) == -1
except NotScalarConstantError:
continue
if last_index:
replacements[client.default_output()] = final_state
return replacements


optdb.register(
"scan_view_last_state",
in2out(scan_view_last_state),
"fast_compile",
"fast_run",
position=0.999,
)
13 changes: 11 additions & 2 deletions tests/link/jax/test_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from jax.tree_util import tree_leaves

from pytensor import function, shared
from pytensor.graph import FunctionGraph
Expand Down Expand Up @@ -70,7 +71,7 @@ def test_scan_with_sequence_and_carried_state():
def test_scan_with_rvs():
rng = shared(np.random.default_rng(123))

[next_rng, _], [_, xs] = scan(
[final_rng, _], [rngs, xs] = scan(
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
init_states=[rng, None],
n_steps=10,
Expand All @@ -83,11 +84,19 @@ def test_scan_with_rvs():
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2)))

# Now with updates
fn = function([], xs, mode="JAX", updates={rng: next_rng})
fn = function([], xs, mode="JAX", updates={rng: final_rng})
res1 = fn()
res2 = fn()
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2)))

# Test traced rngs
fn = function([], [rngs, final_rng], mode="JAX")
rngs_res, final_rng_res = fn()
assert isinstance(rngs_res, list) and len(rngs_res) == 10
assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [
np.array(v).tolist() for v in tree_leaves(final_rng_res)
]


def test_while_scan_fails():
_, [xs] = scan(
Expand Down
37 changes: 30 additions & 7 deletions tests/loop/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from pytensor import function, shared
from pytensor.compile import DeepCopyOp
from pytensor.graph import FunctionGraph
from pytensor.loop.op import Loop, Scan
from pytensor.graph.rewriting.basic import in2out
from pytensor.loop.op import Loop, Scan, scan_view_last_state
from pytensor.tensor import constant, empty, lscalar, scalar, vector
from pytensor.tensor.random import normal
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type_other import NoneTypeT
from pytensor.typed_list import TypedListType


def test_loop_basic():
Expand Down Expand Up @@ -152,10 +154,31 @@ def test_fori_random_scan():
[constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]],
)

_, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared)
assert isinstance(rngs.type, NoneTypeT)
last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)(
n_iters, dummy_init, rng_shared
)
assert isinstance(last_rng.type, RandomGeneratorType)
assert isinstance(rngs.type, TypedListType)
assert isinstance(rngs.type.ttype, RandomGeneratorType)

fn = function([], [ys, rngs], updates={rng_shared: last_rng})
for i in range(2):
ys_res, rngs_res = fn()
for y_res, rng_res in zip(ys_res, rngs_res):
np.testing.assert_almost_equal(y_res, rng_test.normal())
assert rng_res.__getstate__() == rng_test.__getstate__()

fn = function([], ys, updates={rng_shared: new_rng})

np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
def test_scan_view_last_state():
x = scalar("x")
update_fg = FunctionGraph([x], [x > 5, x + 2])

n_iters = 10
y1, ys = Scan(update_fg=update_fg)(n_iters, x)

y2 = ys[-1]
fgraph = FunctionGraph(outputs=[y2, ys], clone=False)
assert fgraph.outputs[0] is not y1
in2out(scan_view_last_state).apply(fgraph)
assert fgraph.outputs[0] is y1
assert fgraph.outputs[1] is ys

0 comments on commit a750fd7

Please sign in to comment.