Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop check and iter check for TensorDictKeysView #200

Merged
merged 15 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from .memmap import MemmapTensor, set_transfer_ownership
from .tensordict import (
_TensorDictKeysView,
detect_loop,
LazyStackedTensorDict,
merge_tensordicts,
SubTensorDict,
Expand All @@ -23,4 +25,5 @@
"TensorDict",
"merge_tensordicts",
"set_transfer_ownership",
"_TensorDictKeysView",
]
43 changes: 43 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5469,3 +5469,46 @@ def _clone_value(value, recurse):
return value.clone(recurse=False)
else:
return value


def detect_loop(tensordict: TensorDict) -> bool:
"""
This helper function detects the presence of an auto nesting loop inside
a TensorDict object. Auto nesting appears when a key of TensorDict references
another TensorDict and initiates a recursive infinite loop. It returns True
if at least one loop is found, otherwise returns False. An example is:

>>> td = TensorDict(
>>> source={
>>> "a": TensorDict(
>>> source={"b": torch.randn(4, 3, 1)},
>>> batch_size=[4, 3, 1]),
>>> },
>>> batch_size=[4, 3, 1]
>>> )
>>> td["b"]["c"] = td
>>>
>>> print(detect_loop(td))
True

Args:
tensordict (TensorDict): The Tensordict Object to check for autonested loops presence.
Returns
bool: True if one loop is found, otherwise False
"""
visited = set()
visited.add(id(tensordict))

def detect(t_d: TensorDict):
for k, v in t_d.items():
if id(v) in visited:
return True
visited.add(id(v))
if isinstance(v, TensorDict):
loop = detect(v)
if loop:
return True
visited.remove(id(v))
return False

return detect(tensordict)
234 changes: 233 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import torch
import torchsnapshot
from _utils_internal import get_available_devices, prod, TestTensorDictsBase
from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict
from tensordict import (
_TensorDictKeysView,
detect_loop,
LazyStackedTensorDict,
MemmapTensor,
TensorDict,
)
from tensordict.tensordict import (
_stack as stack_td,
assert_allclose_td,
Expand Down Expand Up @@ -3715,6 +3721,232 @@ def test_tensordict_prealloc_nested():
assert buffer["agent.obs"].batch_size == torch.Size([B, N, T])


def test_tensordict_view_iteration():
td_simple = TensorDict(
source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)},
batch_size=[4, 3, 2, 1],
)

view = _TensorDictKeysView(
tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True
)
keys = list(view)
assert len(keys) == 2
assert "a" in keys
assert "b" in keys

td_nested = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]),
},
batch_size=[4, 3, 2, 1],
)

view = _TensorDictKeysView(
tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True
)
keys = list(view)
assert len(keys) == 2
assert "a" in keys
assert ("b", "c") in keys

view = _TensorDictKeysView(
tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True
)
keys = list(view)
assert len(keys) == 1
assert "a" in keys

view = _TensorDictKeysView(
tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True
)
keys = list(view)
assert len(keys) == 3
assert "a" in keys
assert "b" in keys
assert ("b", "c") in keys

# We are not considering loops given by referencing non Dicts (leaf nodes) from two different key sequences

td_auto_nested_loop = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_loop["b"]["d"] = td_auto_nested_loop

view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=False,
leaves_only=False,
error_on_loop=True,
)
keys = list(view)
assert len(keys) == 2
assert "a" in keys
assert "b" in keys

view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=False,
leaves_only=True,
error_on_loop=True,
)
keys = list(view)
assert len(keys) == 1
assert "a" in keys

with pytest.raises(RecursionError):
view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=True,
leaves_only=True,
error_on_loop=True,
)
list(view)

with pytest.raises(RecursionError):
view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=True,
leaves_only=False,
error_on_loop=True,
)
list(view)

view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=True,
leaves_only=False,
error_on_loop=False,
)

keys = list(view)
assert len(keys) == 3
assert "a" in keys
assert "b" in keys
assert ("b", "c") in keys

view = _TensorDictKeysView(
tensordict=td_auto_nested_loop,
include_nested=True,
leaves_only=True,
error_on_loop=False,
)

keys = list(view)
assert len(keys) == 2
assert "a" in keys
assert ("b", "c") in keys

td_auto_nested_loop_2 = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2["b"]

view = _TensorDictKeysView(
tensordict=td_auto_nested_loop_2,
include_nested=True,
leaves_only=False,
error_on_loop=False,
)

keys = list(view)
assert len(keys) == 3
assert "a" in keys
assert "b" in keys
assert ("b", "c") in keys


def test_detect_loop():
td_simple = TensorDict(
source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)},
batch_size=[4, 3, 2, 1],
)
assert not detect_loop(td_simple)

td_nested = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]),
},
batch_size=[4, 3, 2, 1],
)
assert not detect_loop(td_nested)

td_auto_nested_no_loop_1 = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_no_loop_1["b"]["d"] = td_auto_nested_no_loop_1["a"]

assert not detect_loop(td_auto_nested_no_loop_1)

td_auto_nested_no_loop_2 = TensorDict(
source={
"a": TensorDict(
source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1]
),
"b": TensorDict(
source={"d": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1]
),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_no_loop_2["b"]["e"] = td_auto_nested_no_loop_2["a"]

assert not detect_loop(td_auto_nested_no_loop_2)

td_auto_nested_no_loop_3 = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 2),
"b": TensorDict(
source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1]
),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_no_loop_3["b"]["d"] = td_auto_nested_no_loop_3["b"]["c"]

assert not detect_loop(td_auto_nested_no_loop_3)

td_auto_nested_loop_1 = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 2),
"b": TensorDict(
source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1]
),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_loop_1["b"]["d"] = td_auto_nested_loop_1["b"]

assert detect_loop(td_auto_nested_loop_1)

td_auto_nested_loop_2 = TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 2),
"b": TensorDict(
source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1]
),
},
batch_size=[4, 3, 2, 1],
)
td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2

assert detect_loop(td_auto_nested_loop_2)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)