-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ebb9d1c
commit 01011aa
Showing
6 changed files
with
165 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import jax | ||
|
||
from pytensor.compile.mode import get_mode | ||
from pytensor.link.jax.dispatch.basic import jax_funcify | ||
from pytensor.loop.op import Scan | ||
|
||
|
||
@jax_funcify.register(Scan) | ||
def jax_funcify_Scan(op, node, global_fgraph, **kwargs): | ||
# TODO: Rewrite as a while loop if only last states are used | ||
if op.has_while_condition: | ||
raise NotImplementedError( | ||
"Scan ops with while condition cannot be transpiled JAX" | ||
) | ||
|
||
# Apply inner rewrites | ||
# TODO: Not sure this is the right place to do this, should we have a rewrite that | ||
# explicitly triggers the optimization of the inner graphs of Scan? | ||
update_fg = op.update_fg.clone() | ||
rewriter = get_mode("JAX").optimizer | ||
rewriter(update_fg) | ||
|
||
jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs) | ||
|
||
# Only include the intermediate states that are used elsewhere | ||
used_traces_idxs = [ | ||
i | ||
for i, trace in enumerate(node.outputs[op.n_states :]) | ||
if global_fgraph.clients[trace] | ||
] | ||
|
||
def scan(max_iters, *outer_inputs): | ||
states = outer_inputs[: op.n_states] | ||
constants = outer_inputs[op.n_states :] | ||
|
||
def scan_fn(carry, _): | ||
resume, *carry = jaxified_scan_inner_fn(*carry, *constants) | ||
assert resume | ||
carry = list(carry) | ||
# Return states as both carry and output to be appended | ||
return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs] | ||
|
||
states, traces = jax.lax.scan( | ||
scan_fn, init=list(states), xs=None, length=max_iters | ||
) | ||
for i in range(len(states)): | ||
if i not in used_traces_idxs: | ||
traces.insert(i, None) | ||
|
||
return *states, *traces | ||
|
||
return scan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor import function, shared | ||
from pytensor.graph import FunctionGraph | ||
from pytensor.loop.basic import scan | ||
from pytensor.scan import until | ||
from pytensor.tensor import scalar, vector, zeros | ||
from pytensor.tensor.random import normal | ||
from tests.link.jax.test_basic import compare_jax_and_py | ||
|
||
|
||
def test_scan_with_single_sequence(): | ||
xs = vector("xs") | ||
_, [ys] = scan(lambda x: x * 100, sequences=[xs]) | ||
|
||
out_fg = FunctionGraph([xs], [ys]) | ||
compare_jax_and_py(out_fg, [np.arange(10)]) | ||
|
||
|
||
def test_scan_with_single_sequence_shortened_by_nsteps(): | ||
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant | ||
_, [ys] = scan( | ||
lambda x: x * 100, | ||
sequences=[xs], | ||
n_steps=9, | ||
) | ||
|
||
out_fg = FunctionGraph([xs], [ys]) | ||
compare_jax_and_py(out_fg, [np.arange(10)]) | ||
|
||
|
||
def test_scan_with_multiple_sequences(): | ||
# JAX can only handle constant n_steps | ||
xs = vector("xs", shape=(10,)) | ||
ys = vector("ys", shape=(10,)) | ||
_, [zs] = scan( | ||
fn=lambda x, y: x * y, | ||
sequences=[xs, ys], | ||
) | ||
|
||
out_fg = FunctionGraph([xs, ys], [zs]) | ||
compare_jax_and_py( | ||
out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)] | ||
) | ||
|
||
|
||
def test_scan_with_carried_and_non_carried_states(): | ||
x = scalar("x") | ||
_, [ys1, ys2] = scan( | ||
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2), | ||
init_states=[x, None], | ||
n_steps=10, | ||
) | ||
out_fg = FunctionGraph([x], [ys1, ys2]) | ||
compare_jax_and_py(out_fg, [-1]) | ||
|
||
|
||
def test_scan_with_sequence_and_carried_state(): | ||
xs = vector("xs") | ||
_, [ys] = scan( | ||
fn=lambda x, ytm1: (ytm1 + 1) * x, | ||
init_states=[zeros(())], | ||
sequences=[xs], | ||
) | ||
out_fg = FunctionGraph([xs], [ys]) | ||
compare_jax_and_py(out_fg, [[1, 2, 3]]) | ||
|
||
|
||
def test_scan_with_rvs(): | ||
rng = shared(np.random.default_rng(123)) | ||
|
||
[next_rng, _], [_, xs] = scan( | ||
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs, | ||
init_states=[rng, None], | ||
n_steps=10, | ||
) | ||
|
||
# First without updates | ||
fn = function([], xs, mode="JAX", updates=None) | ||
res1 = fn() | ||
res2 = fn() | ||
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2))) | ||
|
||
# Now with updates | ||
fn = function([], xs, mode="JAX", updates={rng: next_rng}) | ||
res1 = fn() | ||
res2 = fn() | ||
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2))) | ||
|
||
|
||
def test_while_scan_fails(): | ||
_, [xs] = scan( | ||
fn=lambda x: (x + 1, until((x + 1) >= 9)), | ||
init_states=[-1], | ||
n_steps=20, | ||
) | ||
|
||
out_fg = FunctionGraph([], [xs]) | ||
with pytest.raises( | ||
NotImplementedError, | ||
match="Scan ops with while condition cannot be transpiled JAX", | ||
): | ||
compare_jax_and_py(out_fg, []) |