Skip to content

Commit

Permalink
Add JAX rewrite for new Scan Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 16, 2023
1 parent ebb9d1c commit 01011aa
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

JAX = Mode(
JAXLinker(),
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
RewriteDatabaseQuery(
include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "not_jax"]
),
)
NUMBA = Mode(
NumbaLinker(),
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.loop

# isort: on
52 changes: 52 additions & 0 deletions pytensor/link/jax/dispatch/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import jax

from pytensor.compile.mode import get_mode
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.loop.op import Scan


@jax_funcify.register(Scan)
def jax_funcify_Scan(op, node, global_fgraph, **kwargs):
# TODO: Rewrite as a while loop if only last states are used
if op.has_while_condition:
raise NotImplementedError(
"Scan ops with while condition cannot be transpiled JAX"
)

# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
update_fg = op.update_fg.clone()
rewriter = get_mode("JAX").optimizer
rewriter(update_fg)

jaxified_scan_inner_fn = jax_funcify(update_fg, **kwargs)

# Only include the intermediate states that are used elsewhere
used_traces_idxs = [
i
for i, trace in enumerate(node.outputs[op.n_states :])
if global_fgraph.clients[trace]
]

def scan(max_iters, *outer_inputs):
states = outer_inputs[: op.n_states]
constants = outer_inputs[op.n_states :]

def scan_fn(carry, _):
resume, *carry = jaxified_scan_inner_fn(*carry, *constants)
assert resume
carry = list(carry)
# Return states as both carry and output to be appended
return carry, [c for i, c in enumerate(carry) if i in used_traces_idxs]

states, traces = jax.lax.scan(
scan_fn, init=list(states), xs=None, length=max_iters
)
for i in range(len(states)):
if i not in used_traces_idxs:
traces.insert(i, None)

return *states, *traces

return scan
1 change: 1 addition & 0 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def fgraph_to_python(
global_env = {}

body_assigns = []
kwargs.setdefault("global_fgraph", fgraph)
for node in order:
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
Expand Down
10 changes: 4 additions & 6 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import pytest

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector

Expand All @@ -27,9 +25,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")


opts = RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")


def compare_jax_and_py(
Expand Down
104 changes: 104 additions & 0 deletions tests/link/jax/test_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
import pytest

from pytensor import function, shared
from pytensor.graph import FunctionGraph
from pytensor.loop.basic import scan
from pytensor.scan import until
from pytensor.tensor import scalar, vector, zeros
from pytensor.tensor.random import normal
from tests.link.jax.test_basic import compare_jax_and_py


def test_scan_with_single_sequence():
xs = vector("xs")
_, [ys] = scan(lambda x: x * 100, sequences=[xs])

out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [np.arange(10)])


def test_scan_with_single_sequence_shortened_by_nsteps():
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
_, [ys] = scan(
lambda x: x * 100,
sequences=[xs],
n_steps=9,
)

out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [np.arange(10)])


def test_scan_with_multiple_sequences():
# JAX can only handle constant n_steps
xs = vector("xs", shape=(10,))
ys = vector("ys", shape=(10,))
_, [zs] = scan(
fn=lambda x, y: x * y,
sequences=[xs, ys],
)

out_fg = FunctionGraph([xs, ys], [zs])
compare_jax_and_py(
out_fg, [np.arange(10, dtype=xs.dtype), np.arange(10, dtype=ys.dtype)]
)


def test_scan_with_carried_and_non_carried_states():
x = scalar("x")
_, [ys1, ys2] = scan(
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
init_states=[x, None],
n_steps=10,
)
out_fg = FunctionGraph([x], [ys1, ys2])
compare_jax_and_py(out_fg, [-1])


def test_scan_with_sequence_and_carried_state():
xs = vector("xs")
_, [ys] = scan(
fn=lambda x, ytm1: (ytm1 + 1) * x,
init_states=[zeros(())],
sequences=[xs],
)
out_fg = FunctionGraph([xs], [ys])
compare_jax_and_py(out_fg, [[1, 2, 3]])


def test_scan_with_rvs():
rng = shared(np.random.default_rng(123))

[next_rng, _], [_, xs] = scan(
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
init_states=[rng, None],
n_steps=10,
)

# First without updates
fn = function([], xs, mode="JAX", updates=None)
res1 = fn()
res2 = fn()
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2)))

# Now with updates
fn = function([], xs, mode="JAX", updates={rng: next_rng})
res1 = fn()
res2 = fn()
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2)))


def test_while_scan_fails():
_, [xs] = scan(
fn=lambda x: (x + 1, until((x + 1) >= 9)),
init_states=[-1],
n_steps=20,
)

out_fg = FunctionGraph([], [xs])
with pytest.raises(
NotImplementedError,
match="Scan ops with while condition cannot be transpiled JAX",
):
compare_jax_and_py(out_fg, [])

0 comments on commit 01011aa

Please sign in to comment.