From 25f4ce2898ac4b1e6bcfbf1c353862abaef70a7d Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 2 Feb 2023 14:54:58 +0000 Subject: [PATCH 1/9] PoC --- tensordict/tensordict.py | 240 +++++++++++++++++++++++++-------------- test/_utils_internal.py | 5 + test/test_tensordict.py | 1 + 3 files changed, 161 insertions(+), 85 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index f2f44ba0b..2aedb97f8 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -168,26 +168,26 @@ def __iter__(self): else: yield from self._keys() else: - yield from self._iter_helper(self.tensordict) + try: + yield from self._iter_helper(self.tensordict) + except RecursionError as e: + raise RecursionError( + "Iterating over contents of TensorDict resulted in a recursion " + "error. It's likely that you have auto-nested values, in which " + f"case iteration with `include_nested=True` is not supported. {e}" + ) def _iter_helper(self, tensordict, prefix=None): - items_iter = self._items(tensordict) - - for key, value in items_iter: + for key, value in self._items(tensordict): full_key = self._combine_keys(prefix, key) - if ( - isinstance(value, (TensorDictBase, KeyedJaggedTensor)) - and self.include_nested - ): - subkeys = tuple( + if not self.leaves_only or not isinstance(value, TensorDictBase): + yield full_key + if isinstance(value, (TensorDictBase, KeyedJaggedTensor)): + yield from tuple( self._iter_helper( - value, - full_key if isinstance(full_key, tuple) else (full_key,), + value, full_key if isinstance(full_key, tuple) else (full_key,) ) ) - yield from subkeys - if not (isinstance(value, TensorDictBase) and self.leaves_only): - yield full_key def _combine_keys(self, prefix, key): if prefix is not None: @@ -596,7 +596,8 @@ def apply_(self, fn: Callable) -> TensorDictBase: self or a copy of self with the function applied """ - return self.apply(fn, inplace=True) + return _apply_safe(lambda _, value: fn(value), self, inplace=True) + # return self.apply(fn, inplace=True) def apply( self, @@ -990,22 +991,24 @@ def __eq__(self, other: object) -> TensorDictBase: """ if not isinstance(other, (TensorDictBase, dict, float, int)): return False - if not isinstance(other, TensorDictBase) and isinstance(other, dict): + if isinstance(other, dict): other = make_tensordict(**other, batch_size=self.batch_size) - if not isinstance(other, TensorDictBase): - return TensorDict( - {key: value == other for key, value in self.items()}, - self.batch_size, - device=self.device, - ) - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") - d = {} - for (key, item1) in self.items(): - d[key] = item1 == other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + + def hook(key, value): + if isinstance(other, TensorDictBase): + other_ = other.get(key) if key else other + keys1 = set(value.keys()) + keys2 = set(other_.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError( + f"keys in tensordicts mismatch, got {keys1} and {keys2}" + ) + + def fn(key, value): + other_ = other.get(key) if isinstance(other, TensorDictBase) else other + return value == other_ + + return _apply_safe(fn, self, hook=hook) @abc.abstractmethod def del_(self, key: str) -> TensorDictBase: @@ -1174,21 +1177,12 @@ def to_tensordict(self): a new TensorDict object containing the same values. """ - return TensorDict( - { - key: value.clone() - if not isinstance(value, TensorDictBase) - else value.to_tensordict() - for key, value in self.items() - }, - device=self.device, - batch_size=self.batch_size, - ) + return _apply_safe(lambda _, value: value.clone(), self) def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" - for key in self.keys(): - self.fill_(key, 0) + for _, value in _items_safe(self): + value.zero_() return self def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: @@ -1258,15 +1252,7 @@ def clone(self, recurse: bool = True) -> TensorDictBase: TensorDict will be copied too. Default is `True`. """ - - return TensorDict( - source={key: _clone_value(value, recurse) for key, value in self.items()}, - batch_size=self.batch_size, - device=self.device, - _run_checks=False, - _is_shared=self.is_shared() if not recurse else False, - _is_memmap=self.is_memmap() if not recurse else False, - ) + return _apply_safe(lambda _, value: _clone_value(value, recurse=recurse), self) @classmethod def __torch_function__( @@ -1714,13 +1700,29 @@ def permute( ) def __repr__(self) -> str: - fields = _td_fields(self) - field_str = indent(f"fields={{{fields}}}", 4 * " ") - batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") - device_str = indent(f"device={self.device}", 4 * " ") - is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ") - string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) - return f"{type(self).__name__}(\n{string})" + visited = {id(self)} + + def _repr(td): + fields = [] + for key, value in td.items(): + if is_tensordict(value): + if id(value) in visited: + fields.append(f"{key}: {value.__class__.__name__}(...)") + else: + visited.add(id(value)) + fields.append(f"{key}: {_repr(value)}") + visited.remove(id(value)) + else: + fields.append(f"{key}: {get_repr(value)}") + fields = indent("\n" + ",\n".join(sorted(fields)), " " * 4) + field_str = indent(f"fields={{{fields}}}", 4 * " ") + batch_size_str = indent(f"batch_size={td.batch_size}", 4 * " ") + device_str = indent(f"device={td.device}", 4 * " ") + is_shared_str = indent(f"is_shared={td.is_shared()}", 4 * " ") + string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) + return f"{td.__class__.__name__}(\n{string})" + + return _repr(self) def all(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if all values are True/non-null in the tensordict. @@ -1741,12 +1743,8 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim is not None: if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.all(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, - ) - return all(value.all() for value in self.values()) + return _apply_safe(lambda _, value: value.all(dim=dim), self) + return all(value.all() for _, value in _items_safe(self)) def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict. @@ -1767,12 +1765,8 @@ def any(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim is not None: if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.any(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, - ) - return any([value.any() for key, value in self.items()]) + return _apply_safe(lambda _, value: value.all(dim=dim), self) + return any(value.any() for _, value in _items_safe(self)) def get_sub_tensordict(self, idx: INDEX_TYPING) -> TensorDictBase: """Returns a SubTensorDict with the desired index.""" @@ -2135,6 +2129,97 @@ def unlock(self): return self +def _apply_safe(fn, tensordict, inplace=False, hook=None): + """ + Safely apply a function to all values in a TensorDict that may contain self-nested + values. + + Args: + fn (Callable[[key, value], Any]): Function to apply to each value. Takes the key + and value at that key as arguments. The key is useful for example when + implementing __eq__, as it lets us do something like + fn=lambda key, value: value == other.get(key). The results of this function + are used to set / update values in the TensorDict. + tensordict (TensorDictBase): The tensordict to apply the function to. + inplace (bool): If True, updates are applied in-place. + hook (Callable[[key, value], None]): A hook called on any tensordicts + encountered during the recursion. Can be used to perform input validation + at each level of the recursion (e.g. checking keys match) + """ + # store ids of values together with the keys they appear under. root tensordict is + # given the "key" None + visited = {id(tensordict): None} + # update will map nested keys to the corresponding key higher up in the tree + # e.g. if we have + # >>> d = {"a": 1, "b": {"c": 0}} + # >>> d["b"]["d"] = d + # then after recursing update should look like {("b", "d"): "b"} + update = {} + + def recurse(td, prefix=()): + if hook is not None: + hook(prefix, td) + + out = ( + td + if inplace + else TensorDict({}, batch_size=td.batch_size, device=td.device) + ) + + for key, value in td.items(): + full_key = prefix + (key,) + if isinstance(value, TensorDictBase): + if id(value) in visited: + # we have already visited this value, capture the key we saw it at + # so that we can restore auto-nesting at the end of recursion + update[full_key] = visited[id(value)] + else: + visited[id(value)] = full_key + out.set(key, recurse(value, prefix=full_key), inplace=inplace) + del visited[id(value)] + else: + out.set(key, fn(full_key, value), inplace=inplace) + return out + + out = recurse(tensordict) + if not inplace: + # only need to restore self-nesting if not inplace + for nested_key, root_key in update.items(): + if root_key is None: + out[nested_key] = out + else: + out[nested_key] = out[root_key] + + return out + + +def _items_safe(tensordict): + """ + Safely iterate over leaf tensors in the presence of self-nesting + + Args: + tensordict (TensorDictBase): TensorDict over which to iterate + """ + # safely iterate over keys and values in a tensordict, accounting for possible + # auto-nesting + visited = {id(tensordict)} + # create a keys view instance we can use to iterate over items + _keys_view = _TensorDictKeysView(None, False, False) + + def recurse(td, prefix=()): + for key, value in _keys_view._items(td): + full_key = prefix + (key if isinstance(key, tuple) else (key,)) + if isinstance(value, (TensorDictBase, KeyedJaggedTensor)): + if id(value) not in visited: + visited.add(id(value)) + yield from recurse(value, prefix=full_key) + visited.remove(id(value)) + else: + yield full_key, value + + yield from recurse(tensordict) + + class TensorDict(TensorDictBase): """A batched dictionary of tensors. @@ -5247,7 +5332,6 @@ def _stack_onto_( list_item: List[COMPATIBLE_TYPES], dim: int, ) -> TensorDictBase: - permute_dims = self.custom_op_kwargs["dims"] inv_permute_dims = np.argsort(permute_dims) new_dim = [i for i, v in enumerate(inv_permute_dims) if v == dim][0] @@ -5273,20 +5357,6 @@ def get_repr(tensor): return f"{tensor.__class__.__name__}({s})" -def _make_repr(key, item, tensordict): - if is_tensordict(type(item)): - return f"{key}: {repr(tensordict.get(key))}" - return f"{key}: {get_repr(item)}" - - -def _td_fields(td: TensorDictBase) -> str: - return indent( - "\n" - + ",\n".join(sorted([_make_repr(key, item, td) for key, item in td.items()])), - 4 * " ", - ) - - def _check_keys( list_of_tensordicts: Sequence[TensorDictBase], strict: bool = False ) -> Set[str]: diff --git a/test/_utils_internal.py b/test/_utils_internal.py index e2760b125..e9518b658 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -175,6 +175,11 @@ def td_reset_bs(self, device): td.batch_size = torch.Size([4, 3, 2, 1]) return td + def autonested_td(self, device): + td = self.td(device) + td["self"] = td + return td + def expand_list(list_of_tensors, *dims): n = len(list_of_tensors) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f1393ccd1..9b6014c11 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -512,6 +512,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation): "nested_td", "permute_td", "nested_stacked_td", + "autonested_td", ], ) @pytest.mark.parametrize("device", get_available_devices()) From 6b9381abe980f76b49e7e62d3f65102bd78f57a0 Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Fri, 3 Feb 2023 09:48:59 +0100 Subject: [PATCH 2/9] added loop detection function for TensorDict --- tensordict/__init__.py | 1 + tensordict/tensordict.py | 19 ++++++++ test/test_tensordict.py | 97 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 253ab0608..54baeba1f 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -9,6 +9,7 @@ merge_tensordicts, SubTensorDict, TensorDict, + detect_loop, ) try: diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index f2f44ba0b..83c94d59b 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5400,3 +5400,22 @@ def _clone_value(value, recurse): return value.clone(recurse=False) else: return value + + +def detect_loop(tensordict: TensorDict) -> bool: + visited = set() + visited.add(id(tensordict)) + + def detect(t_d: TensorDict): + for k, v in t_d.items(): + if id(v) in visited: + return True + visited.add(id(v)) + if isinstance(v, TensorDict): + loop = detect(v) + if loop: + return True + visited.remove(id(v)) + return False + + return detect(tensordict) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f1393ccd1..30bc2491e 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,7 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict +from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop from tensordict.tensordict import ( _stack as stack_td, assert_allclose_td, @@ -3714,6 +3714,101 @@ def test_tensordict_prealloc_nested(): assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) +def test_detect_loop(): + + td_simple = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": torch.randn(4, 3, 2, 1, 5) + }, + batch_size=[4, 3, 2, 1] + ) + assert not detect_loop(td_simple) + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + assert not detect_loop(td_nested) + + td_auto_nested_no_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_1["b"]["d"] = td_auto_nested_no_loop_1["a"] + + assert not detect_loop(td_auto_nested_no_loop_1) + + td_auto_nested_no_loop_2 = TensorDict( + source={ + "a": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + "b": TensorDict( + source={"d": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_2["b"]["e"] = td_auto_nested_no_loop_2["a"] + + assert not detect_loop(td_auto_nested_no_loop_2) + + td_auto_nested_no_loop_3 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_3["b"]["d"] = td_auto_nested_no_loop_3["b"]["c"] + + assert not detect_loop(td_auto_nested_no_loop_3) + + td_auto_nested_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop_1["b"]["d"] = td_auto_nested_loop_1["b"] + + assert detect_loop(td_auto_nested_loop_1) + + td_auto_nested_loop_2 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2 + + assert detect_loop(td_auto_nested_loop_2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 8ac8f034fb2794600f9b8d457860d6206fa75024 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 3 Feb 2023 10:07:48 +0000 Subject: [PATCH 3/9] Incorporate _items_safe into _TensorDictKeysView --- tensordict/tensordict.py | 87 +++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 2aedb97f8..3bda74a6d 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -151,11 +151,17 @@ class _TensorDictKeysView: """ def __init__( - self, tensordict: "TensorDictBase", include_nested: bool, leaves_only: bool + self, + tensordict: "TensorDictBase", + include_nested: bool, + leaves_only: bool, + error_on_loop: bool = True, ): self.tensordict = tensordict self.include_nested = include_nested self.leaves_only = leaves_only + self.error_on_loop = error_on_loop + self.visited = set() def __iter__(self): if not self.include_nested: @@ -168,26 +174,32 @@ def __iter__(self): else: yield from self._keys() else: - try: - yield from self._iter_helper(self.tensordict) - except RecursionError as e: - raise RecursionError( - "Iterating over contents of TensorDict resulted in a recursion " - "error. It's likely that you have auto-nested values, in which " - f"case iteration with `include_nested=True` is not supported. {e}" - ) + yield from self._iter_helper(self.tensordict) def _iter_helper(self, tensordict, prefix=None): for key, value in self._items(tensordict): full_key = self._combine_keys(prefix, key) if not self.leaves_only or not isinstance(value, TensorDictBase): - yield full_key + if id(value) not in self.visited: + yield full_key if isinstance(value, (TensorDictBase, KeyedJaggedTensor)): - yield from tuple( - self._iter_helper( - value, full_key if isinstance(full_key, tuple) else (full_key,) + if id(value) in self.visited: + if self.error_on_loop: + raise RecursionError( + "Iterating over contents of TensorDict resulted in a " + "recursion error. It's likely that you have auto-nested " + "values, in which case iteration with " + "`include_nested=True` is not supported." + ) + else: + self.visited.add(id(value)) + yield from tuple( + self._iter_helper( + value, + full_key if isinstance(full_key, tuple) else (full_key,), + ) ) - ) + self.visited.remove(id(value)) def _combine_keys(self, prefix, key): if prefix is not None: @@ -1181,8 +1193,10 @@ def to_tensordict(self): def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" - for _, value in _items_safe(self): - value.zero_() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ): + self.get(key).zero_() return self def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: @@ -1744,7 +1758,12 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim < 0: dim = self.batch_dims + dim return _apply_safe(lambda _, value: value.all(dim=dim), self) - return all(value.all() for _, value in _items_safe(self)) + return all( + self.get(key).all() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ) + ) def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict. @@ -1766,7 +1785,12 @@ def any(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim < 0: dim = self.batch_dims + dim return _apply_safe(lambda _, value: value.all(dim=dim), self) - return any(value.any() for _, value in _items_safe(self)) + return any( + self.get(key).any() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ) + ) def get_sub_tensordict(self, idx: INDEX_TYPING) -> TensorDictBase: """Returns a SubTensorDict with the desired index.""" @@ -2193,33 +2217,6 @@ def recurse(td, prefix=()): return out -def _items_safe(tensordict): - """ - Safely iterate over leaf tensors in the presence of self-nesting - - Args: - tensordict (TensorDictBase): TensorDict over which to iterate - """ - # safely iterate over keys and values in a tensordict, accounting for possible - # auto-nesting - visited = {id(tensordict)} - # create a keys view instance we can use to iterate over items - _keys_view = _TensorDictKeysView(None, False, False) - - def recurse(td, prefix=()): - for key, value in _keys_view._items(td): - full_key = prefix + (key if isinstance(key, tuple) else (key,)) - if isinstance(value, (TensorDictBase, KeyedJaggedTensor)): - if id(value) not in visited: - visited.add(id(value)) - yield from recurse(value, prefix=full_key) - visited.remove(id(value)) - else: - yield full_key, value - - yield from recurse(tensordict) - - class TensorDict(TensorDictBase): """A batched dictionary of tensors. From 7fcb560fc7c4021ab527a37390c8e17946e31db5 Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Fri, 3 Feb 2023 16:09:34 +0100 Subject: [PATCH 4/9] Added test for TensorDict view iteration behavior --- tensordict/__init__.py | 2 + test/test_tensordict.py | 88 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 54baeba1f..9c9bb58f4 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -10,6 +10,7 @@ SubTensorDict, TensorDict, detect_loop, + _TensorDictKeysView, ) try: @@ -24,4 +25,5 @@ "TensorDict", "merge_tensordicts", "set_transfer_ownership", + "_TensorDictKeysView" ] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index dd541d7fe..4bb8f9592 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,7 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop +from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop, _TensorDictKeysView from tensordict.tensordict import ( _stack as stack_td, assert_allclose_td, @@ -3715,6 +3715,90 @@ def test_tensordict_prealloc_nested(): assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) +def test_tensordict_view_iteration(): + + td_simple = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": torch.randn(4, 3, 2, 1, 5) + }, + batch_size=[4, 3, 2, 1] + ) + + view = _TensorDictKeysView(tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert "b" in keys + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert ("b", "c") in keys + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 1 + assert "a" in keys + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True) + keys = list(view) + assert len(keys) == 3 + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys + + # We are not considering loops given by referencing non Dicts (leaf nodes) from two different key sequences + + td_auto_nested_loop = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop["b"]["d"] = td_auto_nested_loop + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=False, + error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert 'a' in keys + assert 'b' in keys + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=True, + error_on_loop=True) + keys = list(view) + assert len(keys) == 1 + assert 'a' in keys + + with pytest.raises(RecursionError): + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=True, + error_on_loop=True) + list(view) + + with pytest.raises(RecursionError): + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, + error_on_loop=True) + list(view) + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, + error_on_loop=False) + # TODO Specify this undefined behavior better + def test_detect_loop(): td_simple = TensorDict( @@ -3810,6 +3894,8 @@ def test_detect_loop(): assert detect_loop(td_auto_nested_loop_2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 8a2619a18b1c06c4ceda2e80c5e7b028c0b3ca93 Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Fri, 3 Feb 2023 09:48:59 +0100 Subject: [PATCH 5/9] added loop detection function for TensorDict --- tensordict/__init__.py | 1 + tensordict/tensordict.py | 19 ++++++++ test/test_tensordict.py | 97 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 253ab0608..54baeba1f 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -9,6 +9,7 @@ merge_tensordicts, SubTensorDict, TensorDict, + detect_loop, ) try: diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 3bda74a6d..930c1f149 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5467,3 +5467,22 @@ def _clone_value(value, recurse): return value.clone(recurse=False) else: return value + + +def detect_loop(tensordict: TensorDict) -> bool: + visited = set() + visited.add(id(tensordict)) + + def detect(t_d: TensorDict): + for k, v in t_d.items(): + if id(v) in visited: + return True + visited.add(id(v)) + if isinstance(v, TensorDict): + loop = detect(v) + if loop: + return True + visited.remove(id(v)) + return False + + return detect(tensordict) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 9b6014c11..dd541d7fe 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,7 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict +from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop from tensordict.tensordict import ( _stack as stack_td, assert_allclose_td, @@ -3715,6 +3715,101 @@ def test_tensordict_prealloc_nested(): assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) +def test_detect_loop(): + + td_simple = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": torch.randn(4, 3, 2, 1, 5) + }, + batch_size=[4, 3, 2, 1] + ) + assert not detect_loop(td_simple) + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + assert not detect_loop(td_nested) + + td_auto_nested_no_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_1["b"]["d"] = td_auto_nested_no_loop_1["a"] + + assert not detect_loop(td_auto_nested_no_loop_1) + + td_auto_nested_no_loop_2 = TensorDict( + source={ + "a": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + "b": TensorDict( + source={"d": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_2["b"]["e"] = td_auto_nested_no_loop_2["a"] + + assert not detect_loop(td_auto_nested_no_loop_2) + + td_auto_nested_no_loop_3 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_no_loop_3["b"]["d"] = td_auto_nested_no_loop_3["b"]["c"] + + assert not detect_loop(td_auto_nested_no_loop_3) + + td_auto_nested_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop_1["b"]["d"] = td_auto_nested_loop_1["b"] + + assert detect_loop(td_auto_nested_loop_1) + + td_auto_nested_loop_2 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, + batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2 + + assert detect_loop(td_auto_nested_loop_2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From a4529f620f8b9ca0b1570fd1ae82a97c346e02eb Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Fri, 3 Feb 2023 16:09:34 +0100 Subject: [PATCH 6/9] Added test for TensorDict view iteration behavior --- tensordict/__init__.py | 2 + test/test_tensordict.py | 88 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 54baeba1f..9c9bb58f4 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -10,6 +10,7 @@ SubTensorDict, TensorDict, detect_loop, + _TensorDictKeysView, ) try: @@ -24,4 +25,5 @@ "TensorDict", "merge_tensordicts", "set_transfer_ownership", + "_TensorDictKeysView" ] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index dd541d7fe..4bb8f9592 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,7 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop +from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop, _TensorDictKeysView from tensordict.tensordict import ( _stack as stack_td, assert_allclose_td, @@ -3715,6 +3715,90 @@ def test_tensordict_prealloc_nested(): assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) +def test_tensordict_view_iteration(): + + td_simple = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": torch.randn(4, 3, 2, 1, 5) + }, + batch_size=[4, 3, 2, 1] + ) + + view = _TensorDictKeysView(tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert "b" in keys + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert ("b", "c") in keys + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True) + keys = list(view) + assert len(keys) == 1 + assert "a" in keys + + view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True) + keys = list(view) + assert len(keys) == 3 + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys + + # We are not considering loops given by referencing non Dicts (leaf nodes) from two different key sequences + + td_auto_nested_loop = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop["b"]["d"] = td_auto_nested_loop + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=False, + error_on_loop=True) + keys = list(view) + assert len(keys) == 2 + assert 'a' in keys + assert 'b' in keys + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=True, + error_on_loop=True) + keys = list(view) + assert len(keys) == 1 + assert 'a' in keys + + with pytest.raises(RecursionError): + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=True, + error_on_loop=True) + list(view) + + with pytest.raises(RecursionError): + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, + error_on_loop=True) + list(view) + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, + error_on_loop=False) + # TODO Specify this undefined behavior better + def test_detect_loop(): td_simple = TensorDict( @@ -3810,6 +3894,8 @@ def test_detect_loop(): assert detect_loop(td_auto_nested_loop_2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 0870fc4484cb92b3b78fcac9db921dcd33b1cc6b Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Mon, 6 Feb 2023 11:00:03 +0100 Subject: [PATCH 7/9] Added docstring to detect loop function --- tensordict/tensordict.py | 24 ++++ test/test_tensordict.py | 271 +++++++++++++++++++-------------------- 2 files changed, 158 insertions(+), 137 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 930c1f149..b52ca7bc5 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5470,6 +5470,30 @@ def _clone_value(value, recurse): def detect_loop(tensordict: TensorDict) -> bool: + """ + This helper function detects the presence of an auto nesting loop inside + a TensorDict object. Auto nesting appears when a key of TensorDict references + another TensorDict and initiates a recursive infinite loop. It returns True + if at least one loop is found, otherwise returns False. An example is: + + >>> td = TensorDict( + >>> source={ + >>> "a": TensorDict( + >>> source={"b": torch.randn(4, 3, 1)}, + >>> batch_size=[4, 3, 1]), + >>> }, + >>> batch_size=[4, 3, 1] + >>> ) + >>> td["b"]["c"] = td + >>> + >>> print(detect_loop(td)) + True + + Args: + tensordict (TensorDict): The Tensordict Object to check for autonested loops presence. + Returns + bool: True if one loop is found, otherwise False + """ visited = set() visited.add(id(tensordict)) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4bb8f9592..96c15ebdd 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -43,7 +43,7 @@ def test_tensordict_set(device): assert td.get("key_device").device == torch.device(device) with pytest.raises( - AttributeError, match="for populating tensordict with new key-value pair" + AttributeError, match="for populating tensordict with new key-value pair" ): td.set_("smartypants", torch.ones(4, 5, device="cpu", dtype=torch.double)) # test set_at_ @@ -111,7 +111,7 @@ def test_tensordict_error_messages(device): td2 = TensorDict({"sub": sub2}, [2]) with pytest.raises( - RuntimeError, match='tensors on different devices at key "sub" / "a"' + RuntimeError, match='tensors on different devices at key "sub" / "a"' ): torch.cat([td1, td2], 0) @@ -155,13 +155,13 @@ def test_tensordict_indexing(device): td_reconstruct = stack_td(list(td), 0, contiguous=False) assert ( - td_reconstruct == td + td_reconstruct == td ).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}" superlist = [stack_td(list(_td), 0, contiguous=False) for _td in td] td_reconstruct = stack_td(superlist, 0, contiguous=False) assert ( - td_reconstruct == td + td_reconstruct == td ).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}" x = torch.randn(4, 5, device=device) @@ -219,8 +219,8 @@ def test_subtensordict_construction(device): assert std_control.get_parent_tensordict() is td assert ( - std_control.get_parent_tensordict() - is std2.get_parent_tensordict().get_parent_tensordict() + std_control.get_parent_tensordict() + is std2.get_parent_tensordict().get_parent_tensordict() ) @@ -257,7 +257,7 @@ def test_unbind_td(device): td = TensorDict(batch_size=(4, 5), source=d) td_unbind = torch.unbind(td, dim=1) assert ( - td_unbind[0].batch_size == td[:, 0].batch_size + td_unbind[0].batch_size == td[:, 0].batch_size ), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}" @@ -432,8 +432,8 @@ def test_permute_with_tensordict_operations(device): "c": torch.randn(4, 5, 6, 7, device=device), } td1 = TensorDict(batch_size=(4, 5, 6, 7), source=d)[ - :, :, :, torch.tensor([1, 2]) - ].permute(3, 2, 1, 0) + :, :, :, torch.tensor([1, 2]) + ].permute(3, 2, 1, 0) assert td1.shape == torch.Size((2, 6, 5, 4)) d = { @@ -467,8 +467,8 @@ def test_inferred_view_size(): ((..., 0), (slice(None), slice(None), slice(None), slice(None), 0)), ((0, ...), (0, slice(None), slice(None), slice(None), slice(None))), ( - (slice(1, 2), ...), - (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), + (slice(1, 2), ...), + (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), ), ], ) @@ -552,14 +552,14 @@ def test_select(self, td_name, device, strict, inplace): if td_name == "saved_td": assert (len(list(td2.keys())) == len(keys)) and ("a" in td2.keys()) assert (len(list(td2.clone().keys())) == len(keys)) and ( - "a" in td2.clone().keys() + "a" in td2.clone().keys() ) else: assert (len(list(td2.keys(True, True))) == len(keys)) and ( - "a" in td2.keys() + "a" in td2.keys() ) assert (len(list(td2.clone().keys(True, True))) == len(keys)) and ( - "a" in td2.clone().keys() + "a" in td2.clone().keys() ) @pytest.mark.parametrize("strict", [True, False]) @@ -580,11 +580,11 @@ def test_exclude(self, td_name, device): td2 = td.exclude("a") assert td2 is not td assert ( - len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() + len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() ) assert ( - len(list(td2.clone().keys())) == len(list(td.keys())) - 1 - and "a" not in td2.clone().keys() + len(list(td2.clone().keys())) == len(list(td.keys())) - 1 + and "a" not in td2.clone().keys() ) td2 = td.exclude("a", inplace=True) @@ -594,8 +594,8 @@ def test_assert(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) with pytest.raises( - ValueError, - match="Converting a tensordict to boolean value is not permitted", + ValueError, + match="Converting a tensordict to boolean value is not permitted", ): assert td @@ -648,8 +648,8 @@ def test_set_unexisting(self, td_name, device): td = getattr(self, td_name)(device) if td.is_locked: with pytest.raises( - RuntimeError, - match="Cannot modify locked TensorDict. For in-place modification", + RuntimeError, + match="Cannot modify locked TensorDict. For in-place modification", ): td.set("z", torch.ones_like(td.get("a"))) else: @@ -1012,11 +1012,11 @@ def test_exclude_nested(self, td_name, device, nested): assert "a" in td.keys() assert "a" not in td2.keys() if td_name not in ( - "sub_td", - "sub_td2", - "unsqueezed_td", - "squeezed_td", - "permute_td", + "sub_td", + "sub_td2", + "unsqueezed_td", + "squeezed_td", + "permute_td", ): # TODO: document this as an edge-case: with a sub-tensordict, exclude acts on the parent tensordict # perhaps exclude should return an error in these cases? @@ -1148,11 +1148,11 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): td_stack = torch.stack([td1, td2], dim) # get will fail with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.get(key) with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack[key] if dim in (0, -5): @@ -1162,17 +1162,17 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): else: # if the stack_dim is not zero, then calling get_nestedtensor is disallowed with pytest.raises( - RuntimeError, - match="LazyStackedTensorDict.get_nestedtensor can only be called " - "when the stack_dim is 0.", + RuntimeError, + match="LazyStackedTensorDict.get_nestedtensor can only be called " + "when the stack_dim is 0.", ): td_stack.get_nestedtensor(key) with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.contiguous() with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.to_tensordict() # cloning is type-preserving: we can do that operation @@ -1184,14 +1184,14 @@ def test_clone_td(self, td_name, device): assert (torch.clone(td) == td).all() assert td.batch_size == torch.clone(td).batch_size if td_name in ( - "stacked_td", - "nested_stacked_td", - "saved_td", - "squeezed_td", - "unsqueezed_td", - "sub_td", - "sub_td2", - "permute_td", + "stacked_td", + "nested_stacked_td", + "saved_td", + "squeezed_td", + "unsqueezed_td", + "sub_td", + "sub_td2", + "permute_td", ): with pytest.raises(AssertionError): assert td.clone(recurse=False).get("a") is td.get("a") @@ -1320,7 +1320,7 @@ def test_getitem_range(self, td_name, device): # weird as the index to the stacking dimension we'll get the error idx = (slice(None),) * td.stack_dim + ({1, 2, 3},) with pytest.raises( - TypeError, match="Invalid index used for stack dimension." + TypeError, match="Invalid index used for stack dimension." ): td[idx] @@ -1578,7 +1578,7 @@ def test_nested_td_index(self, td_name, device): td.set("sub_td", sub_td) assert (td["sub_td", "sub_sub_td", "a"] == 0).all() assert ( - td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] + td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] ).all() a = torch.ones_like(a) @@ -1586,7 +1586,7 @@ def test_nested_td_index(self, td_name, device): td["sub_td", "sub_sub_td"] = other_sub_sub_td assert (td["sub_td", "sub_sub_td", "a"] == 1).all() assert ( - td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] + td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] ).all() b = torch.ones_like(a) @@ -1598,7 +1598,7 @@ def test_nested_td_index(self, td_name, device): td["sub_td", "sub_sub_td"] = other_sub_sub_td assert (td["sub_td", "sub_sub_td", "b"] == 1).all() assert ( - td["sub_td"]["sub_sub_td"]["b"] == td["sub_td", "sub_sub_td", "b"] + td["sub_td"]["sub_sub_td"]["b"] == td["sub_td", "sub_sub_td", "b"] ).all() @pytest.mark.parametrize("inplace", [True, False]) @@ -1633,8 +1633,8 @@ def test_flatten_keys(self, td_name, device, inplace, separator): for value in td_flatten.values(): assert not isinstance(value, TensorDictBase) assert ( - separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) - in td_flatten.keys() + separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) + in td_flatten.keys() ) if inplace: assert td_flatten is td @@ -1689,8 +1689,8 @@ def test_memmap_(self, td_name, device): td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( - RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", ): td.memmap_() else: @@ -1706,8 +1706,8 @@ def test_memmap_prefix(self, td_name, device, tmpdir): td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( - RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", ): td.memmap_(tmpdir / "tensordict") return @@ -1747,7 +1747,7 @@ def test_memmap_existing(self, td_name, device, copy_existing, tmpdir): assert (td == td3).all() else: with pytest.raises( - RuntimeError, match="TensorDict already contains MemmapTensors" + RuntimeError, match="TensorDict already contains MemmapTensors" ): # calling memmap_ with prefix that is different to contents gives error td.memmap_(prefix=tmpdir / "tensordict2") @@ -1921,11 +1921,11 @@ def test_pop(self, td_name, device): assert (out == default).all() with pytest.raises( - KeyError, - match=re.escape( - "You are trying to pop key `z` which is not in dict" - "without providing default value" - ), + KeyError, + match=re.escape( + "You are trying to pop key `z` which is not in dict" + "without providing default value" + ), ): td.pop("z") @@ -2224,7 +2224,7 @@ def test_repr_indexed_stacked_tensordict(self, device, dtype, index): def test_repr_device_to_device(self, device, dtype, device_cast): td = self.td(device, dtype) if (device_cast is None and (torch.cuda.device_count() > 0)) or ( - device_cast is not None and device_cast.type == "cuda" + device_cast is not None and device_cast.type == "cuda" ): is_shared = True else: @@ -2356,11 +2356,11 @@ def test_batchsize_reset(): # incompatible size with pytest.raises( - RuntimeError, - match=re.escape( - "the tensor a has shape torch.Size([3, 4, 5, " - "6]) which is incompatible with the new shape torch.Size([3, 5])" - ), + RuntimeError, + match=re.escape( + "the tensor a has shape torch.Size([3, 4, 5, " + "6]) which is incompatible with the new shape torch.Size([3, 5])" + ), ): td.batch_size = [3, 5] @@ -2370,8 +2370,8 @@ def test_batchsize_reset(): # test index td[torch.tensor([1, 2])] with pytest.raises( - IndexError, - match=re.escape("too many indices for tensor of dimension 1"), + IndexError, + match=re.escape("too many indices for tensor of dimension 1"), ): td[:, 0] @@ -2383,22 +2383,22 @@ def test_batchsize_reset(): td.set("c", torch.randn(3, 4, 5, 6)) with pytest.raises( - RuntimeError, - match=re.escape( - "batch dimension mismatch, " - "got self.batch_size=torch.Size([3, 4, 5]) and tensor.shape[:self.batch_dims]=torch.Size([3, 4, 2])" - ), + RuntimeError, + match=re.escape( + "batch dimension mismatch, " + "got self.batch_size=torch.Size([3, 4, 5]) and tensor.shape[:self.batch_dims]=torch.Size([3, 4, 2])" + ), ): td.set("d", torch.randn(3, 4, 2)) # test that lazy tds return an exception td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)]) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation " - "of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation " + "of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): td_stack.batch_size = [2] td_stack.to_tensordict().batch_size = [2] @@ -2406,10 +2406,10 @@ def test_batchsize_reset(): td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) subtd = td.get_sub_tensordict((slice(None), torch.tensor([1, 2]))) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): subtd.batch_size = [3, 2] subtd.to_tensordict().batch_size = [3, 2] @@ -2417,10 +2417,10 @@ def test_batchsize_reset(): td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) td_u = td.unsqueeze(0) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): td_u.batch_size = [1] td_u.to_tensordict().batch_size = [1] @@ -2534,7 +2534,7 @@ def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordi a = torch.ones(2) * val tensordict.set_("a", a) assert ( - tensordict.get("a") == a + tensordict.get("a") == a ).all(), f'found {a} and {tensordict.get("a")}' command_pipe_child.send("done") elif cmd == "set_done": @@ -2689,10 +2689,10 @@ def test_mp(td_type): (torch.tensor([1, 2, 3]),), ([1, 2, 3]), ( - torch.tensor([1, 2, 3]), - torch.tensor([2, 3, 4]), - torch.tensor([0, 10, 2]), - torch.tensor([2, 4, 1]), + torch.tensor([1, 2, 3]), + torch.tensor([2, 3, 4]), + torch.tensor([0, 10, 2]), + torch.tensor([2, 4, 1]), ), torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(), torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(), @@ -2978,7 +2978,7 @@ def test_keys_view(): assert ("a", "c", "b") not in tensordict.keys(include_nested=True) with pytest.raises( - TypeError, match="checks with tuples of strings is only supported" + TypeError, match="checks with tuples of strings is only supported" ): ("a", "b", "c") in tensordict.keys() # noqa: B015 @@ -3006,8 +3006,8 @@ def test_error_on_contains(): {"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)}, [1] ) with pytest.raises( - NotImplementedError, - match="TensorDict does not support membership checks with the `in` keyword", + NotImplementedError, + match="TensorDict does not support membership checks with the `in` keyword", ): "random_string" in td # noqa: B015 @@ -3145,57 +3145,57 @@ def test_flatten_unflatten_key_collision(inplace, separator): ) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td1.flatten_keys(separator) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td2.flatten_keys(separator) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td3.flatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td1.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td2.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td3.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td4.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td5.unflatten_keys(separator) @@ -3405,8 +3405,8 @@ def test_stacked_td(self, stack_dim, device): tensordicts0.zero_() assert (sub_td[item].get("key1") == sub_td.get("key1")[item]).all() assert ( - sub_td.contiguous()[item].get("key1") - == sub_td.contiguous().get("key1")[item] + sub_td.contiguous()[item].get("key1") + == sub_td.contiguous().get("key1")[item] ).all() assert (sub_td.contiguous().get("key1")[item] == 0).all() @@ -3415,7 +3415,7 @@ def test_stacked_td(self, stack_dim, device): tensordicts1.zero_() assert (std2[item].get("key1") == std2.get("key1")[item]).all() assert ( - std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] + std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] ).all() assert (std2.contiguous().get("key1")[item] == 0).all() @@ -3424,7 +3424,7 @@ def test_stacked_td(self, stack_dim, device): item = (*[slice(None) for _ in range(stack_dim)], 2) assert (std3[item].get("key1") == std3.get("key1")[item]).all() assert ( - std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] + std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] ).all() assert (std3.contiguous().get("key1")[item] == 0).all() @@ -3433,7 +3433,7 @@ def test_stacked_td(self, stack_dim, device): item = (*[slice(None) for _ in range(stack_dim)], 3) assert (std4[item].get("key1") == std4.get("key1")[item]).all() assert ( - std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] + std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] ).all() assert (std4.contiguous().get("key1")[item] == 0).all() @@ -3452,9 +3452,9 @@ def test_stacked_indexing(self, device, stack_dim): tds = torch.stack(list(tensordict.unbind(stack_dim)), stack_dim) for item, expected_shape in ( - ((2, 2), torch.Size([5])), - ((slice(1, 2), 2), torch.Size([1, 5])), - ((..., 2), torch.Size([3, 4])), + ((2, 2), torch.Size([5])), + ((slice(1, 2), 2), torch.Size([1, 5])), + ((..., 2), torch.Size([3, 4])), ): assert tds[item].batch_size == expected_shape assert (tds[item].get("a") == tds.get("a")[item]).all() @@ -3505,7 +3505,7 @@ def test_lazy_stacked_insert(self, dim, index, device): torch.testing.assert_close(lstd["a"], t) with pytest.raises( - TypeError, match="Expected new value to be TensorDictBase instance" + TypeError, match="Expected new value to be TensorDictBase instance" ): lstd.insert(index, torch.rand(10)) @@ -3526,8 +3526,8 @@ def test_lazy_stacked_contains(self): assert td.clone() not in lstd with pytest.raises( - NotImplementedError, - match="TensorDict does not support membership checks with the `in` keyword", + NotImplementedError, + match="TensorDict does not support membership checks with the `in` keyword", ): "random_string" in lstd # noqa: B015 @@ -3559,7 +3559,7 @@ def test_lazy_stacked_append(self, dim, device): torch.testing.assert_close(lstd["a"], t) with pytest.raises( - TypeError, match="Expected new value to be TensorDictBase instance" + TypeError, match="Expected new value to be TensorDictBase instance" ): lstd.append(torch.rand(10)) @@ -3590,14 +3590,14 @@ def test_stack_update_heter_stacked_td(self, stack_dim): td_b = td_a.clone() td_a.update(td_b) with pytest.raises( - RuntimeError, - match="Found more than one unique shape in the tensors to be stacked", + RuntimeError, + match="Found more than one unique shape in the tensors to be stacked", ): td_a.update(td_b.to_tensordict()) td_a.update_(td_b) with pytest.raises( - RuntimeError, - match="Found more than one unique shape in the tensors to be stacked", + RuntimeError, + match="Found more than one unique shape in the tensors to be stacked", ): td_a.update_(td_b.to_tensordict()) @@ -3638,7 +3638,7 @@ def test_inplace(self, save_name): assert isinstance(td_dest["b", "c"], MemmapTensor) def test_update( - self, + self, ): tensordict = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3)}}, []) state = {"state": tensordict} @@ -3696,8 +3696,8 @@ def test_tensordict_prealloc_nested(): buffer[0] = td_0 buffer[1] = td_1 assert ( - repr(buffer) - == """TensorDict( + repr(buffer) + == """TensorDict( fields={ agent.obs: TensorDict( fields={ @@ -3716,7 +3716,6 @@ def test_tensordict_prealloc_nested(): def test_tensordict_view_iteration(): - td_simple = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -3799,8 +3798,8 @@ def test_tensordict_view_iteration(): error_on_loop=False) # TODO Specify this undefined behavior better -def test_detect_loop(): +def test_detect_loop(): td_simple = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -3894,8 +3893,6 @@ def test_detect_loop(): assert detect_loop(td_auto_nested_loop_2) - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 852bb4dd2a65bf833ab17a2b5c7ed36f06c71043 Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Mon, 6 Feb 2023 12:13:19 +0100 Subject: [PATCH 8/9] Completed test for iteration in TensorDictKeysViews --- test/test_tensordict.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 96c15ebdd..247087d86 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3796,7 +3796,41 @@ def test_tensordict_view_iteration(): view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, error_on_loop=False) - # TODO Specify this undefined behavior better + + keys = list(view) + assert len(keys) == 3 + assert 'a' in keys + assert 'b' in keys + assert ('b', 'c') in keys + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=True, + error_on_loop=False) + + keys = list(view) + assert len(keys) == 2 + assert 'a' in keys + assert ('b', 'c') in keys + + td_auto_nested_loop_2 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict( + {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1] + ) + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2['b'] + + view = _TensorDictKeysView(tensordict=td_auto_nested_loop_2, include_nested=True, leaves_only=False, + error_on_loop=False) + + keys = list(view) + assert len(keys) == 3 + assert 'a' in keys + assert 'b' in keys + assert ('b', 'c') in keys + def test_detect_loop(): From 0a71de054d497f3aa0105c200f5b73645961f1a3 Mon Sep 17 00:00:00 2001 From: Ruggero Vasile Date: Mon, 6 Feb 2023 12:19:40 +0100 Subject: [PATCH 9/9] pre-commit format --- tensordict/__init__.py | 6 +- tensordict/tensordict.py | 42 ++-- test/test_tensordict.py | 438 ++++++++++++++++++++------------------- 3 files changed, 253 insertions(+), 233 deletions(-) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 9c9bb58f4..3f7bad24a 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -5,12 +5,12 @@ from .memmap import MemmapTensor, set_transfer_ownership from .tensordict import ( + _TensorDictKeysView, + detect_loop, LazyStackedTensorDict, merge_tensordicts, SubTensorDict, TensorDict, - detect_loop, - _TensorDictKeysView, ) try: @@ -25,5 +25,5 @@ "TensorDict", "merge_tensordicts", "set_transfer_ownership", - "_TensorDictKeysView" + "_TensorDictKeysView", ] diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index c0f8f3ac1..dadbb2342 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -5473,28 +5473,28 @@ def _clone_value(value, recurse): def detect_loop(tensordict: TensorDict) -> bool: """ - This helper function detects the presence of an auto nesting loop inside - a TensorDict object. Auto nesting appears when a key of TensorDict references - another TensorDict and initiates a recursive infinite loop. It returns True - if at least one loop is found, otherwise returns False. An example is: + This helper function detects the presence of an auto nesting loop inside + a TensorDict object. Auto nesting appears when a key of TensorDict references + another TensorDict and initiates a recursive infinite loop. It returns True + if at least one loop is found, otherwise returns False. An example is: + + >>> td = TensorDict( + >>> source={ + >>> "a": TensorDict( + >>> source={"b": torch.randn(4, 3, 1)}, + >>> batch_size=[4, 3, 1]), + >>> }, + >>> batch_size=[4, 3, 1] + >>> ) + >>> td["b"]["c"] = td + >>> + >>> print(detect_loop(td)) + True - >>> td = TensorDict( - >>> source={ - >>> "a": TensorDict( - >>> source={"b": torch.randn(4, 3, 1)}, - >>> batch_size=[4, 3, 1]), - >>> }, - >>> batch_size=[4, 3, 1] - >>> ) - >>> td["b"]["c"] = td - >>> - >>> print(detect_loop(td)) - True - - Args: - tensordict (TensorDict): The Tensordict Object to check for autonested loops presence. - Returns - bool: True if one loop is found, otherwise False + Args: + tensordict (TensorDict): The Tensordict Object to check for autonested loops presence. + Returns + bool: True if one loop is found, otherwise False """ visited = set() visited.add(id(tensordict)) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 247087d86..713cc1ddc 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,7 +12,13 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict, detect_loop, _TensorDictKeysView +from tensordict import ( + _TensorDictKeysView, + detect_loop, + LazyStackedTensorDict, + MemmapTensor, + TensorDict, +) from tensordict.tensordict import ( _stack as stack_td, assert_allclose_td, @@ -43,7 +49,7 @@ def test_tensordict_set(device): assert td.get("key_device").device == torch.device(device) with pytest.raises( - AttributeError, match="for populating tensordict with new key-value pair" + AttributeError, match="for populating tensordict with new key-value pair" ): td.set_("smartypants", torch.ones(4, 5, device="cpu", dtype=torch.double)) # test set_at_ @@ -111,7 +117,7 @@ def test_tensordict_error_messages(device): td2 = TensorDict({"sub": sub2}, [2]) with pytest.raises( - RuntimeError, match='tensors on different devices at key "sub" / "a"' + RuntimeError, match='tensors on different devices at key "sub" / "a"' ): torch.cat([td1, td2], 0) @@ -155,13 +161,13 @@ def test_tensordict_indexing(device): td_reconstruct = stack_td(list(td), 0, contiguous=False) assert ( - td_reconstruct == td + td_reconstruct == td ).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}" superlist = [stack_td(list(_td), 0, contiguous=False) for _td in td] td_reconstruct = stack_td(superlist, 0, contiguous=False) assert ( - td_reconstruct == td + td_reconstruct == td ).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}" x = torch.randn(4, 5, device=device) @@ -219,8 +225,8 @@ def test_subtensordict_construction(device): assert std_control.get_parent_tensordict() is td assert ( - std_control.get_parent_tensordict() - is std2.get_parent_tensordict().get_parent_tensordict() + std_control.get_parent_tensordict() + is std2.get_parent_tensordict().get_parent_tensordict() ) @@ -257,7 +263,7 @@ def test_unbind_td(device): td = TensorDict(batch_size=(4, 5), source=d) td_unbind = torch.unbind(td, dim=1) assert ( - td_unbind[0].batch_size == td[:, 0].batch_size + td_unbind[0].batch_size == td[:, 0].batch_size ), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}" @@ -432,8 +438,8 @@ def test_permute_with_tensordict_operations(device): "c": torch.randn(4, 5, 6, 7, device=device), } td1 = TensorDict(batch_size=(4, 5, 6, 7), source=d)[ - :, :, :, torch.tensor([1, 2]) - ].permute(3, 2, 1, 0) + :, :, :, torch.tensor([1, 2]) + ].permute(3, 2, 1, 0) assert td1.shape == torch.Size((2, 6, 5, 4)) d = { @@ -467,8 +473,8 @@ def test_inferred_view_size(): ((..., 0), (slice(None), slice(None), slice(None), slice(None), 0)), ((0, ...), (0, slice(None), slice(None), slice(None), slice(None))), ( - (slice(1, 2), ...), - (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), + (slice(1, 2), ...), + (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), ), ], ) @@ -552,14 +558,14 @@ def test_select(self, td_name, device, strict, inplace): if td_name == "saved_td": assert (len(list(td2.keys())) == len(keys)) and ("a" in td2.keys()) assert (len(list(td2.clone().keys())) == len(keys)) and ( - "a" in td2.clone().keys() + "a" in td2.clone().keys() ) else: assert (len(list(td2.keys(True, True))) == len(keys)) and ( - "a" in td2.keys() + "a" in td2.keys() ) assert (len(list(td2.clone().keys(True, True))) == len(keys)) and ( - "a" in td2.clone().keys() + "a" in td2.clone().keys() ) @pytest.mark.parametrize("strict", [True, False]) @@ -580,11 +586,11 @@ def test_exclude(self, td_name, device): td2 = td.exclude("a") assert td2 is not td assert ( - len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() + len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() ) assert ( - len(list(td2.clone().keys())) == len(list(td.keys())) - 1 - and "a" not in td2.clone().keys() + len(list(td2.clone().keys())) == len(list(td.keys())) - 1 + and "a" not in td2.clone().keys() ) td2 = td.exclude("a", inplace=True) @@ -594,8 +600,8 @@ def test_assert(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) with pytest.raises( - ValueError, - match="Converting a tensordict to boolean value is not permitted", + ValueError, + match="Converting a tensordict to boolean value is not permitted", ): assert td @@ -648,8 +654,8 @@ def test_set_unexisting(self, td_name, device): td = getattr(self, td_name)(device) if td.is_locked: with pytest.raises( - RuntimeError, - match="Cannot modify locked TensorDict. For in-place modification", + RuntimeError, + match="Cannot modify locked TensorDict. For in-place modification", ): td.set("z", torch.ones_like(td.get("a"))) else: @@ -1012,11 +1018,11 @@ def test_exclude_nested(self, td_name, device, nested): assert "a" in td.keys() assert "a" not in td2.keys() if td_name not in ( - "sub_td", - "sub_td2", - "unsqueezed_td", - "squeezed_td", - "permute_td", + "sub_td", + "sub_td2", + "unsqueezed_td", + "squeezed_td", + "permute_td", ): # TODO: document this as an edge-case: with a sub-tensordict, exclude acts on the parent tensordict # perhaps exclude should return an error in these cases? @@ -1148,11 +1154,11 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): td_stack = torch.stack([td1, td2], dim) # get will fail with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.get(key) with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack[key] if dim in (0, -5): @@ -1162,17 +1168,17 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): else: # if the stack_dim is not zero, then calling get_nestedtensor is disallowed with pytest.raises( - RuntimeError, - match="LazyStackedTensorDict.get_nestedtensor can only be called " - "when the stack_dim is 0.", + RuntimeError, + match="LazyStackedTensorDict.get_nestedtensor can only be called " + "when the stack_dim is 0.", ): td_stack.get_nestedtensor(key) with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.contiguous() with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" + RuntimeError, match="Found more than one unique shape in the tensors" ): td_stack.to_tensordict() # cloning is type-preserving: we can do that operation @@ -1184,14 +1190,14 @@ def test_clone_td(self, td_name, device): assert (torch.clone(td) == td).all() assert td.batch_size == torch.clone(td).batch_size if td_name in ( - "stacked_td", - "nested_stacked_td", - "saved_td", - "squeezed_td", - "unsqueezed_td", - "sub_td", - "sub_td2", - "permute_td", + "stacked_td", + "nested_stacked_td", + "saved_td", + "squeezed_td", + "unsqueezed_td", + "sub_td", + "sub_td2", + "permute_td", ): with pytest.raises(AssertionError): assert td.clone(recurse=False).get("a") is td.get("a") @@ -1320,7 +1326,7 @@ def test_getitem_range(self, td_name, device): # weird as the index to the stacking dimension we'll get the error idx = (slice(None),) * td.stack_dim + ({1, 2, 3},) with pytest.raises( - TypeError, match="Invalid index used for stack dimension." + TypeError, match="Invalid index used for stack dimension." ): td[idx] @@ -1578,7 +1584,7 @@ def test_nested_td_index(self, td_name, device): td.set("sub_td", sub_td) assert (td["sub_td", "sub_sub_td", "a"] == 0).all() assert ( - td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] + td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] ).all() a = torch.ones_like(a) @@ -1586,7 +1592,7 @@ def test_nested_td_index(self, td_name, device): td["sub_td", "sub_sub_td"] = other_sub_sub_td assert (td["sub_td", "sub_sub_td", "a"] == 1).all() assert ( - td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] + td["sub_td"]["sub_sub_td"]["a"] == td["sub_td", "sub_sub_td", "a"] ).all() b = torch.ones_like(a) @@ -1598,7 +1604,7 @@ def test_nested_td_index(self, td_name, device): td["sub_td", "sub_sub_td"] = other_sub_sub_td assert (td["sub_td", "sub_sub_td", "b"] == 1).all() assert ( - td["sub_td"]["sub_sub_td"]["b"] == td["sub_td", "sub_sub_td", "b"] + td["sub_td"]["sub_sub_td"]["b"] == td["sub_td", "sub_sub_td", "b"] ).all() @pytest.mark.parametrize("inplace", [True, False]) @@ -1633,8 +1639,8 @@ def test_flatten_keys(self, td_name, device, inplace, separator): for value in td_flatten.values(): assert not isinstance(value, TensorDictBase) assert ( - separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) - in td_flatten.keys() + separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) + in td_flatten.keys() ) if inplace: assert td_flatten is td @@ -1689,8 +1695,8 @@ def test_memmap_(self, td_name, device): td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( - RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", ): td.memmap_() else: @@ -1706,8 +1712,8 @@ def test_memmap_prefix(self, td_name, device, tmpdir): td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( - RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", ): td.memmap_(tmpdir / "tensordict") return @@ -1747,7 +1753,7 @@ def test_memmap_existing(self, td_name, device, copy_existing, tmpdir): assert (td == td3).all() else: with pytest.raises( - RuntimeError, match="TensorDict already contains MemmapTensors" + RuntimeError, match="TensorDict already contains MemmapTensors" ): # calling memmap_ with prefix that is different to contents gives error td.memmap_(prefix=tmpdir / "tensordict2") @@ -1921,11 +1927,11 @@ def test_pop(self, td_name, device): assert (out == default).all() with pytest.raises( - KeyError, - match=re.escape( - "You are trying to pop key `z` which is not in dict" - "without providing default value" - ), + KeyError, + match=re.escape( + "You are trying to pop key `z` which is not in dict" + "without providing default value" + ), ): td.pop("z") @@ -2224,7 +2230,7 @@ def test_repr_indexed_stacked_tensordict(self, device, dtype, index): def test_repr_device_to_device(self, device, dtype, device_cast): td = self.td(device, dtype) if (device_cast is None and (torch.cuda.device_count() > 0)) or ( - device_cast is not None and device_cast.type == "cuda" + device_cast is not None and device_cast.type == "cuda" ): is_shared = True else: @@ -2356,11 +2362,11 @@ def test_batchsize_reset(): # incompatible size with pytest.raises( - RuntimeError, - match=re.escape( - "the tensor a has shape torch.Size([3, 4, 5, " - "6]) which is incompatible with the new shape torch.Size([3, 5])" - ), + RuntimeError, + match=re.escape( + "the tensor a has shape torch.Size([3, 4, 5, " + "6]) which is incompatible with the new shape torch.Size([3, 5])" + ), ): td.batch_size = [3, 5] @@ -2370,8 +2376,8 @@ def test_batchsize_reset(): # test index td[torch.tensor([1, 2])] with pytest.raises( - IndexError, - match=re.escape("too many indices for tensor of dimension 1"), + IndexError, + match=re.escape("too many indices for tensor of dimension 1"), ): td[:, 0] @@ -2383,22 +2389,22 @@ def test_batchsize_reset(): td.set("c", torch.randn(3, 4, 5, 6)) with pytest.raises( - RuntimeError, - match=re.escape( - "batch dimension mismatch, " - "got self.batch_size=torch.Size([3, 4, 5]) and tensor.shape[:self.batch_dims]=torch.Size([3, 4, 2])" - ), + RuntimeError, + match=re.escape( + "batch dimension mismatch, " + "got self.batch_size=torch.Size([3, 4, 5]) and tensor.shape[:self.batch_dims]=torch.Size([3, 4, 2])" + ), ): td.set("d", torch.randn(3, 4, 2)) # test that lazy tds return an exception td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)]) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation " - "of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation " + "of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): td_stack.batch_size = [2] td_stack.to_tensordict().batch_size = [2] @@ -2406,10 +2412,10 @@ def test_batchsize_reset(): td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) subtd = td.get_sub_tensordict((slice(None), torch.tensor([1, 2]))) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): subtd.batch_size = [3, 2] subtd.to_tensordict().batch_size = [3, 2] @@ -2417,10 +2423,10 @@ def test_batchsize_reset(): td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) td_u = td.unsqueeze(0) with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy repesentation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), ): td_u.batch_size = [1] td_u.to_tensordict().batch_size = [1] @@ -2534,7 +2540,7 @@ def _remote_process(worker_id, command_pipe_child, command_pipe_parent, tensordi a = torch.ones(2) * val tensordict.set_("a", a) assert ( - tensordict.get("a") == a + tensordict.get("a") == a ).all(), f'found {a} and {tensordict.get("a")}' command_pipe_child.send("done") elif cmd == "set_done": @@ -2689,10 +2695,10 @@ def test_mp(td_type): (torch.tensor([1, 2, 3]),), ([1, 2, 3]), ( - torch.tensor([1, 2, 3]), - torch.tensor([2, 3, 4]), - torch.tensor([0, 10, 2]), - torch.tensor([2, 4, 1]), + torch.tensor([1, 2, 3]), + torch.tensor([2, 3, 4]), + torch.tensor([0, 10, 2]), + torch.tensor([2, 4, 1]), ), torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(), torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(), @@ -2978,7 +2984,7 @@ def test_keys_view(): assert ("a", "c", "b") not in tensordict.keys(include_nested=True) with pytest.raises( - TypeError, match="checks with tuples of strings is only supported" + TypeError, match="checks with tuples of strings is only supported" ): ("a", "b", "c") in tensordict.keys() # noqa: B015 @@ -3006,8 +3012,8 @@ def test_error_on_contains(): {"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)}, [1] ) with pytest.raises( - NotImplementedError, - match="TensorDict does not support membership checks with the `in` keyword", + NotImplementedError, + match="TensorDict does not support membership checks with the `in` keyword", ): "random_string" in td # noqa: B015 @@ -3145,57 +3151,57 @@ def test_flatten_unflatten_key_collision(inplace, separator): ) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td1.flatten_keys(separator) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td2.flatten_keys(separator) with pytest.raises( - KeyError, match="Flattening keys in tensordict collides with existing key *" + KeyError, match="Flattening keys in tensordict collides with existing key *" ): _ = td3.flatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td1.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td2.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td3.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td4.unflatten_keys(separator) with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override existing unflattened key" - ), + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override existing unflattened key" + ), ): _ = td5.unflatten_keys(separator) @@ -3405,8 +3411,8 @@ def test_stacked_td(self, stack_dim, device): tensordicts0.zero_() assert (sub_td[item].get("key1") == sub_td.get("key1")[item]).all() assert ( - sub_td.contiguous()[item].get("key1") - == sub_td.contiguous().get("key1")[item] + sub_td.contiguous()[item].get("key1") + == sub_td.contiguous().get("key1")[item] ).all() assert (sub_td.contiguous().get("key1")[item] == 0).all() @@ -3415,7 +3421,7 @@ def test_stacked_td(self, stack_dim, device): tensordicts1.zero_() assert (std2[item].get("key1") == std2.get("key1")[item]).all() assert ( - std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] + std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] ).all() assert (std2.contiguous().get("key1")[item] == 0).all() @@ -3424,7 +3430,7 @@ def test_stacked_td(self, stack_dim, device): item = (*[slice(None) for _ in range(stack_dim)], 2) assert (std3[item].get("key1") == std3.get("key1")[item]).all() assert ( - std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] + std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] ).all() assert (std3.contiguous().get("key1")[item] == 0).all() @@ -3433,7 +3439,7 @@ def test_stacked_td(self, stack_dim, device): item = (*[slice(None) for _ in range(stack_dim)], 3) assert (std4[item].get("key1") == std4.get("key1")[item]).all() assert ( - std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] + std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] ).all() assert (std4.contiguous().get("key1")[item] == 0).all() @@ -3452,9 +3458,9 @@ def test_stacked_indexing(self, device, stack_dim): tds = torch.stack(list(tensordict.unbind(stack_dim)), stack_dim) for item, expected_shape in ( - ((2, 2), torch.Size([5])), - ((slice(1, 2), 2), torch.Size([1, 5])), - ((..., 2), torch.Size([3, 4])), + ((2, 2), torch.Size([5])), + ((slice(1, 2), 2), torch.Size([1, 5])), + ((..., 2), torch.Size([3, 4])), ): assert tds[item].batch_size == expected_shape assert (tds[item].get("a") == tds.get("a")[item]).all() @@ -3505,7 +3511,7 @@ def test_lazy_stacked_insert(self, dim, index, device): torch.testing.assert_close(lstd["a"], t) with pytest.raises( - TypeError, match="Expected new value to be TensorDictBase instance" + TypeError, match="Expected new value to be TensorDictBase instance" ): lstd.insert(index, torch.rand(10)) @@ -3526,8 +3532,8 @@ def test_lazy_stacked_contains(self): assert td.clone() not in lstd with pytest.raises( - NotImplementedError, - match="TensorDict does not support membership checks with the `in` keyword", + NotImplementedError, + match="TensorDict does not support membership checks with the `in` keyword", ): "random_string" in lstd # noqa: B015 @@ -3559,7 +3565,7 @@ def test_lazy_stacked_append(self, dim, device): torch.testing.assert_close(lstd["a"], t) with pytest.raises( - TypeError, match="Expected new value to be TensorDictBase instance" + TypeError, match="Expected new value to be TensorDictBase instance" ): lstd.append(torch.rand(10)) @@ -3590,14 +3596,14 @@ def test_stack_update_heter_stacked_td(self, stack_dim): td_b = td_a.clone() td_a.update(td_b) with pytest.raises( - RuntimeError, - match="Found more than one unique shape in the tensors to be stacked", + RuntimeError, + match="Found more than one unique shape in the tensors to be stacked", ): td_a.update(td_b.to_tensordict()) td_a.update_(td_b) with pytest.raises( - RuntimeError, - match="Found more than one unique shape in the tensors to be stacked", + RuntimeError, + match="Found more than one unique shape in the tensors to be stacked", ): td_a.update_(td_b.to_tensordict()) @@ -3638,7 +3644,7 @@ def test_inplace(self, save_name): assert isinstance(td_dest["b", "c"], MemmapTensor) def test_update( - self, + self, ): tensordict = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3)}}, []) state = {"state": tensordict} @@ -3696,8 +3702,8 @@ def test_tensordict_prealloc_nested(): buffer[0] = td_0 buffer[1] = td_1 assert ( - repr(buffer) - == """TensorDict( + repr(buffer) + == """TensorDict( fields={ agent.obs: TensorDict( fields={ @@ -3717,14 +3723,13 @@ def test_tensordict_prealloc_nested(): def test_tensordict_view_iteration(): td_simple = TensorDict( - source={ - "a": torch.randn(4, 3, 2, 1, 5), - "b": torch.randn(4, 3, 2, 1, 5) - }, - batch_size=[4, 3, 2, 1] + source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)}, + batch_size=[4, 3, 2, 1], ) - view = _TensorDictKeysView(tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True + ) keys = list(view) assert len(keys) == 2 assert "a" in keys @@ -3733,25 +3738,29 @@ def test_tensordict_view_iteration(): td_nested = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), - "b": TensorDict( - {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] - ), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) - view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True + ) keys = list(view) assert len(keys) == 2 assert "a" in keys assert ("b", "c") in keys - view = _TensorDictKeysView(tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True + ) keys = list(view) assert len(keys) == 1 assert "a" in keys - view = _TensorDictKeysView(tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True + ) keys = list(view) assert len(keys) == 3 assert "a" in keys @@ -3763,105 +3772,121 @@ def test_tensordict_view_iteration(): td_auto_nested_loop = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), - "b": TensorDict( - {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] - ), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_loop["b"]["d"] = td_auto_nested_loop - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=False, - error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=False, + leaves_only=False, + error_on_loop=True, + ) keys = list(view) assert len(keys) == 2 - assert 'a' in keys - assert 'b' in keys + assert "a" in keys + assert "b" in keys - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=False, leaves_only=True, - error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=False, + leaves_only=True, + error_on_loop=True, + ) keys = list(view) assert len(keys) == 1 - assert 'a' in keys + assert "a" in keys with pytest.raises(RecursionError): - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=True, - error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=True, + error_on_loop=True, + ) list(view) with pytest.raises(RecursionError): - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, - error_on_loop=True) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=False, + error_on_loop=True, + ) list(view) - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=False, - error_on_loop=False) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) keys = list(view) assert len(keys) == 3 - assert 'a' in keys - assert 'b' in keys - assert ('b', 'c') in keys + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys - view = _TensorDictKeysView(tensordict=td_auto_nested_loop, include_nested=True, leaves_only=True, - error_on_loop=False) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=True, + error_on_loop=False, + ) keys = list(view) assert len(keys) == 2 - assert 'a' in keys - assert ('b', 'c') in keys + assert "a" in keys + assert ("b", "c") in keys td_auto_nested_loop_2 = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), - "b": TensorDict( - {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] - ), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) - td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2['b'] + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2["b"] - view = _TensorDictKeysView(tensordict=td_auto_nested_loop_2, include_nested=True, leaves_only=False, - error_on_loop=False) + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop_2, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) keys = list(view) assert len(keys) == 3 - assert 'a' in keys - assert 'b' in keys - assert ('b', 'c') in keys - + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys def test_detect_loop(): td_simple = TensorDict( - source={ - "a": torch.randn(4, 3, 2, 1, 5), - "b": torch.randn(4, 3, 2, 1, 5) - }, - batch_size=[4, 3, 2, 1] + source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)}, + batch_size=[4, 3, 2, 1], ) assert not detect_loop(td_simple) td_nested = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), - "b": TensorDict( - {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] - ), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) assert not detect_loop(td_nested) td_auto_nested_no_loop_1 = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), - "b": TensorDict( - {"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1] - ), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_no_loop_1["b"]["d"] = td_auto_nested_no_loop_1["a"] @@ -3870,15 +3895,13 @@ def test_detect_loop(): td_auto_nested_no_loop_2 = TensorDict( source={ "a": TensorDict( - source={"c": torch.randn(4, 3, 2, 1, 2)}, - batch_size=[4, 3, 2, 1] + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] ), "b": TensorDict( - source={"d": torch.randn(4, 3, 2, 1, 2)}, - batch_size=[4, 3, 2, 1] + source={"d": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] ), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_no_loop_2["b"]["e"] = td_auto_nested_no_loop_2["a"] @@ -3888,11 +3911,10 @@ def test_detect_loop(): source={ "a": torch.randn(4, 3, 2, 1, 2), "b": TensorDict( - source={"c": torch.randn(4, 3, 2, 1, 2)}, - batch_size=[4, 3, 2, 1] + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] ), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_no_loop_3["b"]["d"] = td_auto_nested_no_loop_3["b"]["c"] @@ -3902,11 +3924,10 @@ def test_detect_loop(): source={ "a": torch.randn(4, 3, 2, 1, 2), "b": TensorDict( - source={"c": torch.randn(4, 3, 2, 1, 2)}, - batch_size=[4, 3, 2, 1] + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] ), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_loop_1["b"]["d"] = td_auto_nested_loop_1["b"] @@ -3916,11 +3937,10 @@ def test_detect_loop(): source={ "a": torch.randn(4, 3, 2, 1, 2), "b": TensorDict( - source={"c": torch.randn(4, 3, 2, 1, 2)}, - batch_size=[4, 3, 2, 1] + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] ), }, - batch_size=[4, 3, 2, 1] + batch_size=[4, 3, 2, 1], ) td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2