Skip to content

Commit

Permalink
refactor is_in_ancestors to support multiple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine authored and ricardoV94 committed Dec 9, 2022
1 parent 38731ad commit 8ad3317
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
20 changes: 12 additions & 8 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,30 +1568,34 @@ def expand(o: Apply) -> List[Apply]:
)


def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
"""Determine if `f_apply` is in the graph given by `l_apply`.
def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) -> bool:
"""Determine if any `depends_on` is in the graph given by ``apply``.
Parameters
----------
l_apply : Apply
The node to walk.
f_apply : Apply
The node to find in `l_apply`.
apply : Apply
The Apply node to check.
depends_on : Union[Apply, Collection[Apply]]
Apply nodes to check dependency on
Returns
-------
bool
"""
computed = set()
todo = [l_apply]
todo = [apply]
if not isinstance(depends_on, Collection):
depends_on = {depends_on}
else:
depends_on = set(depends_on)
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
if cur is f_apply:
if cur in depends_on:
return True
else:
todo.append(cur)
Expand Down
8 changes: 4 additions & 4 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor import as_symbolic
from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, is_in_ancestors
from pytensor.graph.basic import Apply, Variable, apply_depends_on
from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
Expand Down Expand Up @@ -604,7 +604,7 @@ def apply(self, fgraph):
return False
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors(
if proposal.inputs[0] == merging_node.inputs[0] and not apply_depends_on(
proposal, merging_node
):
# Create a list of replacements for proposal
Expand Down Expand Up @@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node):
for proposal in cond_nodes[1:]:
if (
proposal.inputs[0] == merging_node.inputs[0]
and not is_in_ancestors(proposal, merging_node)
and not is_in_ancestors(merging_node, proposal)
and not apply_depends_on(proposal, merging_node)
and not apply_depends_on(merging_node, proposal)
):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs]
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
Apply,
Constant,
Variable,
apply_depends_on,
equal_computations,
graph_inputs,
io_toposort,
is_in_ancestors,
)
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import ReplaceValidate
Expand Down Expand Up @@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node):
old_new += [(o, new_outs[nw_pos])]
# Check if the new outputs depend on the old scan node
old_scan_is_used = [
is_in_ancestors(new.owner, node) for old, new in old_new
apply_depends_on(new.owner, node) for old, new in old_new
]
if any(old_scan_is_used):
return False
Expand Down Expand Up @@ -1877,7 +1877,7 @@ def belongs_to_set(self, node, set_nodes):

# Check to see if it is an input of a different node
for nd in set_nodes:
if is_in_ancestors(node, nd) or is_in_ancestors(nd, node):
if apply_depends_on(node, nd) or apply_depends_on(nd, node):
return False

if not node.op.info.as_while:
Expand Down
12 changes: 8 additions & 4 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NominalVariable,
Variable,
ancestors,
apply_depends_on,
applys_between,
as_string,
clone,
Expand All @@ -20,7 +21,6 @@
get_var_by_name,
graph_inputs,
io_toposort,
is_in_ancestors,
list_of_nodes,
orphans_between,
vars_between,
Expand Down Expand Up @@ -491,15 +491,19 @@ def test_list_of_nodes():
assert res == [o2.owner, o1.owner]


def test_is_in_ancestors():
def test_apply_depends_on():

r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2 = MyOp(r1, o1)
o2.name = "o2"
o3 = MyOp(r3, o1, o2)
o3.name = "o3"

assert is_in_ancestors(o2.owner, o1.owner)
assert apply_depends_on(o2.owner, o1.owner)
assert apply_depends_on(o2.owner, o2.owner)
assert apply_depends_on(o3.owner, [o1.owner, o2.owner])


@pytest.mark.xfail(reason="Not implemented")
Expand Down

0 comments on commit 8ad3317

Please sign in to comment.