diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 0f141cbd9..2cc9a4d31 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -13,9 +13,10 @@ import weakref from collections import defaultdict from copy import copy, deepcopy +from functools import wraps from pathlib import Path from textwrap import indent -from typing import Any, Callable, Iterator, Sequence, Type +from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type import numpy as np import torch @@ -24,7 +25,6 @@ from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.base import ( - _ACCEPTED_CLASSES, _is_tensor_collection, _register_tensor_class, BEST_ATTEMPT_INPLACE, @@ -37,10 +37,10 @@ from tensordict.memmap import MemoryMappedTensor as MemmapTensor from tensordict.utils import ( _broadcast_tensors, + _check_keys, _getitem_batch_size, _is_number, _parse_to, - _prune_selected_keys, _renamed_inplace_method, _shape, _td_fields, @@ -54,12 +54,10 @@ infer_size_impl, is_tensorclass, KeyedJaggedTensor, - lazy_legacy, lock_blocked, NestedKey, ) from torch import Tensor -from torch.utils._pytree import tree_map _has_functorch = False try: @@ -108,6 +106,18 @@ def __contains__(self, item): ) +def _fails_exclusive_keys(func): + @wraps(func) + def newfunc(self, *args, **kwargs): + if self._has_exclusive_keys: + raise RuntimeError( + f"the method {func.__name__} cannot complete when there are exclusive keys." + ) + return getattr(TensorDictBase, func.__name__)(self, *args, **kwargs) + + return newfunc + + class LazyStackedTensorDict(TensorDictBase): """A Lazy stack of TensorDicts. @@ -216,14 +226,53 @@ def __init__( if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + @property + @cache # noqa: B019 + def _has_exclusive_keys(self): + keys = None + for td in self.tensordicts: + _keys = set(td.keys(True, True)) + if keys is None: + keys = _keys + else: + if keys != _keys: + return True + else: + return False + + @_fails_exclusive_keys + def to_dict(self) -> dict[str, Any]: + ... + + @_fails_exclusive_keys + def state_dict( + self, + destination=None, + prefix="", + keep_vars=False, + flatten=False, + ) -> OrderedDict[str, Any]: + ... + + @_fails_exclusive_keys + def flatten_keys( + self, + separator: str = ".", + inplace: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + ) -> T: + ... + + @_fails_exclusive_keys + def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T: + ... + @property def device(self) -> torch.device | None: # devices might have changed, so we check that they're all the same device_set = {td.device for td in self.tensordicts} if len(device_set) != 1: - raise RuntimeError( - f"found multiple devices in {self.__class__.__name__}:" f" {device_set}" - ) + return None device = self.tensordicts[0].device return device @@ -608,33 +657,6 @@ def _set_at_str(self, key, value, index, *, validated): self_idx = (slice(None),) * split_index["mask_loc"] + (i,) self[self_idx]._set_at_str(key, _value, _idx, validated=validated) - # # it may be the case that we can't get the value - # # because it can't be stacked. - # # self[index]._set_str(key, value, validated=validated, inplace=True) - # # return self - # split_index = self._split_index(index) - # converted_idx = split_index["index_dict"] - # num_single = split_index["num_single"] - # isinteger = split_index["isinteger"] - # if isinteger: - # for (i, _idx) in converted_idx.items(): - # if _idx: - # self.tensordicts[i]._set_at_str( - # key, value, _idx, validated=validated - # ) - # else: - # self.tensordicts[i]._set_str( - # key, - # value, - # validated=validated, - # inplace=True, - # ) - # return self - # unbind_dim = self.stack_dim - num_single - # for (i, _idx), _value in zip(converted_idx.items(), value.unbind(unbind_dim)): - # self.tensordicts[i]._set_at_str(key, _value, _idx, validated=validated) - # return self - def _set_at_tuple(self, key, value, idx, *, validated): if len(key) == 1: return self._set_at_str(key[0], value, idx, validated=validated) @@ -735,9 +757,8 @@ def unbind(self, dim: int) -> tuple[TensorDictBase, ...]: ) for td in self.tensordicts: out.append(td.unbind(new_dim)) - from tensordict._torch_func import _stack - return tuple(_stack(vals, new_stack_dim) for vals in zip(*out)) + return tuple(self.lazy_stack(vals, new_stack_dim) for vals in zip(*out)) def _stack_onto_( self, @@ -773,10 +794,8 @@ def _get_str( # then we consider this default as non-stackable and return prematurly return default try: - out = torch.stack(tensors, self.stack_dim) + out = self.lazy_stack(tensors, self.stack_dim) if _is_tensor_collection(out.__class__): - if self._td_dim_name is not None: - out._td_dim_name = self._td_dim_name if isinstance(out, LazyStackedTensorDict): # then it's a LazyStackedTD out.hook_out = self.hook_out @@ -787,14 +806,9 @@ def _get_str( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] ) - elif not lazy_legacy(): - # it must be a TensorDict - incr = 0 if not self._is_vmapped else 1 - out._batch_size = ( - self._batch_size - + out.batch_size[(len(self._batch_size) + incr) :] - ) - else: + if self._td_dim_name is not None: + out._td_dim_name = self._td_dim_name + elif is_tensorclass(out): # then it's a tensorclass out._tensordict.hook_out = self.hook_out out._tensordict.hook_in = self.hook_in @@ -804,6 +818,10 @@ def _get_str( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) + if self._td_dim_name is not None: + out._tensordict._td_dim_name = self._td_dim_name + else: + raise RuntimeError elif self.hook_out is not None: out = self.hook_out(out) return out @@ -841,10 +859,205 @@ def _get_tuple(self, key, default): f" for key '{key[1:]}' in tensordict:\n{self}." ) + @classmethod + def lazy_stack( + cls, + items: Sequence[TensorDictBase], + dim: int = 0, + device: DeviceType | None = None, + out: T | None = None, + ) -> T: + """Stacks tensordicts in a LazyStackedTensorDict.""" + if not items: + raise RuntimeError("items cannot be empty") + + from .tensorclass import NonTensorData + + if all(isinstance(item, torch.Tensor) for item in items): + return torch.stack(items, dim=dim, out=out) + if all( + is_tensorclass(item) and type(item) == type(items[0]) # noqa: E721 + for item in items + ): + if all(isinstance(tensordict, NonTensorData) for tensordict in items): + return NonTensorData._stack_non_tensor(items, dim=dim) + lazy_stack = cls.lazy_stack( + [item._tensordict for item in items], dim=dim, out=out + ) + # we take the first non_tensordict by convention + return type(items[0])._from_tensordict( + tensordict=lazy_stack, non_tensordict=items[0]._non_tensordict + ) + + batch_size = items[0].batch_size + if dim < 0: + dim = len(batch_size) + dim + 1 + + for td in items[1:]: + if td.batch_size != items[0].batch_size: + raise RuntimeError( + "stacking tensordicts requires them to have congruent batch sizes, " + f"got td1.batch_size={td.batch_size} and td2.batch_size=" + f"{items[0].batch_size}" + ) + + if out is None: + # We need to handle tensordicts with exclusive keys and tensordicts with + # mismatching shapes. + # The first case is handled within _check_keys which fails if keys + # don't match exactly. + # The second requires a check over the tensor shapes. + return LazyStackedTensorDict(*items, stack_dim=dim) + else: + batch_size = list(batch_size) + batch_size.insert(dim, len(items)) + batch_size = torch.Size(batch_size) + + if out.batch_size != batch_size: + raise RuntimeError( + "out.batch_size and stacked batch size must match, " + f"got out.batch_size={out.batch_size} and batch_size" + f"={batch_size}" + ) + + try: + out._stack_onto_(items, dim) + except KeyError as err: + raise err + return out + + @classmethod + def maybe_dense_stack( + cls, + items: Sequence[TensorDictBase], + dim: int = 0, + out: T | None = None, + strict: bool = False, + ) -> T: + """Stacks tensors or tensordicts densly if possible, or onto a LazyStackedTensorDict otherwise. + + Examples: + >>> td0 = TensorDict({"a": 0}, []) + >>> td1 = TensorDict({"b": 0}, []) + >>> LazyStackedTensorDict.maybe_dense_stack([td0, td0]) # returns a TensorDict with shape [2] + >>> LazyStackedTensorDict.maybe_dense_stack([td0, td1]) # returns a LazyStackedTensorDict with shape [2] + >>> LazyStackedTensorDict.maybe_dense_stack(list(torch.randn(2))) # returns a torch.Tensor with shape [2] + """ + if not items: + raise RuntimeError("items cannot be empty") + + from .tensorclass import NonTensorData + + if all(isinstance(item, torch.Tensor) for item in items): + return torch.stack(items, dim=dim, out=out) + + if all(isinstance(tensordict, NonTensorData) for tensordict in items): + return NonTensorData._stack_non_tensor(items, dim=dim) + + batch_size = items[0].batch_size + if dim < 0: + dim = len(batch_size) + dim + 1 + + for td in items[1:]: + if td.batch_size != items[0].batch_size: + raise RuntimeError( + "stacking tensordicts requires them to have congruent batch sizes, " + f"got td1.batch_size={td.batch_size} and td2.batch_size=" + f"{items[0].batch_size}" + ) + + if out is None: + # We need to handle tensordicts with exclusive keys and tensordicts with + # mismatching shapes. + # The first case is handled within _check_keys which fails if keys + # don't match exactly. + # The second requires a check over the tensor shapes. + device = items[0].device + if any(device != item.device for item in items[1:]): + device = None + if any( + isinstance(item, LazyStackedTensorDict) and item._has_exclusive_keys + for item in items + ): + return LazyStackedTensorDict(*items, stack_dim=dim) + try: + keys = _check_keys(items, strict=True) + except KeyError: + return LazyStackedTensorDict(*items, stack_dim=dim) + + out = {} + for key in keys: + out[key] = [] + tensor_shape = None + for _tensordict in items: + # TODO: this can break if the tensor cannot be stacked and _tensordict is a lazy stack itself + tensor = _tensordict._get_str(key, default=NO_DEFAULT) + if tensor_shape is None: + tensor_shape = tensor.shape + elif tensor.shape != tensor_shape: + return LazyStackedTensorDict(*items, stack_dim=dim) + out[key].append(tensor) + + def stack_fn(key_values): + key, values = key_values + return cls.maybe_dense_stack(values, dim) + + out = {key: stack_fn((key, value)) for key, value in out.items()} + + is_locked = any(item.is_locked for item in items) + result = TensorDict( + out, + batch_size=LazyStackedTensorDict._compute_batch_size( + batch_size, dim, len(items) + ), + device=device, + _run_checks=False, + ) + if is_locked: + return result.lock_() + return result + else: + keys = _check_keys(items) + batch_size = list(batch_size) + batch_size.insert(dim, len(items)) + batch_size = torch.Size(batch_size) + + if out.batch_size != batch_size: + raise RuntimeError( + "out.batch_size and stacked batch size must match, " + f"got out.batch_size={out.batch_size} and batch_size" + f"={batch_size}" + ) + + out_keys = set(out.keys()) + if strict: + in_keys = set(keys) + if len(out_keys - in_keys) > 0: + raise RuntimeError( + "The output tensordict has keys that are missing in the " + "tensordict that has to be written: {out_keys - in_keys}. " + "As per the call to `stack(..., strict=True)`, this " + "is not permitted." + ) + elif len(in_keys - out_keys) > 0: + raise RuntimeError( + "The resulting tensordict has keys that are missing in " + f"its destination: {in_keys - out_keys}. As per the call " + "to `stack(..., strict=True)`, this is not permitted." + ) + + try: + out._stack_onto_(items, dim) + except KeyError as err: + raise err + return out + @cache # noqa: B019 def _add_batch_dim(self, *, in_dim, vmap_level): if self.is_memmap(): - td = torch.stack([td.cpu().as_tensor() for td in self.tensordicts], 0) + td = LazyStackedTensorDict.lazy_stack( + [td.cpu().as_tensor() for td in self.tensordicts], 0 + ) else: td = self if in_dim < 0: @@ -875,6 +1088,8 @@ def _cached_add_batch_dims(cls, td, in_dim, vmap_level): out = td.copy() def hook_out(tensor, in_dim=in_dim, vmap_level=vmap_level): + if _is_tensor_collection(type(tensor)): + return tensor._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level) return _add_batch_dim(tensor, in_dim, vmap_level) n = len(td.tensordicts) @@ -1004,6 +1219,12 @@ def contiguous(self) -> T: ) return out + def empty(self, recurse=False) -> T: + return LazyStackedTensorDict( + *[td.empty(recurse=recurse) for td in self.tensordicts], + stack_dim=self.stack_dim, + ) + def clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with @@ -1180,7 +1401,7 @@ def exclude(self, *keys: str, inplace: bool = False) -> LazyStackedTensorDict: if inplace: self.tensordicts = tensordicts return self - return torch.stack(tensordicts, dim=self.stack_dim) + return LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) def __setitem__(self, index: IndexType, value: T) -> T: if isinstance(index, (tuple, str)): @@ -1306,7 +1527,7 @@ def __getitem__(self, index: IndexType) -> T: else: out.append(self.tensordicts[i][_idx]) out[-1] = out[-1].squeeze(cat_dim) - return torch.stack(out, cat_dim) + return LazyStackedTensorDict.lazy_stack(out, cat_dim) else: for i, _idx in converted_idx.items(): self_idx = (slice(None),) * split_index["mask_loc"] + (i,) @@ -1314,7 +1535,7 @@ def __getitem__(self, index: IndexType) -> T: return torch.cat(out, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none - return torch.stack( + return LazyStackedTensorDict.lazy_stack( [self[idx] for idx in converted_idx.values()], new_stack_dim ) else: @@ -1334,7 +1555,7 @@ def __getitem__(self, index: IndexType) -> T: new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): out.append(self.tensordicts[i][_idx]) - out = torch.stack(out, new_stack_dim) + out = LazyStackedTensorDict.lazy_stack(out, new_stack_dim) out._td_dim_name = self._td_dim_name return out @@ -1342,15 +1563,29 @@ def __eq__(self, other): if is_tensorclass(other): return other == self if isinstance(other, (dict,)): - other = TensorDict.from_dict(other) + # we may want to broadcast it instead + other = TensorDict.from_dict(other, batch_size=self.batch_size) if _is_tensor_collection(other.__class__): + if other.batch_size != self.batch_size: + if self.ndim < other.ndim: + self_expand = self.expand(other.batch_size) + elif self.ndim > other.ndim: + other = other.expand(self.batch_size) + self_expand = self + else: + raise RuntimeError( + f"Could not compare tensordicts with shapes {self.shape} and {other.shape}" + ) + else: + self_expand = self out = [] - for i, td in enumerate(self.tensordicts): - idx = (slice(None),) * self.stack_dim + (i,) - out.append(other[idx] == td) - return torch.stack(out, self.stack_dim) + for td0, td1 in zip( + self_expand.tensordicts, other.unbind(self_expand.stack_dim) + ): + out.append(td0 == td1) + return LazyStackedTensorDict.lazy_stack(out, self.stack_dim) if isinstance(other, (numbers.Number, Tensor)): - return torch.stack( + return LazyStackedTensorDict.lazy_stack( [td == other for td in self.tensordicts], self.stack_dim, ) @@ -1360,20 +1595,98 @@ def __ne__(self, other): if is_tensorclass(other): return other != self if isinstance(other, (dict,)): - other = TensorDict.from_dict(other) + # we may want to broadcast it instead + other = TensorDict.from_dict(other, batch_size=self.batch_size) if _is_tensor_collection(other.__class__): + if other.batch_size != self.batch_size: + if self.ndim < other.ndim: + self_expand = self.expand(other.batch_size) + elif self.ndim > other.ndim: + other = other.expand(self.batch_size) + self_expand = self + else: + raise RuntimeError( + f"Could not compare tensordicts with shapes {self.shape} and {other.shape}" + ) + else: + self_expand = self out = [] - for i, td in enumerate(self.tensordicts): - idx = (slice(None),) * self.stack_dim + (i,) - out.append(other[idx] != td) - return torch.stack(out, self.stack_dim) + for td0, td1 in zip( + self_expand.tensordicts, other.unbind(self_expand.stack_dim) + ): + out.append(td0 != td1) + return LazyStackedTensorDict.lazy_stack(out, self.stack_dim) if isinstance(other, (numbers.Number, Tensor)): - return torch.stack( + return LazyStackedTensorDict.lazy_stack( [td != other for td in self.tensordicts], self.stack_dim, ) return True + def __xor__(self, other): + if is_tensorclass(other): + return other == self + if isinstance(other, (dict,)): + # we may want to broadcast it instead + other = TensorDict.from_dict(other, batch_size=self.batch_size) + if _is_tensor_collection(other.__class__): + if other.batch_size != self.batch_size: + if self.ndim < other.ndim: + self_expand = self.expand(other.batch_size) + elif self.ndim > other.ndim: + other = other.expand(self.batch_size) + self_expand = self + else: + raise RuntimeError( + f"Could not compare tensordicts with shapes {self.shape} and {other.shape}" + ) + else: + self_expand = self + out = [] + for td0, td1 in zip( + self_expand.tensordicts, other.unbind(self_expand.stack_dim) + ): + out.append(td0 ^ td1) + return LazyStackedTensorDict.lazy_stack(out, self.stack_dim) + if isinstance(other, (numbers.Number, Tensor)): + return LazyStackedTensorDict.lazy_stack( + [td ^ other for td in self.tensordicts], + self.stack_dim, + ) + return False + + def __or__(self, other): + if is_tensorclass(other): + return other == self + if isinstance(other, (dict,)): + # we may want to broadcast it instead + other = TensorDict.from_dict(other, batch_size=self.batch_size) + if _is_tensor_collection(other.__class__): + if other.batch_size != self.batch_size: + if self.ndim < other.ndim: + self_expand = self.expand(other.batch_size) + elif self.ndim > other.ndim: + other = other.expand(self.batch_size) + self_expand = self + else: + raise RuntimeError( + f"Could not compare tensordicts with shapes {self.shape} and {other.shape}" + ) + else: + self_expand = self + out = [] + for td0, td1 in zip( + self_expand.tensordicts, other.unbind(self_expand.stack_dim) + ): + out.append(td0 | td1) + return LazyStackedTensorDict.lazy_stack(out, self.stack_dim) + if isinstance(other, (numbers.Number, Tensor)): + return LazyStackedTensorDict.lazy_stack( + [td | other for td in self.tensordicts], + self.stack_dim, + ) + return False + def all(self, dim: int = None) -> bool | TensorDictBase: if dim is not None and (dim >= self.batch_dims or dim < -self.batch_dims): raise RuntimeError( @@ -1560,6 +1873,7 @@ def _memmap_( like=False, ) -> T: if prefix is not None: + prefix = Path(prefix) def save_metadata(prefix=prefix, self=self): prefix = Path(prefix) @@ -1588,7 +1902,7 @@ def save_metadata(prefix=prefix, self=self): ) ) if not inplace: - results = torch.stack(results, dim=self.stack_dim) + results = LazyStackedTensorDict.lazy_stack(results, dim=self.stack_dim) else: results = self results._is_memmap = True @@ -1618,7 +1932,7 @@ def expand(self, *args: int, inplace: bool = False) -> T: self.tensordicts = tensordicts self.stack_dim = stack_dim return self - return torch.stack(tensordicts, stack_dim) + return LazyStackedTensorDict.maybe_dense_stack(tensordicts, dim=stack_dim) def update( self, @@ -1628,9 +1942,18 @@ def update( keys_to_update: Sequence[NestedKey] | None = None, **kwargs: Any, ) -> T: + # This implementation of update is compatible with exclusive keys + # as well as vmapped lazy stacks. + # We iterate over the tensordicts rather than iterating over the keys, + # which requires stacking and unbinding but is also not robust to missing keys. if input_dict_or_td is self: # no op return self + if isinstance(input_dict_or_td, dict): + input_dict_or_td = TensorDict.from_dict( + input_dict_or_td, batch_size=self.batch_size + ) + if keys_to_update is not None: keys_to_update = unravel_key_list(keys_to_update) if len(keys_to_update) == 0: @@ -1652,57 +1975,38 @@ def update( ) return self - inplace = kwargs.get("inplace", False) - for key, value in input_dict_or_td.items(): - if clone and hasattr(value, "clone"): - value = value.clone() - elif clone: - value = tree_map(torch.clone, value) - key = _unravel_key_to_tuple(key) - firstkey, subkey = key[0], key[1:] - if keys_to_update and not any( - firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0] - for ktu in keys_to_update - ): - continue - - if subkey: - # we must check that the target is not a leaf - target = self._get_str(firstkey, default=None) - if is_tensor_collection(target): - sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey) - target.update( - {subkey: value}, - inplace=inplace, - clone=clone, - keys_to_update=sub_keys_to_update, - ) - elif target is None: - self._set_tuple(key, value, inplace=inplace, validated=False) - else: - raise TypeError( - f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection." - ) - else: - target = self._get_str(firstkey, default=None) - if is_tensor_collection(target) and ( - is_tensor_collection(value) or isinstance(value, dict) - ): - sub_keys_to_update = _prune_selected_keys(keys_to_update, firstkey) - target.update( - value, - inplace=inplace, - clone=clone, - keys_to_update=sub_keys_to_update, - ) - elif target is None or not is_tensor_collection(value): - self._set_str(firstkey, value, inplace=inplace, validated=False) - else: - raise TypeError( - f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}." - ) - - return self + if self.hook_in is not None: + self_upd = self.hook_in(self) + input_dict_or_td = self.hook_in(input_dict_or_td) + else: + self_upd = self + # Then we can decompose the tensordict along its stack dim + if input_dict_or_td.ndim <= self_upd.stack_dim or input_dict_or_td.batch_size[ + self_upd.stack_dim + ] != len(self_upd.tensordicts): + try: + # if the batch-size does not permit unbinding, let's first try to reset the batch-size. + input_dict_or_td = input_dict_or_td.copy() + batch_size = self_upd.batch_size + if self_upd.hook_out is not None: + batch_size = list(batch_size) + batch_size.insert(self_upd.stack_dim, len(self_upd.tensordicts)) + input_dict_or_td.batch_size = batch_size + except RuntimeError as err: + raise ValueError( + "cannot update stacked tensordicts with different shapes." + ) from err + for td_dest, td_source in zip( + self_upd.tensordicts, input_dict_or_td.unbind(self_upd.stack_dim) + ): + td_dest.update( + td_source, clone=clone, keys_to_update=keys_to_update, **kwargs + ) + if self.hook_out is not None: + self_upd = self.hook_out(self_upd) + else: + self_upd = self + return self_upd def update_( self, @@ -1713,30 +2017,12 @@ def update_( if input_dict_or_td is self: # no op return self - if isinstance(input_dict_or_td, LazyStackedTensorDict): - if input_dict_or_td.stack_dim == self.stack_dim: - if not input_dict_or_td.shape[self.stack_dim] == len(self.tensordicts): - raise ValueError( - "cannot update stacked tensordicts with different shapes." - ) - for td_dest, td_source in zip( - self.tensordicts, input_dict_or_td.tensordicts - ): - td_dest.update_(td_source) - return self - else: - for i, td in enumerate(input_dict_or_td.tensordicts): - idx = (slice(None),) * input_dict_or_td.stack_dim + (i,) - self.update_at_(td, idx) - for key, value in input_dict_or_td.items(): - if not isinstance(value, tuple(_ACCEPTED_CLASSES)): - raise TypeError( - f"Expected value to be one of types {_ACCEPTED_CLASSES} " - f"but got {type(value)}" - ) - if clone: - value = value.clone() - self.set_(key, value, **kwargs) + if input_dict_or_td.batch_size[self.stack_dim] != len(self.tensordicts): + raise ValueError("cannot update stacked tensordicts with different shapes.") + for td_dest, td_source in zip( + self.tensordicts, input_dict_or_td.unbind(self.stack_dim) + ): + td_dest.update_(td_source, clone=clone, **kwargs) return self def update_at_( @@ -1745,46 +2031,34 @@ def update_at_( index: IndexType, clone: bool = False, ) -> T: - if isinstance(input_dict_or_td, TensorDictBase): - split_index = self._split_index(index) - converted_idx = split_index["index_dict"] - num_single = split_index["num_single"] - isinteger = split_index["isinteger"] - if isinteger: - # this will break if the index along the stack dim is [0] or :1 or smth - for i, _idx in converted_idx.items(): - self.tensordicts[i].update_at_( - input_dict_or_td, - _idx, - ) - return self - unbind_dim = self.stack_dim - num_single - for (i, _idx), _value in zip( - converted_idx.items(), - input_dict_or_td.unbind(unbind_dim), - ): + if not isinstance(input_dict_or_td, TensorDictBase): + input_dict_or_td = TensorDict.from_dict( + input_dict_or_td, batch_size=self.batch_size + ) + split_index = self._split_index(index) + converted_idx = split_index["index_dict"] + num_single = split_index["num_single"] + isinteger = split_index["isinteger"] + if isinteger: + # this will break if the index along the stack dim is [0] or :1 or smth + for i, _idx in converted_idx.items(): self.tensordicts[i].update_at_( - _value, + input_dict_or_td, _idx, ) return self - for key, value in input_dict_or_td.items(): - if not isinstance(value, _ACCEPTED_CLASSES): - raise TypeError( - f"Expected value to be one of types {_ACCEPTED_CLASSES} " - f"but got {type(value)}" - ) - if clone: - value = value.clone() - self.set_at_(key, value, index) + unbind_dim = self.stack_dim - num_single + for (i, _idx), _value in zip( + converted_idx.items(), + input_dict_or_td.unbind(unbind_dim), + ): + self.tensordicts[i].update_at_( + _value, + _idx, + ) return self def rename_key_(self, old_key: str, new_key: str, safe: bool = False) -> T: - def sort_keys(element): - if isinstance(element, tuple): - return "_-|-_".join(element) - return element - for td in self.tensordicts: td.rename_key_(old_key, new_key, safe=safe) return self @@ -1800,7 +2074,7 @@ def where(self, condition, other, *, out=None, pad=None): and other.shape[: self.stack_dim] == self.shape[: self.stack_dim] ): other = other.unbind(self.stack_dim) - result = torch.stack( + result = LazyStackedTensorDict.maybe_dense_stack( [ td.where(cond, _other, pad=pad) for td, cond, _other in zip(self.tensordicts, condition, other) @@ -1808,7 +2082,7 @@ def where(self, condition, other, *, out=None, pad=None): self.stack_dim, ) else: - result = torch.stack( + result = LazyStackedTensorDict.maybe_dense_stack( [ td.where(cond, other, pad=pad) for td, cond in zip(self.tensordicts, condition) @@ -1996,8 +2270,6 @@ def _repr_exclusive_fields(self): unlock_ = TensorDictBase.unlock_ unlock = _renamed_inplace_method(unlock_) - __xor__ = TensorDict.__xor__ - __or__ = TensorDict.__or__ _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared _convert_to_tensordict = TensorDict._convert_to_tensordict diff --git a/tensordict/_td.py b/tensordict/_td.py index 0358f4e21..9c876e1fc 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -34,11 +34,6 @@ TensorDictBase, ) -# from tensordict._memmap import ( -# empty_like as empty_like_memmap, -# from_filename, -# from_tensor as from_tensor_memmap, -# ) from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -551,17 +546,23 @@ def __setitem__( if isinstance(value, (TensorDictBase, dict)): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): - value = self.empty(recurse=True)[index].update(value) + value = TensorDict.from_dict(value, batch_size=indexed_bs) + # value = self.empty(recurse=True)[index].update(value) if value.batch_size != indexed_bs: - # try to expand on the left (broadcasting) - try: + if value.shape == indexed_bs[-len(value.shape) :]: + # try to expand on the left (broadcasting) value = value.expand(indexed_bs) - except RuntimeError as err: - raise RuntimeError( - f"indexed destination TensorDict batch size is {indexed_bs} " - f"(batch_size = {self.batch_size}, index={index}), " - f"which differs from the source batch size {value.batch_size}" - ) from err + else: + try: + # copy and change batch_size if can't be expanded + value = value.copy() + value.batch_size = indexed_bs + except RuntimeError as err: + raise RuntimeError( + f"indexed destination TensorDict batch size is {indexed_bs} " + f"(batch_size = {self.batch_size}, index={index}), " + f"which differs from the source batch size {value.batch_size}" + ) from err keys = set(self.keys()) if any(key not in keys for key in value.keys()): @@ -1186,16 +1187,9 @@ def is_boolean(idx): if num_boolean_dim: names = [None] + names[num_boolean_dim:] else: - # def is_int(subidx): - # if isinstance(subidx, Number): - # return True - # if isinstance(subidx, Tensor) and len(subidx.shape) == 0: - # return True - # return False - if not isinstance(idx, tuple): idx = (idx,) - if len(idx) < self.ndim: + if len([_idx for _idx in idx if _idx is not None]) < self.ndim: idx = (*idx, Ellipsis) idx_names = convert_ellipsis_to_idx(idx, self.batch_size) # this will convert a [None, :, :, 0, None, 0] in [None, 0, 1, None, 3] @@ -2230,6 +2224,13 @@ def update( if input_dict_or_td is self: # no op return self + from ._lazy import LazyStackedTensorDict + + if isinstance(self._source, LazyStackedTensorDict): + if self._source._has_exclusive_keys: + raise RuntimeError( + "Cannot use _SubTensorDict.update with a LazyStackedTensorDict that has exclusive keys." + ) if keys_to_update is not None: if len(keys_to_update) == 0: return self diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 33fc3bbf2..67d06a972 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -24,11 +24,10 @@ ) from torch import Tensor -T = TypeVar("T", bound="TensorDictBase") - TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {} +T = TypeVar("T", bound="TensorDictBase") def implements_for_td(torch_function: Callable) -> Callable[[Callable], Callable]: diff --git a/tensordict/base.py b/tensordict/base.py index a1c27a5f9..f179ea1f4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -960,7 +960,7 @@ def transpose(self, dim0, dim1): @property def transpose(self): - """Returns a tensordit that is a transposed version of input. The given dimensions ``dim0`` and ``dim1`` are swapped. + """Returns a tensordict that is a transposed version of input. The given dimensions ``dim0`` and ``dim1`` are swapped. In-place or out-place modifications of the transposed tensordict will impact the original tensordict too as the memory is shared and the operations @@ -985,21 +985,6 @@ def _transpose(self, dim0, dim1): ... def _legacy_transpose(self, dim0, dim1): - """Returns a tensordit that is a transposed version of input. The given dimensions ``dim0`` and ``dim1`` are swapped. - - In-place or out-place modifications of the transposed tensordict will - impact the original tensordict too as the memory is shared and the operations - are mapped back on the original tensordict. - - Examples: - >>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4]) - >>> tensordict_transpose = tensordict.transpose(0, 1) - >>> print(tensordict_transpose.shape) - torch.Size([4, 3]) - >>> tensordict_transpose.set("b",, torch.randn(4, 3)) - >>> print(tensordict.get("b").shape) - torch.Size([3, 4]) - """ if dim0 < 0: dim0 = self.ndim + dim0 if dim1 < 0: @@ -1784,7 +1769,15 @@ def memmap_like( if return_early: executor = ThreadPoolExecutor(max_workers=num_threads) futures = [] - result = self._memmap_( + # we create an empty copy of self + # This is because calling MMapTensor.from_tensor(mmap_tensor) does nothing + # if both are in filesystem + input = self.apply( + lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand( + x.shape + ) + ) + result = input._memmap_( prefix=prefix, copy_existing=copy_existing, executor=executor, @@ -1797,7 +1790,10 @@ def memmap_like( return result else: return TensorDictFuture(futures, result) - return self._memmap_( + input = self.apply( + lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape) + ) + return input._memmap_( prefix=prefix, copy_existing=copy_existing, inplace=False, @@ -2356,11 +2352,6 @@ def update_( for ktu in keys_to_update ): continue - # if not isinstance(value, _accepted_classes): - # raise TypeError( - # f"Expected value to be one of types {_accepted_classes} " - # f"but got {type(value)}" - # ) if clone: value = value.clone() self.set_((firstkey, *nextkeys), value) @@ -3326,7 +3317,7 @@ def _irecv( def reduce( self, dst, - op=dist.ReduceOp.SUM, + op=None, async_op=False, return_premature=False, group=None, @@ -3336,17 +3327,21 @@ def reduce( Only the process with ``rank`` dst is going to receive the final result. """ + if op is None: + op = dist.ReduceOp.SUM return self._reduce(dst, op, async_op, return_premature, group=group) def _reduce( self, dst, - op=dist.ReduceOp.SUM, + op=None, async_op=False, return_premature=False, _future_list=None, group=None, ): + if op is None: + op = dist.ReduceOp.SUM root = False if _future_list is None: _future_list = [] diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index fcf8bfdd6..b9b15cf1a 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -205,6 +205,8 @@ def _carry_over(func): @wraps(func) def new_func(self, *args, **kwargs): out = getattr(self._param_td, name)(*args, **kwargs) + if out is self._param_td: + return self if not isinstance(out, TensorDictParams): out = TensorDictParams(out, no_convert=True) out.no_convert = self.no_convert diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 09f93d33c..1783999c1 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -41,6 +41,15 @@ from torch import Tensor +def _make_data(shape): + return MyData( + X=torch.rand(*shape), + y=torch.rand(*shape), + z="test_tensorclass", + batch_size=shape[:1], + ) + + class MyData: X: torch.Tensor y: torch.Tensor @@ -62,1537 +71,1416 @@ class MyData2: z: list -def test_dataclass(): - data = MyData( - X=torch.ones(3, 4, 5), - y=torch.zeros(3, 4, 5, dtype=torch.bool), - z="test_tensorclass", - batch_size=[3, 4], - ) - assert dataclasses.is_dataclass(data) +class TestTensorClass: + def test_all_any(self): + @tensorclass + class MyClass1: + x: torch.Tensor + z: str + y: "MyClass1" = None + + # with all 0 + x = MyClass1( + torch.zeros(3, 1), + "z", + MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]), + batch_size=[3, 1], + ) + assert not x.all() + assert not x.any() + assert isinstance(x.all(), bool) + assert isinstance(x.any(), bool) + for dim in [0, 1, -1, -2]: + assert isinstance(x.all(dim=dim), MyClass1) + assert isinstance(x.any(dim=dim), MyClass1) + assert not x.all(dim=dim).all() + assert not x.any(dim=dim).any() + # with all 1 + x = x.apply(lambda x: x.fill_(1.0)) + assert isinstance(x, MyClass1) + assert x.all() + assert x.any() + assert isinstance(x.all(), bool) + assert isinstance(x.any(), bool) + for dim in [0, 1]: + assert isinstance(x.all(dim=dim), MyClass1) + assert isinstance(x.any(dim=dim), MyClass1) + assert x.all(dim=dim).all() + assert x.any(dim=dim).any() + + # with 0 and 1 + x.y.x.fill_(0.0) + assert not x.all() + assert x.any() + assert isinstance(x.all(), bool) + assert isinstance(x.any(), bool) + for dim in [0, 1]: + assert isinstance(x.all(dim=dim), MyClass1) + assert isinstance(x.any(dim=dim), MyClass1) + assert not x.all(dim=dim).all() + assert x.any(dim=dim).any() + + assert not x.y.all() + assert not x.y.any() + + def test_args(self): + @tensorclass + class MyData: + D: torch.Tensor + B: torch.Tensor + A: torch.Tensor + C: torch.Tensor + E: str + + D = torch.ones(3, 4, 5) + B = torch.ones(3, 4, 5) + A = torch.ones(3, 4, 5) + C = torch.ones(3, 4, 5) + E = "test_tensorclass" + data1 = MyData(D, B=B, A=A, C=C, E=E, batch_size=[3, 4]) + data2 = MyData(D, B, A=A, C=C, E=E, batch_size=[3, 4]) + data3 = MyData(D, B, A, C=C, E=E, batch_size=[3, 4]) + data4 = MyData(D, B, A, C, E=E, batch_size=[3, 4]) + data5 = MyData(D, B, A, C, E, batch_size=[3, 4]) + data = torch.stack([data1, data2, data3, data4, data5], 0) + assert (data.A == A).all() + assert (data.B == B).all() + assert (data.C == C).all() + assert (data.D == D).all() + assert data.E == E + + def test_attributes(self): + X = torch.ones(3, 4, 5) + y = torch.zeros(3, 4, 5, dtype=torch.bool) + batch_size = [3, 4] + z = "test_tensorclass" + tensordict = TensorDict( + { + "X": X, + "y": y, + }, + batch_size=[3, 4], + ) + data = MyData(X=X, y=y, z=z, batch_size=batch_size) -def test_type(): - data = MyData( - X=torch.ones(3, 4, 5), - y=torch.zeros(3, 4, 5, dtype=torch.bool), - z="test_tensorclass", - batch_size=[3, 4], - ) - assert isinstance(data, MyData) - assert is_tensorclass(data) - assert is_tensorclass(MyData) - # we get an instance of the user defined class, not a dynamically defined subclass - assert type(data) is MyDataUndecorated + equality_tensordict = data._tensordict == tensordict + assert torch.equal(data.X, X) + assert torch.equal(data.y, y) + assert data.batch_size == torch.Size(batch_size) + assert equality_tensordict.all() + assert equality_tensordict.batch_size == torch.Size(batch_size) + assert data.z == z -def test_signature(): - sig = inspect.signature(MyData) - assert list(sig.parameters) == ["X", "y", "z", "batch_size", "device", "names"] + def test_banned_types(self): + @tensorclass + class MyAnyClass: + subclass: Any = None - with pytest.raises(TypeError, match="missing 3 required positional arguments"): - MyData(batch_size=[10]) + data = MyAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) + assert data.subclass is not None - with pytest.raises(TypeError, match="missing 2 required positional argument"): - MyData(X=torch.rand(10), batch_size=[10]) + @tensorclass + class MyOptAnyClass: + subclass: Optional[Any] = None - with pytest.raises(TypeError, match="missing 1 required positional argument"): - MyData(X=torch.rand(10), y=torch.rand(10), batch_size=[10], device="cpu") + data = MyOptAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) + assert data.subclass is not None - # if all positional arguments are specified, ommitting batch_size gives error - with pytest.raises( - TypeError, match="missing 1 required keyword-only argument: 'batch_size'" - ): - MyData(X=torch.rand(10), y=torch.rand(10)) + @tensorclass + class MyUnionAnyClass: + subclass: Union[Any] = None - # all positional arguments + batch_size is fine - MyData(X=torch.rand(10), y=torch.rand(10), z="test_tensorclass", batch_size=[10]) + data = MyUnionAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) + assert data.subclass is not None + @tensorclass + class MyUnionAnyTDClass: + subclass: Union[Any, TensorDict] = None -@pytest.mark.parametrize("device", get_available_devices()) -def test_device(device): - data = MyData( - X=torch.ones(3, 4, 5), - y=torch.zeros(3, 4, 5, dtype=torch.bool), - z="test_tensorclass", - batch_size=[3, 4], - device=device, - ) - assert data.device == device - assert data.X.device == device - assert data.y.device == device + data = MyUnionAnyTDClass(subclass=torch.ones(3, 4), batch_size=[3]) + assert data.subclass is not None - with pytest.raises(AttributeError, match="'str' object has no attribute 'device'"): - assert data.z.device == device + @tensorclass + class MyOptionalClass: + subclass: Optional[TensorDict] = None - with pytest.raises( - RuntimeError, match="device cannot be set using tensorclass.device = device" - ): - data.device = torch.device("cpu") + data = MyOptionalClass(subclass=TensorDict({}, [3]), batch_size=[3]) + assert data.subclass is not None + data = MyOptionalClass(subclass=torch.ones(3), batch_size=[3]) + assert data.subclass is not None -def test_banned_types(): - @tensorclass - class MyAnyClass: - subclass: Any = None + @tensorclass + class MyUnionClass: + subclass: Union[MyOptionalClass, TensorDict] = None - data = MyAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) - assert data.subclass is not None + data = MyUnionClass( + subclass=MyUnionClass._from_tensordict(TensorDict({}, [3])), batch_size=[3] + ) + assert data.subclass is not None + + def test_batch_size(self): + myc = MyData( + X=torch.rand(2, 3, 4), + y=torch.rand(2, 3, 4, 5), + z="test_tensorclass", + batch_size=[2, 3], + ) - @tensorclass - class MyOptAnyClass: - subclass: Optional[Any] = None + assert myc.batch_size == torch.Size([2, 3]) + assert myc.X.shape == torch.Size([2, 3, 4]) - data = MyOptAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) - assert data.subclass is not None + myc.batch_size = torch.Size([2]) - @tensorclass - class MyUnionAnyClass: - subclass: Union[Any] = None + assert myc.batch_size == torch.Size([2]) + assert myc.X.shape == torch.Size([2, 3, 4]) - data = MyUnionAnyClass(subclass=torch.ones(3, 4), batch_size=[3]) - assert data.subclass is not None + def test_cat(self): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data1 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + data2 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + + catted_tc = torch.cat([data1, data2], 0) + assert type(catted_tc) is type(data1) + assert isinstance(catted_tc.y, type(data1.y)) + assert catted_tc.X.shape == torch.Size([6, 4, 5]) + assert catted_tc.y.X.shape == torch.Size([6, 4, 5]) + assert (catted_tc.X == 1).all() + assert (catted_tc.y.X == 1).all() + assert isinstance(catted_tc._tensordict, TensorDict) + assert catted_tc.z == catted_tc.y.z == z + + # Testing negative scenarios + y = torch.zeros(3, 4, 5, dtype=torch.bool) + data3 = MyData(X=X, y=y, z=z, batch_size=batch_size) + + with pytest.raises( + TypeError, + match=("Multiple dispatch failed|no implementation found"), + ): + torch.cat([data1, data3], dim=0) + + def test_clone(self): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + clone_tc = torch.clone(data) + assert clone_tc.batch_size == torch.Size(data.batch_size) + assert torch.all(torch.eq(clone_tc.X, data.X)) + assert isinstance(clone_tc.y, MyDataNested) + assert torch.all(torch.eq(clone_tc.y.X, data.y.X)) + assert clone_tc.z == data.z == z + + def test_dataclass(self): + data = MyData( + X=torch.ones(3, 4, 5), + y=torch.zeros(3, 4, 5, dtype=torch.bool), + z="test_tensorclass", + batch_size=[3, 4], + ) + assert dataclasses.is_dataclass(data) - @tensorclass - class MyUnionAnyTDClass: - subclass: Union[Any, TensorDict] = None + def test_default(self): + @tensorclass + class MyData: + X: torch.Tensor = ( + None # TODO: do we want to allow any default, say an integer? + ) + y: torch.Tensor = torch.ones(3, 4, 5) + + data = MyData(batch_size=[3, 4]) + assert (data.y == 1).all() + assert data.X is None + data.X = torch.zeros(3, 4, 1) + assert (data.X == 0).all() + + MyData(batch_size=[3]) + MyData(batch_size=[]) + with pytest.raises(RuntimeError, match="batch dimension mismatch"): + MyData(batch_size=[4]) + + def test_defaultfactory(self): + @tensorclass + class MyData: + X: torch.Tensor = ( + None # TODO: do we want to allow any default, say an integer? + ) + y: torch.Tensor = dataclasses.field( + default_factory=lambda: torch.ones(3, 4, 5) + ) + + data = MyData(batch_size=[3, 4]) + assert (data.y == 1).all() + assert data.X is None + data.X = torch.zeros(3, 4, 1) + assert (data.X == 0).all() + + MyData(batch_size=[3]) + MyData(batch_size=[]) + with pytest.raises(RuntimeError, match="batch dimension mismatch"): + MyData(batch_size=[4]) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_device(self, device): + data = MyData( + X=torch.ones(3, 4, 5), + y=torch.zeros(3, 4, 5, dtype=torch.bool), + z="test_tensorclass", + batch_size=[3, 4], + device=device, + ) + assert data.device == device + assert data.X.device == device + assert data.y.device == device + + with pytest.raises( + AttributeError, match="'str' object has no attribute 'device'" + ): + assert data.z.device == device + + with pytest.raises( + RuntimeError, match="device cannot be set using tensorclass.device = device" + ): + data.device = torch.device("cpu") + + def test_disallowed_attributes(self): + with pytest.raises( + AttributeError, + match="Attribute name reshape can't be used with @tensorclass", + ): + + @tensorclass + class MyInvalidClass: + x: torch.Tensor + y: torch.Tensor + reshape: torch.Tensor + + def test_equal(self): + @tensorclass + class MyClass1: + x: torch.Tensor + z: str + y: "MyClass1" = None - data = MyUnionAnyTDClass(subclass=torch.ones(3, 4), batch_size=[3]) - assert data.subclass is not None + @tensorclass + class MyClass2: + x: torch.Tensor + z: str + y: "MyClass2" = None + + a = MyClass1( + torch.zeros(3), + "z0", + MyClass1( + torch.ones(3), + "z1", + None, + batch_size=[3], + ), + batch_size=[3], + ) + b = MyClass2( + torch.zeros(3), + "z0", + MyClass2( + torch.ones(3), + "z1", + None, + batch_size=[3], + ), + batch_size=[3], + ) + c = TensorDict({"x": torch.zeros(3), "y": {"x": torch.ones(3)}}, batch_size=[3]) - @tensorclass - class MyOptionalClass: - subclass: Optional[TensorDict] = None + assert (a == a.clone()).all() + assert (a != 1.0).any() + assert (a[:2] != 1.0).any() - data = MyOptionalClass(subclass=TensorDict({}, [3]), batch_size=[3]) - assert data.subclass is not None + assert (a.y == 1).all() + assert (a[:2].y == 1).all() + assert (a.y[:2] == 1).all() - data = MyOptionalClass(subclass=torch.ones(3), batch_size=[3]) - assert data.subclass is not None + assert (a != torch.ones([])).any() + assert (a.y == torch.ones([])).all() - @tensorclass - class MyUnionClass: - subclass: Union[MyOptionalClass, TensorDict] = None + assert (a == b).all() + assert (b == a).all() + assert (b[:2] == a[:2]).all() - data = MyUnionClass( - subclass=MyUnionClass._from_tensordict(TensorDict({}, [3])), batch_size=[3] - ) - assert data.subclass is not None - - -def test_attributes(): - X = torch.ones(3, 4, 5) - y = torch.zeros(3, 4, 5, dtype=torch.bool) - batch_size = [3, 4] - z = "test_tensorclass" - tensordict = TensorDict( - { - "X": X, - "y": y, - }, - batch_size=[3, 4], - ) + assert (a == c).all() + assert (a[:2] == c[:2]).all() - data = MyData(X=X, y=y, z=z, batch_size=batch_size) + assert (c == a).all() + assert (c[:2] == a[:2]).all() - equality_tensordict = data._tensordict == tensordict + assert (a != c.clone().zero_()).any() + assert (c != a.clone().zero_()).any() - assert torch.equal(data.X, X) - assert torch.equal(data.y, y) - assert data.batch_size == torch.Size(batch_size) - assert equality_tensordict.all() - assert equality_tensordict.batch_size == torch.Size(batch_size) - assert data.z == z + def test_from_dict(self): + td = TensorDict( + { + ("a", "b", "c"): 1, + ("a", "d"): 2, + }, + [], + ).expand(10) + d = td.to_dict() + @tensorclass + class MyClass: + a: TensorDictBase -def test_disallowed_attributes(): - with pytest.raises( - AttributeError, - match="Attribute name reshape can't be used with @tensorclass", - ): + tc = MyClass.from_dict(d) + assert isinstance(tc, MyClass) + assert isinstance(tc.a, TensorDict) + assert tc.batch_size == torch.Size([10]) + def test_full_like(self): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + full_like_tc = torch.full_like(data, 9.0) + assert type(full_like_tc) is type(data) + assert full_like_tc.batch_size == torch.Size(data.batch_size) + assert full_like_tc.X.size() == data.X.size() + assert isinstance(full_like_tc.y, type(data.y)) + assert full_like_tc.y.X.size() == data.y.X.size() + assert (full_like_tc.X == 9).all() + assert (full_like_tc.y.X == 9).all() + assert full_like_tc.z == data.z == z + + @pytest.mark.parametrize("from_torch", [True, False]) + def test_gather(self, from_torch): @tensorclass - class MyInvalidClass: + class MyClass: x: torch.Tensor - y: torch.Tensor - reshape: torch.Tensor - - -def test_batch_size(): - myc = MyData( - X=torch.rand(2, 3, 4), - y=torch.rand(2, 3, 4, 5), - z="test_tensorclass", - batch_size=[2, 3], - ) - - assert myc.batch_size == torch.Size([2, 3]) - assert myc.X.shape == torch.Size([2, 3, 4]) - - myc.batch_size = torch.Size([2]) - - assert myc.batch_size == torch.Size([2]) - assert myc.X.shape == torch.Size([2, 3, 4]) - - -def test_len(): - myc = MyData( - X=torch.rand(2, 3, 4), - y=torch.rand(2, 3, 4, 5), - z="test_tensorclass", - batch_size=[2, 3], - ) - assert len(myc) == 2 - - myc2 = MyData( - X=torch.rand(2, 3, 4), - y=torch.rand(2, 3, 4, 5), - z="test_tensorclass", - batch_size=[], - ) - assert len(myc2) == 0 - + z: str + y: "MyClass" = None -def test_indexing(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: list - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = ["a", "b", "c"] - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - - assert data[:2].batch_size == torch.Size([2, 4]) - assert data[:2].X.shape == torch.Size([2, 4, 5]) - assert (data[:2].X == X[:2]).all() - assert isinstance(data[:2].y, type(data_nest)) - - # Nested tensors all get indexed - assert (data[:2].y.X == X[:2]).all() - assert data[:2].y.batch_size == torch.Size([2, 4]) - assert data[1].batch_size == torch.Size([4]) - assert data[1][1].batch_size == torch.Size([]) - - # Non-tensor data won't get indexed - assert data[1].z == data[2].z == data[:2].z == z - - with pytest.raises( - RuntimeError, - match="indexing a tensordict with td.batch_dims==0 is not permitted", + c = MyClass( + torch.randn(3, 4), + "foo", + MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), + batch_size=[3, 4], + ) + dim = -1 + index = torch.arange(3).expand(3, 3) + if from_torch: + c_gather = torch.gather(c, index=index, dim=dim) + else: + c_gather = c.gather(index=index, dim=dim) + assert c_gather.x.shape == torch.Size([3, 3]) + assert c_gather.y.shape == torch.Size([3, 3, 5]) + assert c_gather.y.x.shape == torch.Size([3, 3, 5]) + assert c_gather.y.z == "bar" + assert c_gather.z == "foo" + c_gather_zero = c_gather.clone().zero_() + if from_torch: + c_gather2 = torch.gather(c, index=index, dim=dim, out=c_gather_zero) + else: + c_gather2 = c.gather(index=index, dim=dim, out=c_gather_zero) + + assert (c_gather2 == c_gather).all() + + def test_get( + self, ): - data[1][1][1] - - with pytest.raises(ValueError, match="Invalid indexing arguments."): - data["X"] - + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str -def test_setitem(): - data = MyData( - X=torch.ones(3, 4, 5), - y=torch.zeros(3, 4, 5), - z="test_tensorclass", - batch_size=[3, 4], - ) + @tensorclass + class MyDataParent: + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + k: Optional[Tensor] = None + + batch_size = [3, 4] + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + v = "test_tensorclass" + data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) + assert isinstance(data.y, type(data_nest)) + assert (data.get("X") == X).all() + assert data.get("batch_size") == torch.Size(batch_size) + assert data.get("v") == v + assert (data.get("z") == td).all() + + # Testing nested tensor class + assert data.get("y")._tensordict is data_nest._tensordict + assert (data.get("y").X == X).all() + assert (data.get(("y", "X")) == X).all() + assert data.get("y").v == "test_nested" + assert data.get(("y", "v")) == "test_nested" + assert data.get("y").batch_size == torch.Size(batch_size) + + # ensure optional fields are there + assert data.get("k") is None + + # ensure default works + assert data.get("foo", "working") == "working" + assert data.get(("foo", "foo2"), "working") == "working" + assert data.get(("X", "foo2"), "working") == "working" + + assert (data.get("X", "working") == X).all() + assert data.get("v", "working") == v + + @pytest.mark.parametrize("any_to_td", [True, False]) + def test_getattr(self, any_to_td): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str - x = torch.randn(3, 4, 5) - y = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data2 = MyData(X=x, y=y, z=z, batch_size=batch_size) - data3 = MyData(X=y, y=x, z=z, batch_size=batch_size) - - # Testing the data before setting - assert (data[:2].X == torch.ones(2, 4, 5)).all() - assert (data[:2].y == torch.zeros(2, 4, 5)).all() - assert data[:2].z == "test_tensorclass" - assert (data[[1, 2]].X == torch.ones(5)).all() - - # Setting the item and testing post setting the item - data[:2] = data2[:2].clone() - assert (data[:2].X == data2[:2].X).all() - assert (data[:2].y == data2[:2].y).all() - assert data[:2].z == z - - data[[1, 2]] = data3[[1, 2]].clone() - assert (data[[1, 2]].X == data3[[1, 2]].X).all() - assert (data[[1, 2]].y == data3[[1, 2]].y).all() - assert data[[1, 2]].z == z - - data[:, [1, 2]] = data2[:, [1, 2]].clone() - assert (data[:, [1, 2]].X == data2[:, [1, 2]].X).all() - assert (data[:, [1, 2]].y == data[:, [1, 2]].y).all() - assert data[:, [1, 2]].z == z - - with pytest.raises( - RuntimeError, match="indexed destination TensorDict batch size is" - ): - data[:, [1, 2]] = data.clone() - - # Negative testcase for non-tensor data - z = "test_bluff" - data2 = MyData(X=x, y=y, z=z, batch_size=batch_size) - with pytest.warns( - UserWarning, - match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", + @tensorclass + class MyDataParent: + W: Any + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + + batch_size = [3, 4] + if any_to_td: + W = TensorDict({}, batch_size) + else: + W = torch.zeros(*batch_size, 1) + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + v = "test_tensorclass" + data = MyDataParent(X=X, y=data_nest, z=td, W=W, v=v, batch_size=batch_size) + assert isinstance(data.y, type(data_nest)) + assert (data.X == X).all() + assert data.batch_size == torch.Size(batch_size) + assert data.v == v + assert (data.z == td).all() + assert (data.W == W).all() + + # Testing nested tensor class + assert data.y._tensordict is data_nest._tensordict + assert (data.y.X == X).all() + assert data.y.v == "test_nested" + assert data.y.batch_size == torch.Size(batch_size) + + def test_indexing( + self, ): - data[1] = data2[1] - - # Validating nested test cases - @tensorclass - class MyDataNested: - X: torch.Tensor - z: list - y: "MyDataNested" = None - - X = torch.randn(3, 4, 5) - z = ["a", "b", "c"] - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - X2 = torch.ones(3, 4, 5) - data_nest2 = MyDataNested(X=X2, z=z, batch_size=batch_size) - data2 = MyDataNested(X=X2, y=data_nest2, z=z, batch_size=batch_size) - data[:2] = data2[:2].clone() - assert (data[:2].X == data2[:2].X).all() - assert (data[:2].y.X == data2[:2].y.X).all() - assert data[:2].z == z - - # Negative Scenario - data3 = MyDataNested(X=X2, y=data_nest2, z=["e", "f"], batch_size=batch_size) - with pytest.warns( - UserWarning, - match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", + @tensorclass + class MyDataNested: + X: torch.Tensor + z: list + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = ["a", "b", "c"] + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + + assert data[:2].batch_size == torch.Size([2, 4]) + assert data[:2].X.shape == torch.Size([2, 4, 5]) + assert (data[:2].X == X[:2]).all() + assert isinstance(data[:2].y, type(data_nest)) + + # Nested tensors all get indexed + assert (data[:2].y.X == X[:2]).all() + assert data[:2].y.batch_size == torch.Size([2, 4]) + assert data[1].batch_size == torch.Size([4]) + assert data[1][1].batch_size == torch.Size([]) + + # Non-tensor data won't get indexed + assert data[1].z == data[2].z == data[:2].z == z + + with pytest.raises( + RuntimeError, + match="indexing a tensordict with td.batch_dims==0 is not permitted", + ): + data[1][1][1] + + with pytest.raises(ValueError, match="Invalid indexing arguments."): + data["X"] + + def test_kjt( + self, ): - data[:2] = data3[:2] + try: + from torchrec import KeyedJaggedTensor + except ImportError: + pytest.skip("TorchRec not installed.") + + def _get_kjt( + self, + ): + values = torch.Tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0] + ) + weights = torch.Tensor( + [1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0] + ) + keys = ["index_0", "index_1", "index_2"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8, 9, 10, 11]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + return jag_tensor + + kjt = _get_kjt() + @tensorclass + class MyData: + X: torch.Tensor + y: KeyedJaggedTensor + z: str -def test_setitem_memmap(): - # regression test PR #203 - # We should be able to set tensors items with MemoryMappedTensors and viceversa - @tensorclass - class MyDataMemMap1: - x: torch.Tensor - y: MemoryMappedTensor - - data1 = MyDataMemMap1( - x=torch.zeros(3, 4, 5), - y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), - batch_size=[3, 4], - ) - - data2 = MyDataMemMap1( - x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), - y=torch.ones(3, 4, 5), - batch_size=[3, 4], - ) - - data1[:2] = data2[:2] - assert (data1[:2] == 1).all() - assert (data1.x[:2] == 1).all() - assert (data1.y[:2] == 1).all() - data2[2:] = data1[2:] - assert (data2[2:] == 0).all() - assert (data2.x[2:] == 0).all() - assert (data2.y[2:] == 0).all() + z = "test_tensorclass" + data = MyData(X=torch.zeros(3, 1), y=kjt, z=z, batch_size=[3]) + subdata = data[:2] + assert ( + subdata.y["index_0"].to_padded_dense() + == torch.tensor([[1.0, 2.0], [0.0, 0.0]]) + ).all() + + subdata = data[[0, 2]] + assert ( + subdata.y["index_0"].to_padded_dense() + == torch.tensor([[1.0, 2.0], [3.0, 0.0]]) + ).all() + assert subdata.z == data.z == z + + def test_len( + self, + ): + myc = MyData( + X=torch.rand(2, 3, 4), + y=torch.rand(2, 3, 4, 5), + z="test_tensorclass", + batch_size=[2, 3], + ) + assert len(myc) == 2 + myc2 = MyData( + X=torch.rand(2, 3, 4), + y=torch.rand(2, 3, 4, 5), + z="test_tensorclass", + batch_size=[], + ) + assert len(myc2) == 0 -def test_setitem_other_cls(): - @tensorclass - class MyData1: - x: torch.Tensor - y: MemoryMappedTensor - - data1 = MyData1( - x=torch.zeros(3, 4, 5), - y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), - batch_size=[3, 4], - ) + def test_multiprocessing( + self, + ): + with Pool(os.cpu_count()) as p: + catted = torch.cat(p.map(_make_data, [(i, 2) for i in range(1, 9)]), dim=0) - # Set Item should work for other tensorclass - @tensorclass - class MyData2: - x: MemoryMappedTensor - y: torch.Tensor - - data_other_cls = MyData2( - x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), - y=torch.ones(3, 4, 5), - batch_size=[3, 4], - ) - data1[:2] = data_other_cls[:2] - data_other_cls[2:] = data1[2:] + assert catted.batch_size == torch.Size([36]) + assert catted.z == "test_tensorclass" - # Set Item should raise if other tensorclass with different members - @tensorclass - class MyData3: - x: MemoryMappedTensor - z: torch.Tensor - - data_wrong_cls = MyData3( - x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), - z=torch.ones(3, 4, 5), - batch_size=[3, 4], - ) - with pytest.raises( - ValueError, - match="__setitem__ is only allowed for same-class or compatible class .* assignment", + def test_nested( + self, ): - data1[:2] = data_wrong_cls[:2] - with pytest.raises( - ValueError, - match="__setitem__ is only allowed for same-class or compatible class .* assignment", + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + assert isinstance(data.y, MyDataNested), type(data.y) + assert data.z == data_nest.z == data.y.z == z + + def test_nested_eq( + self, ): - data_wrong_cls[2:] = data1[2:] - + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + data_nest2 = MyDataNested(X=X, z=z, batch_size=batch_size) + data2 = MyDataNested(X=X, y=data_nest2, z=z, batch_size=batch_size) + assert (data == data2).all() + assert (data == data2).X.all() + assert (data == data2).z is None + assert (data == data2).y.X.all() + assert (data == data2).y.z is None + + @pytest.mark.parametrize("any_to_td", [True, False]) + def test_nested_heterogeneous(self, any_to_td): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str -@pytest.mark.parametrize( - "broadcast_type", - ["scalar", "tensor", "tensordict", "maptensor"], -) -def test_setitem_broadcast(broadcast_type): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: list - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = ["a", "b", "c"] - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - - if broadcast_type == "scalar": - val = 0 - elif broadcast_type == "tensor": - val = torch.zeros(4, 5) - elif broadcast_type == "tensordict": - val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4]) - elif broadcast_type == "maptensor": - val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5)) - - data[:2] = val - assert (data[:2] == 0).all() - assert (data.X[:2] == 0).all() - assert (data.y.X[:2] == 0).all() - - -def test_stack(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data1 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - data2 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - - stacked_tc = torch.stack([data1, data2], 0) - assert type(stacked_tc) is type(data1) - assert isinstance(stacked_tc.y, type(data1.y)) - assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5]) - assert stacked_tc.y.X.shape == torch.Size([2, 3, 4, 5]) - assert (stacked_tc.X == 1).all() - assert (stacked_tc.y.X == 1).all() - assert isinstance(stacked_tc._tensordict, LazyStackedTensorDict) - assert isinstance(stacked_tc.y._tensordict, LazyStackedTensorDict) - assert stacked_tc.z == stacked_tc.y.z == z - - # Testing negative scenarios - y = torch.zeros(3, 4, 5, dtype=torch.bool) - data3 = MyData(X=X, y=y, z=z, batch_size=batch_size) - - with pytest.raises( - TypeError, - match=("Multiple dispatch failed|no implementation found"), + @tensorclass + class MyDataParent: + W: Any + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + + batch_size = [3, 4] + if any_to_td: + W = TensorDict({}, batch_size) + else: + W = torch.zeros(*batch_size, 1) + X = torch.ones(3, 4, 5) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + td = TensorDict({}, batch_size) + v = "test_tensorclass" + data = MyDataParent(X=X, y=data_nest, z=td, W=W, v=v, batch_size=batch_size) + assert isinstance(data.y, MyDataNest) + assert isinstance(data.y.X, Tensor) + assert isinstance(data.X, Tensor) + if not any_to_td: + assert isinstance(data.W, Tensor) + else: + assert isinstance(data.W, TensorDict) + assert isinstance(data, MyDataParent) + assert isinstance(data.z, TensorDict) + assert data.v == v + assert data.y.v == "test_nested" + # Testing nested indexing + assert isinstance(data[0], type(data)) + assert isinstance(data[0].y, type(data.y)) + assert data[0].y.X.shape == torch.Size([4, 5]) + + def test_nested_ne( + self, ): - torch.stack([data1, data3], dim=0) - - -def test_cat(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data1 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - data2 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - - catted_tc = torch.cat([data1, data2], 0) - assert type(catted_tc) is type(data1) - assert isinstance(catted_tc.y, type(data1.y)) - assert catted_tc.X.shape == torch.Size([6, 4, 5]) - assert catted_tc.y.X.shape == torch.Size([6, 4, 5]) - assert (catted_tc.X == 1).all() - assert (catted_tc.y.X == 1).all() - assert isinstance(catted_tc._tensordict, TensorDict) - assert catted_tc.z == catted_tc.y.z == z - - # Testing negative scenarios - y = torch.zeros(3, 4, 5, dtype=torch.bool) - data3 = MyData(X=X, y=y, z=z, batch_size=batch_size) - - with pytest.raises( - TypeError, - match=("Multiple dispatch failed|no implementation found"), + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + data_nest2 = MyDataNested(X=X, z=z, batch_size=batch_size) + z = "test_bluff" + data2 = MyDataNested(X=X + 1, y=data_nest2, z=z, batch_size=batch_size) + assert (data != data2).any() + assert (data != data2).X.all() + assert (data != data2).z is None + assert not (data != data2).y.X.any() + assert (data != data2).y.z is None + + def test_permute( + self, ): - torch.cat([data1, data3], dim=0) - - -def test_unbind(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - unbind_tcs = torch.unbind(data, 0) - assert type(unbind_tcs[1]) is type(data) - assert type(unbind_tcs[0].y[0]) is type(data) - assert len(unbind_tcs) == 3 - assert torch.all(torch.eq(unbind_tcs[0].X, torch.ones(4, 5))) - assert torch.all(torch.eq(unbind_tcs[0].y[0].X, torch.ones(4, 5))) - assert unbind_tcs[0].batch_size == torch.Size([4]) - assert unbind_tcs[0].z == unbind_tcs[1].z == unbind_tcs[2].z == z - - -def test_full_like(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - full_like_tc = torch.full_like(data, 9.0) - assert type(full_like_tc) is type(data) - assert full_like_tc.batch_size == torch.Size(data.batch_size) - assert full_like_tc.X.size() == data.X.size() - assert isinstance(full_like_tc.y, type(data.y)) - assert full_like_tc.y.X.size() == data.y.X.size() - assert (full_like_tc.X == 9).all() - assert (full_like_tc.y.X == 9).all() - assert full_like_tc.z == data.z == z - - -def test_clone(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - clone_tc = torch.clone(data) - assert clone_tc.batch_size == torch.Size(data.batch_size) - assert torch.all(torch.eq(clone_tc.X, data.X)) - assert isinstance(clone_tc.y, MyDataNested) - assert torch.all(torch.eq(clone_tc.y.X, data.y.X)) - assert clone_tc.z == data.z == z - - -def test_squeeze(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(1, 4, 5) - z = "test_tensorclass" - batch_size = [1, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - squeeze_tc = torch.squeeze(data) - assert squeeze_tc.batch_size == torch.Size([4]) - assert squeeze_tc.X.shape == torch.Size([4, 5]) - assert squeeze_tc.y.X.shape == torch.Size([4, 5]) - assert squeeze_tc.z == squeeze_tc.y.z == z - - -def test_unsqueeze(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - unsqueeze_tc = torch.unsqueeze(data, dim=1) - assert unsqueeze_tc.batch_size == torch.Size([3, 1, 4]) - assert unsqueeze_tc.X.shape == torch.Size([3, 1, 4, 5]) - assert unsqueeze_tc.y.X.shape == torch.Size([3, 1, 4, 5]) - assert unsqueeze_tc.z == unsqueeze_tc.y.z == z - - -def test_split(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 6, 5) - z = "test_tensorclass" - batch_size = [3, 6] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyData(X=X, y=data_nest, z=z, batch_size=batch_size) - split_tcs = torch.split(data, split_size_or_sections=[3, 2, 1], dim=1) - assert type(split_tcs[1]) is type(data) - assert split_tcs[0].batch_size == torch.Size([3, 3]) - assert split_tcs[1].batch_size == torch.Size([3, 2]) - assert split_tcs[2].batch_size == torch.Size([3, 1]) - - assert split_tcs[0].y.batch_size == torch.Size([3, 3]) - assert split_tcs[1].y.batch_size == torch.Size([3, 2]) - assert split_tcs[2].y.batch_size == torch.Size([3, 1]) - - assert torch.all(torch.eq(split_tcs[0].X, torch.ones(3, 3, 5))) - assert torch.all(torch.eq(split_tcs[0].y[0].X, torch.ones(3, 3, 5))) - assert split_tcs[0].z == split_tcs[1].z == split_tcs[2].z == z - assert split_tcs[0].y[0].z == split_tcs[0].y[1].z == split_tcs[0].y[2].z == z - - -def test_reshape(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - stacked_tc = data.reshape(-1) - assert stacked_tc.X.shape == torch.Size([12, 5]) - assert stacked_tc.y.X.shape == torch.Size([12, 5]) - assert stacked_tc.shape == torch.Size([12]) - assert (stacked_tc.X == 1).all() - assert isinstance(stacked_tc._tensordict, TensorDict) - assert stacked_tc.z == stacked_tc.y.z == z - - -def test_view(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - stacked_tc = data.view(-1) - assert stacked_tc.X.shape == torch.Size([12, 5]) - assert stacked_tc.y.X.shape == torch.Size([12, 5]) - assert stacked_tc.shape == torch.Size([12]) - assert (stacked_tc.X == 1).all() - assert isinstance(stacked_tc._tensordict, _ViewedTensorDict) - assert stacked_tc.z == stacked_tc.y.z == z - - -def test_permute(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - stacked_tc = data.permute(1, 0) - assert stacked_tc.X.shape == torch.Size([4, 3, 5]) - assert stacked_tc.y.X.shape == torch.Size([4, 3, 5]) - assert stacked_tc.shape == torch.Size([4, 3]) - assert (stacked_tc.X == 1).all() - assert isinstance(stacked_tc._tensordict, _PermutedTensorDict) - assert stacked_tc.z == stacked_tc.y.z == z - - -def test_nested(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - assert isinstance(data.y, MyDataNested), type(data.y) - assert data.z == data_nest.z == data.y.z == z - - -def test_nested_eq(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - data_nest2 = MyDataNested(X=X, z=z, batch_size=batch_size) - data2 = MyDataNested(X=X, y=data_nest2, z=z, batch_size=batch_size) - assert (data == data2).all() - assert (data == data2).X.all() - assert (data == data2).z is None - assert (data == data2).y.X.all() - assert (data == data2).y.z is None - - -def test_nested_ne(): - @tensorclass - class MyDataNested: - X: torch.Tensor - z: str - y: "MyDataNested" = None - - X = torch.ones(3, 4, 5) - z = "test_tensorclass" - batch_size = [3, 4] - data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) - data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) - data_nest2 = MyDataNested(X=X, z=z, batch_size=batch_size) - z = "test_bluff" - data2 = MyDataNested(X=X + 1, y=data_nest2, z=z, batch_size=batch_size) - assert (data != data2).any() - assert (data != data2).X.all() - assert (data != data2).z is None - assert not (data != data2).y.X.any() - assert (data != data2).y.z is None - - -def test_args(): - @tensorclass - class MyData: - D: torch.Tensor - B: torch.Tensor - A: torch.Tensor - C: torch.Tensor - E: str - - D = torch.ones(3, 4, 5) - B = torch.ones(3, 4, 5) - A = torch.ones(3, 4, 5) - C = torch.ones(3, 4, 5) - E = "test_tensorclass" - data1 = MyData(D, B=B, A=A, C=C, E=E, batch_size=[3, 4]) - data2 = MyData(D, B, A=A, C=C, E=E, batch_size=[3, 4]) - data3 = MyData(D, B, A, C=C, E=E, batch_size=[3, 4]) - data4 = MyData(D, B, A, C, E=E, batch_size=[3, 4]) - data5 = MyData(D, B, A, C, E, batch_size=[3, 4]) - data = torch.stack([data1, data2, data3, data4, data5], 0) - assert (data.A == A).all() - assert (data.B == B).all() - assert (data.C == C).all() - assert (data.D == D).all() - assert data.E == E - - -@pytest.mark.parametrize("any_to_td", [True, False]) -def test_nested_heterogeneous(any_to_td): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str - - @tensorclass - class MyDataParent: - W: Any - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - - batch_size = [3, 4] - if any_to_td: - W = TensorDict({}, batch_size) - else: - W = torch.zeros(*batch_size, 1) - X = torch.ones(3, 4, 5) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - td = TensorDict({}, batch_size) - v = "test_tensorclass" - data = MyDataParent(X=X, y=data_nest, z=td, W=W, v=v, batch_size=batch_size) - assert isinstance(data.y, MyDataNest) - assert isinstance(data.y.X, Tensor) - assert isinstance(data.X, Tensor) - if not any_to_td: - assert isinstance(data.W, Tensor) - else: - assert isinstance(data.W, TensorDict) - assert isinstance(data, MyDataParent) - assert isinstance(data.z, TensorDict) - assert data.v == v - assert data.y.v == "test_nested" - # Testing nested indexing - assert isinstance(data[0], type(data)) - assert isinstance(data[0].y, type(data.y)) - assert data[0].y.X.shape == torch.Size([4, 5]) - - -@pytest.mark.parametrize("any_to_td", [True, False]) -def test_getattr(any_to_td): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str - - @tensorclass - class MyDataParent: - W: Any - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - - batch_size = [3, 4] - if any_to_td: - W = TensorDict({}, batch_size) - else: - W = torch.zeros(*batch_size, 1) - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - v = "test_tensorclass" - data = MyDataParent(X=X, y=data_nest, z=td, W=W, v=v, batch_size=batch_size) - assert isinstance(data.y, type(data_nest)) - assert (data.X == X).all() - assert data.batch_size == torch.Size(batch_size) - assert data.v == v - assert (data.z == td).all() - assert (data.W == W).all() - - # Testing nested tensor class - assert data.y._tensordict is data_nest._tensordict - assert (data.y.X == X).all() - assert data.y.v == "test_nested" - assert data.y.batch_size == torch.Size(batch_size) - - -@pytest.mark.parametrize("any_to_td", [True, False]) -def test_setattr(any_to_td): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str - - @tensorclass - class MyDataParent: - W: Any - X: Tensor - z: TensorDictBase - y: MyDataNest - v: Any - k: Optional[Tensor] = None - - batch_size = [3, 4] - if any_to_td: - W = TensorDict({}, batch_size) - else: - W = torch.zeros(*batch_size, 1) - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - data = MyDataParent( - X=X, y=data_nest, z=td, W=W, v="test_tensorclass", batch_size=batch_size - ) - assert isinstance(data.y, type(data_nest)) - assert data.y._tensordict is data_nest._tensordict - data.X = torch.zeros(3, 4, 5) - assert (data.X == torch.zeros(3, 4, 5)).all() - v_new = "test_bluff" - data.v = v_new - assert data.v == v_new - # check that you can't mess up the batch_size - with pytest.raises( - RuntimeError, match=re.escape("the tensor smth has shape torch.Size([1]) which") + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + stacked_tc = data.permute(1, 0) + assert stacked_tc.X.shape == torch.Size([4, 3, 5]) + assert stacked_tc.y.X.shape == torch.Size([4, 3, 5]) + assert stacked_tc.shape == torch.Size([4, 3]) + assert (stacked_tc.X == 1).all() + assert isinstance(stacked_tc._tensordict, _PermutedTensorDict) + assert stacked_tc.z == stacked_tc.y.z == z + + def test_pickle( + self, ): - data.z = TensorDict({"smth": torch.zeros(1)}, []) - # check that you can't write any attribute - with pytest.raises(AttributeError, match=re.escape("Cannot set the attribute")): - data.newattr = TensorDict({"smth": torch.zeros(1)}, []) - # Testing nested cases - data_nest.X = torch.zeros(3, 4, 5) - assert (data_nest.X == torch.zeros(3, 4, 5)).all() - assert (data.y.X == torch.zeros(3, 4, 5)).all() - assert data.y.v == "test_nested" - data.y.v = "test_nested_new" - assert data.y.v == data_nest.v == "test_nested_new" - data_nest.v = "test_nested" - assert data_nest.v == data.y.v == "test_nested" - - # Testing if user can override the type of the attribute - data.v = torch.ones(3, 4, 5) - assert (data.v == torch.ones(3, 4, 5)).all() - assert "v" in data._tensordict.keys() - assert "v" not in data._non_tensordict.keys() - - data.v = "test" - assert data.v == "test" - assert "v" not in data._tensordict.keys() - assert "v" in data._non_tensordict.keys() - - # ensure optional fields are writable - data.k = torch.zeros(3, 4, 5) - - -def test_set(): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str - - @tensorclass - class MyDataParent: - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - k: Optional[Tensor] = None - - batch_size = [3, 4] - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - data = MyDataParent( - X=X, y=data_nest, z=td, v="test_tensorclass", batch_size=batch_size - ) + data = MyData( + X=torch.ones(3, 4, 5), + y=torch.zeros(3, 4, 5, dtype=torch.bool), + z="test_tensorclass", + batch_size=[3, 4], + ) - assert isinstance(data.y, type(data_nest)) - assert data.y._tensordict is data_nest._tensordict - data.set("X", torch.zeros(3, 4, 5)) - assert (data.X == torch.zeros(3, 4, 5)).all() - v_new = "test_bluff" - data.set("v", v_new) - assert data.v == v_new - # check that you can't mess up the batch_size - with pytest.raises( - RuntimeError, match=re.escape("the tensor smth has shape torch.Size([1]) which") - ): - data.set("z", TensorDict({"smth": torch.zeros(1)}, [])) - # check that you can't write any attribute - with pytest.raises(AttributeError, match=re.escape("Cannot set the attribute")): - data.set("newattr", TensorDict({"smth": torch.zeros(1)}, [])) - - # Testing nested cases - data_nest.set("X", torch.zeros(3, 4, 5)) - assert (data_nest.X == torch.zeros(3, 4, 5)).all() - assert (data.y.X == torch.zeros(3, 4, 5)).all() - assert data.y.v == "test_nested" - data.set(("y", "v"), "test_nested_new") - assert data.y.v == data_nest.v == "test_nested_new" - data_nest.set("v", "test_nested") - assert data_nest.v == data.y.v == "test_nested" - - data.set(("y", ("v",)), "this time another string") - assert data.y.v == data_nest.v == "this time another string" - - # Testing if user can override the type of the attribute - vorig = torch.ones(3, 4, 5) - data.set("v", vorig) - assert (data.v == torch.ones(3, 4, 5)).all() - assert "v" in data._tensordict.keys() - assert "v" not in data._non_tensordict.keys() - - data.set("v", torch.zeros(3, 4, 5), inplace=True) - assert (vorig == 0).all() - with pytest.raises(RuntimeError, match="Cannot update an existing"): - data.set("v", "les chaussettes", inplace=True) - - data.set("v", "test") - assert data.v == "test" - assert "v" not in data._tensordict.keys() - assert "v" in data._non_tensordict.keys() - - with pytest.raises(RuntimeError, match="Cannot update an existing"): - data.set("v", vorig, inplace=True) - - # ensure optional fields are writable - data.set("k", torch.zeros(3, 4, 5)) - - -def test_get(): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str + with TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) - @tensorclass - class MyDataParent: - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - k: Optional[Tensor] = None - - batch_size = [3, 4] - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - v = "test_tensorclass" - data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) - assert isinstance(data.y, type(data_nest)) - assert (data.get("X") == X).all() - assert data.get("batch_size") == torch.Size(batch_size) - assert data.get("v") == v - assert (data.get("z") == td).all() - - # Testing nested tensor class - assert data.get("y")._tensordict is data_nest._tensordict - assert (data.get("y").X == X).all() - assert (data.get(("y", "X")) == X).all() - assert data.get("y").v == "test_nested" - assert data.get(("y", "v")) == "test_nested" - assert data.get("y").batch_size == torch.Size(batch_size) - - # ensure optional fields are there - assert data.get("k") is None - - # ensure default works - assert data.get("foo", "working") == "working" - assert data.get(("foo", "foo2"), "working") == "working" - assert data.get(("X", "foo2"), "working") == "working" - - assert (data.get("X", "working") == X).all() - assert data.get("v", "working") == v - - -def test_tensorclass_set_at_(): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str + with open(tempdir / "test.pkl", "wb") as f: + pickle.dump(data, f) - @tensorclass - class MyDataParent: - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - k: Optional[Tensor] = None - - batch_size = [3, 4] - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - v = "test_tensorclass" - data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) - - data.set_at_("X", 5, slice(2, 3)) - data.set_at_(("y", "X"), 5, slice(2, 3)) - assert (data.get_at("X", slice(2, 3)) == 5).all() - assert (data.get_at(("y", "X"), slice(2, 3)) == 5).all() - # assert other not changed - assert (data.get_at("X", slice(0, 2)) == 1).all() - assert (data.get_at(("y", "X"), slice(0, 2)) == 1).all() - assert (data.get_at("X", slice(3, 5)) == 1).all() - assert (data.get_at(("y", "X"), slice(3, 5)) == 1).all() - - -def test_tensorclass_get_at(): - @tensorclass - class MyDataNest: - X: torch.Tensor - v: str + with open(tempdir / "test.pkl", "rb") as f: + data2 = pickle.load(f) - @tensorclass - class MyDataParent: - X: Tensor - z: TensorDictBase - y: MyDataNest - v: str - k: Optional[Tensor] = None + assert_allclose_td(data.to_tensordict(), data2.to_tensordict()) + assert isinstance(data2, MyData) + assert data2.z == data.z - batch_size = [3, 4] - X = torch.ones(3, 4, 5) - td = TensorDict({}, batch_size) - data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) - v = "test_tensorclass" - data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) + def test_post_init( + self, + ): + @tensorclass + class MyDataPostInit: + X: torch.Tensor + y: torch.Tensor - assert (data.get("X")[2:3] == data.get_at("X", slice(2, 3))).all() - assert (data.get(("y", "X"))[2:3] == data.get_at(("y", "X"), slice(2, 3))).all() + def __post_init__(self): + assert (self.X > 0).all() + assert self.y.abs().max() <= 10 + self.y = self.y.abs() - # check default - assert data.get_at(("y", "foo"), slice(2, 3), "working") == "working" - assert data.get_at("foo", slice(2, 3), "working") == "working" + y = torch.clamp(torch.randn(3, 4), min=-10, max=10) + data = MyDataPostInit(X=torch.rand(3, 4), y=y, batch_size=[3, 4]) + assert (data.y == y.abs()).all() + # initialising from tensordict is fine + data = MyDataPostInit._from_tensordict( + TensorDict({"X": torch.rand(3, 4), "y": y}, batch_size=[3, 4]) + ) -def test_pre_allocate(): - @tensorclass - class M1: - X: Any + with pytest.raises(AssertionError): + MyDataPostInit(X=-torch.ones(2), y=torch.rand(2), batch_size=[2]) - @tensorclass - class M2: - X: Any + with pytest.raises(AssertionError): + MyDataPostInit._from_tensordict( + TensorDict({"X": -torch.ones(2), "y": torch.rand(2)}, batch_size=[2]) + ) - @tensorclass - class M3: - X: Any + def test_pre_allocate( + self, + ): + @tensorclass + class M1: + X: Any - m1 = M1(M2(M3(X=None, batch_size=[4]), batch_size=[4]), batch_size=[4]) - m2 = M1(M2(M3(X=torch.randn(2), batch_size=[]), batch_size=[]), batch_size=[]) - assert m1.X.X.X is None - m1[0] = m2 - assert (m1[0].X.X.X == m2.X.X.X).all() + @tensorclass + class M2: + X: Any + @tensorclass + class M3: + X: Any -def test_post_init(): - @tensorclass - class MyDataPostInit: - X: torch.Tensor - y: torch.Tensor - - def __post_init__(self): - assert (self.X > 0).all() - assert self.y.abs().max() <= 10 - self.y = self.y.abs() - - y = torch.clamp(torch.randn(3, 4), min=-10, max=10) - data = MyDataPostInit(X=torch.rand(3, 4), y=y, batch_size=[3, 4]) - assert (data.y == y.abs()).all() - - # initialising from tensordict is fine - data = MyDataPostInit._from_tensordict( - TensorDict({"X": torch.rand(3, 4), "y": y}, batch_size=[3, 4]) - ) + m1 = M1(M2(M3(X=None, batch_size=[4]), batch_size=[4]), batch_size=[4]) + m2 = M1(M2(M3(X=torch.randn(2), batch_size=[]), batch_size=[]), batch_size=[]) + assert m1.X.X.X is None + m1[0] = m2 + assert (m1[0].X.X.X == m2.X.X.X).all() - with pytest.raises(AssertionError): - MyDataPostInit(X=-torch.ones(2), y=torch.rand(2), batch_size=[2]) + def test_reshape( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + stacked_tc = data.reshape(-1) + assert stacked_tc.X.shape == torch.Size([12, 5]) + assert stacked_tc.y.X.shape == torch.Size([12, 5]) + assert stacked_tc.shape == torch.Size([12]) + assert (stacked_tc.X == 1).all() + assert isinstance(stacked_tc._tensordict, TensorDict) + assert stacked_tc.z == stacked_tc.y.z == z + + def test_set( + self, + ): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str - with pytest.raises(AssertionError): - MyDataPostInit._from_tensordict( - TensorDict({"X": -torch.ones(2), "y": torch.rand(2)}, batch_size=[2]) + @tensorclass + class MyDataParent: + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + k: Optional[Tensor] = None + + batch_size = [3, 4] + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + data = MyDataParent( + X=X, y=data_nest, z=td, v="test_tensorclass", batch_size=batch_size ) + assert isinstance(data.y, type(data_nest)) + assert data.y._tensordict is data_nest._tensordict + data.set("X", torch.zeros(3, 4, 5)) + assert (data.X == torch.zeros(3, 4, 5)).all() + v_new = "test_bluff" + data.set("v", v_new) + assert data.v == v_new + # check that you can't mess up the batch_size + with pytest.raises( + RuntimeError, + match=re.escape("the tensor smth has shape torch.Size([1]) which"), + ): + data.set("z", TensorDict({"smth": torch.zeros(1)}, [])) + # check that you can't write any attribute + with pytest.raises(AttributeError, match=re.escape("Cannot set the attribute")): + data.set("newattr", TensorDict({"smth": torch.zeros(1)}, [])) + + # Testing nested cases + data_nest.set("X", torch.zeros(3, 4, 5)) + assert (data_nest.X == torch.zeros(3, 4, 5)).all() + assert (data.y.X == torch.zeros(3, 4, 5)).all() + assert data.y.v == "test_nested" + data.set(("y", "v"), "test_nested_new") + assert data.y.v == data_nest.v == "test_nested_new" + data_nest.set("v", "test_nested") + assert data_nest.v == data.y.v == "test_nested" + + data.set(("y", ("v",)), "this time another string") + assert data.y.v == data_nest.v == "this time another string" + + # Testing if user can override the type of the attribute + vorig = torch.ones(3, 4, 5) + data.set("v", vorig) + assert (data.v == torch.ones(3, 4, 5)).all() + assert "v" in data._tensordict.keys() + assert "v" not in data._non_tensordict.keys() + + data.set("v", torch.zeros(3, 4, 5), inplace=True) + assert (vorig == 0).all() + with pytest.raises(RuntimeError, match="Cannot update an existing"): + data.set("v", "les chaussettes", inplace=True) + + data.set("v", "test") + assert data.v == "test" + assert "v" not in data._tensordict.keys() + assert "v" in data._non_tensordict.keys() + + with pytest.raises(RuntimeError, match="Cannot update an existing"): + data.set("v", vorig, inplace=True) + + # ensure optional fields are writable + data.set("k", torch.zeros(3, 4, 5)) + + @pytest.mark.parametrize("any_to_td", [True, False]) + def test_setattr(self, any_to_td): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str -def test_default(): - @tensorclass - class MyData: - X: torch.Tensor = None # TODO: do we want to allow any default, say an integer? - y: torch.Tensor = torch.ones(3, 4, 5) - - data = MyData(batch_size=[3, 4]) - assert (data.y == 1).all() - assert data.X is None - data.X = torch.zeros(3, 4, 1) - assert (data.X == 0).all() - - MyData(batch_size=[3]) - MyData(batch_size=[]) - with pytest.raises(RuntimeError, match="batch dimension mismatch"): - MyData(batch_size=[4]) - - -def test_defaultfactory(): - @tensorclass - class MyData: - X: torch.Tensor = None # TODO: do we want to allow any default, say an integer? - y: torch.Tensor = dataclasses.field(default_factory=lambda: torch.ones(3, 4, 5)) - - data = MyData(batch_size=[3, 4]) - assert (data.y == 1).all() - assert data.X is None - data.X = torch.zeros(3, 4, 1) - assert (data.X == 0).all() - - MyData(batch_size=[3]) - MyData(batch_size=[]) - with pytest.raises(RuntimeError, match="batch dimension mismatch"): - MyData(batch_size=[4]) - - -def test_kjt(): - try: - from torchrec import KeyedJaggedTensor - except ImportError: - pytest.skip("TorchRec not installed.") - - def _get_kjt(): - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]) - weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0]) - keys = ["index_0", "index_1", "index_2"] - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8, 9, 10, 11]) - - jag_tensor = KeyedJaggedTensor( - values=values, - keys=keys, - offsets=offsets, - weights=weights, + @tensorclass + class MyDataParent: + W: Any + X: Tensor + z: TensorDictBase + y: MyDataNest + v: Any + k: Optional[Tensor] = None + + batch_size = [3, 4] + if any_to_td: + W = TensorDict({}, batch_size) + else: + W = torch.zeros(*batch_size, 1) + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + data = MyDataParent( + X=X, y=data_nest, z=td, W=W, v="test_tensorclass", batch_size=batch_size + ) + assert isinstance(data.y, type(data_nest)) + assert data.y._tensordict is data_nest._tensordict + data.X = torch.zeros(3, 4, 5) + assert (data.X == torch.zeros(3, 4, 5)).all() + v_new = "test_bluff" + data.v = v_new + assert data.v == v_new + # check that you can't mess up the batch_size + with pytest.raises( + RuntimeError, + match=re.escape("the tensor smth has shape torch.Size([1]) which"), + ): + data.z = TensorDict({"smth": torch.zeros(1)}, []) + # check that you can't write any attribute + with pytest.raises(AttributeError, match=re.escape("Cannot set the attribute")): + data.newattr = TensorDict({"smth": torch.zeros(1)}, []) + # Testing nested cases + data_nest.X = torch.zeros(3, 4, 5) + assert (data_nest.X == torch.zeros(3, 4, 5)).all() + assert (data.y.X == torch.zeros(3, 4, 5)).all() + assert data.y.v == "test_nested" + data.y.v = "test_nested_new" + assert data.y.v == data_nest.v == "test_nested_new" + data_nest.v = "test_nested" + assert data_nest.v == data.y.v == "test_nested" + + # Testing if user can override the type of the attribute + data.v = torch.ones(3, 4, 5) + assert (data.v == torch.ones(3, 4, 5)).all() + assert "v" in data._tensordict.keys() + assert "v" not in data._non_tensordict.keys() + + data.v = "test" + assert data.v == "test" + assert "v" not in data._tensordict.keys() + assert "v" in data._non_tensordict.keys() + + # ensure optional fields are writable + data.k = torch.zeros(3, 4, 5) + + def test_setitem( + self, + ): + data = MyData( + X=torch.ones(3, 4, 5), + y=torch.zeros(3, 4, 5), + z="test_tensorclass", + batch_size=[3, 4], ) - return jag_tensor - - kjt = _get_kjt() - - @tensorclass - class MyData: - X: torch.Tensor - y: KeyedJaggedTensor - z: str - - z = "test_tensorclass" - data = MyData(X=torch.zeros(3, 1), y=kjt, z=z, batch_size=[3]) - subdata = data[:2] - assert ( - subdata.y["index_0"].to_padded_dense() == torch.tensor([[1.0, 2.0], [0.0, 0.0]]) - ).all() - - subdata = data[[0, 2]] - assert ( - subdata.y["index_0"].to_padded_dense() == torch.tensor([[1.0, 2.0], [3.0, 0.0]]) - ).all() - assert subdata.z == data.z == z - - -def test_pickle(): - data = MyData( - X=torch.ones(3, 4, 5), - y=torch.zeros(3, 4, 5, dtype=torch.bool), - z="test_tensorclass", - batch_size=[3, 4], - ) - - with TemporaryDirectory() as tempdir: - tempdir = Path(tempdir) - - with open(tempdir / "test.pkl", "wb") as f: - pickle.dump(data, f) - - with open(tempdir / "test.pkl", "rb") as f: - data2 = pickle.load(f) - - assert_allclose_td(data.to_tensordict(), data2.to_tensordict()) - assert isinstance(data2, MyData) - assert data2.z == data.z - -def _make_data(shape): - return MyData( - X=torch.rand(*shape), - y=torch.rand(*shape), - z="test_tensorclass", - batch_size=shape[:1], + x = torch.randn(3, 4, 5) + y = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data2 = MyData(X=x, y=y, z=z, batch_size=batch_size) + data3 = MyData(X=y, y=x, z=z, batch_size=batch_size) + + # Testing the data before setting + assert (data[:2].X == torch.ones(2, 4, 5)).all() + assert (data[:2].y == torch.zeros(2, 4, 5)).all() + assert data[:2].z == "test_tensorclass" + assert (data[[1, 2]].X == torch.ones(5)).all() + + # Setting the item and testing post setting the item + data[:2] = data2[:2].clone() + assert (data[:2].X == data2[:2].X).all() + assert (data[:2].y == data2[:2].y).all() + assert data[:2].z == z + + data[[1, 2]] = data3[[1, 2]].clone() + assert (data[[1, 2]].X == data3[[1, 2]].X).all() + assert (data[[1, 2]].y == data3[[1, 2]].y).all() + assert data[[1, 2]].z == z + + data[:, [1, 2]] = data2[:, [1, 2]].clone() + assert (data[:, [1, 2]].X == data2[:, [1, 2]].X).all() + assert (data[:, [1, 2]].y == data[:, [1, 2]].y).all() + assert data[:, [1, 2]].z == z + + with pytest.raises( + RuntimeError, match="indexed destination TensorDict batch size is" + ): + data[:, [1, 2]] = data.clone() + + # Negative testcase for non-tensor data + z = "test_bluff" + data2 = MyData(X=x, y=y, z=z, batch_size=batch_size) + with pytest.warns( + UserWarning, + match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", + ): + data[1] = data2[1] + + # Validating nested test cases + @tensorclass + class MyDataNested: + X: torch.Tensor + z: list + y: "MyDataNested" = None + + X = torch.randn(3, 4, 5) + z = ["a", "b", "c"] + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + X2 = torch.ones(3, 4, 5) + data_nest2 = MyDataNested(X=X2, z=z, batch_size=batch_size) + data2 = MyDataNested(X=X2, y=data_nest2, z=z, batch_size=batch_size) + data[:2] = data2[:2].clone() + assert (data[:2].X == data2[:2].X).all() + assert (data[:2].y.X == data2[:2].y.X).all() + assert data[:2].z == z + + # Negative Scenario + data3 = MyDataNested(X=X2, y=data_nest2, z=["e", "f"], batch_size=batch_size) + with pytest.warns( + UserWarning, + match="Meta data at 'z' may or may not be equal, this may result in undefined behaviours", + ): + data[:2] = data3[:2] + + @pytest.mark.parametrize( + "broadcast_type", + ["scalar", "tensor", "tensordict", "maptensor"], ) + def test_setitem_broadcast(self, broadcast_type): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: list + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = ["a", "b", "c"] + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + + if broadcast_type == "scalar": + val = 0 + elif broadcast_type == "tensor": + val = torch.zeros(4, 5) + elif broadcast_type == "tensordict": + val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4]) + elif broadcast_type == "maptensor": + val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5)) + + data[:2] = val + assert (data[:2] == 0).all() + assert (data.X[:2] == 0).all() + assert (data.y.X[:2] == 0).all() + + def test_setitem_memmap( + self, + ): + # regression test PR #203 + # We should be able to set tensors items with MemoryMappedTensors and viceversa + @tensorclass + class MyDataMemMap1: + x: torch.Tensor + y: MemoryMappedTensor + data1 = MyDataMemMap1( + x=torch.zeros(3, 4, 5), + y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), + batch_size=[3, 4], + ) -def test_multiprocessing(): - with Pool(os.cpu_count()) as p: - catted = torch.cat(p.map(_make_data, [(i, 2) for i in range(1, 9)]), dim=0) - - assert catted.batch_size == torch.Size([36]) - assert catted.z == "test_tensorclass" - - -@pytest.mark.skipif( - not _has_torchsnapshot, reason=f"torchsnapshot not found: err={TORCHSNAPSHOT_ERR}" -) -def test_torchsnapshot(tmp_path): - @tensorclass - class MyClass: - x: torch.Tensor - z: str - y: "MyClass" = None - - z = "test_tensorclass" - tc = MyClass( - x=torch.randn(3), - z=z, - y=MyClass(x=torch.randn(3), z=z, batch_size=[]), - batch_size=[], - ) - tc.memmap_() - assert isinstance(tc.y.x, MemoryMappedTensor) - assert tc.z == z - - app_state = { - "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True)) - } - snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=str(tmp_path)) - - tc_dest = MyClass( - x=torch.randn(3), - z="other", - y=MyClass(x=torch.randn(3), z=z, batch_size=[]), - batch_size=[], - ) - tc_dest.memmap_() - assert isinstance(tc_dest.y.x, MemoryMappedTensor) - app_state = { - "state": torchsnapshot.StateDict(tensordict=tc_dest.state_dict(keep_vars=True)) - } - snapshot.restore(app_state=app_state) - - assert (tc_dest == tc).all() - assert tc_dest.y.batch_size == tc.y.batch_size - assert isinstance(tc_dest.y.x, MemoryMappedTensor) - # torchsnapshot does not support updating strings and such - assert tc_dest.z != z - - tc_dest = MyClass( - x=torch.randn(3), - z="other", - y=MyClass(x=torch.randn(3), z=z, batch_size=[]), - batch_size=[], - ) - tc_dest.memmap_() - tc_dest.load_state_dict(tc.state_dict()) - assert (tc_dest == tc).all() - assert tc_dest.y.batch_size == tc.y.batch_size - assert isinstance(tc_dest.y.x, MemoryMappedTensor) - # load_state_dict outperforms snapshot in this case - assert tc_dest.z == z + data2 = MyDataMemMap1( + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), + y=torch.ones(3, 4, 5), + batch_size=[3, 4], + ) + data1[:2] = data2[:2] + assert (data1[:2] == 1).all() + assert (data1.x[:2] == 1).all() + assert (data1.y[:2] == 1).all() + data2[2:] = data1[2:] + assert (data2[2:] == 0).all() + assert (data2.x[2:] == 0).all() + assert (data2.y[2:] == 0).all() + + def test_setitem_other_cls( + self, + ): + @tensorclass + class MyData1: + x: torch.Tensor + y: MemoryMappedTensor -def test_statedict_errors(): - @tensorclass - class MyClass: - x: torch.Tensor - z: str - y: "MyClass" = None - - z = "test_tensorclass" - tc = MyClass( - x=torch.randn(3), - z=z, - y=MyClass(x=torch.randn(3), z=z, batch_size=[]), - batch_size=[], - ) + data1 = MyData1( + x=torch.zeros(3, 4, 5), + y=MemoryMappedTensor.from_tensor(torch.zeros(3, 4, 5)), + batch_size=[3, 4], + ) - sd = tc.state_dict() - sd["a"] = None - with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): - tc.load_state_dict(sd) - del sd["a"] - sd["_tensordict"]["a"] = None - with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): - tc.load_state_dict(sd) - del sd["_tensordict"]["a"] - sd["_non_tensordict"]["a"] = None - with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): - tc.load_state_dict(sd) - del sd["_non_tensordict"]["a"] - sd["_tensordict"]["y"]["_tensordict"]["a"] = None - with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): - tc.load_state_dict(sd) - - -def test_equal(): - @tensorclass - class MyClass1: - x: torch.Tensor - z: str - y: "MyClass1" = None + # Set Item should work for other tensorclass + @tensorclass + class MyData2: + x: MemoryMappedTensor + y: torch.Tensor - @tensorclass - class MyClass2: - x: torch.Tensor - z: str - y: "MyClass2" = None - - a = MyClass1( - torch.zeros(3), - "z0", - MyClass1( - torch.ones(3), - "z1", - None, - batch_size=[3], - ), - batch_size=[3], - ) - b = MyClass2( - torch.zeros(3), - "z0", - MyClass2( - torch.ones(3), - "z1", - None, - batch_size=[3], - ), - batch_size=[3], - ) - c = TensorDict({"x": torch.zeros(3), "y": {"x": torch.ones(3)}}, batch_size=[3]) + data_other_cls = MyData2( + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), + y=torch.ones(3, 4, 5), + batch_size=[3, 4], + ) + data1[:2] = data_other_cls[:2] + data_other_cls[2:] = data1[2:] - assert (a == a.clone()).all() - assert (a != 1.0).any() - assert (a[:2] != 1.0).any() + # Set Item should raise if other tensorclass with different members + @tensorclass + class MyData3: + x: MemoryMappedTensor + z: torch.Tensor - assert (a.y == 1).all() - assert (a[:2].y == 1).all() - assert (a.y[:2] == 1).all() + data_wrong_cls = MyData3( + x=MemoryMappedTensor.from_tensor(torch.ones(3, 4, 5)), + z=torch.ones(3, 4, 5), + batch_size=[3, 4], + ) + with pytest.raises( + ValueError, + match="__setitem__ is only allowed for same-class or compatible class .* assignment", + ): + data1[:2] = data_wrong_cls[:2] + with pytest.raises( + ValueError, + match="__setitem__ is only allowed for same-class or compatible class .* assignment", + ): + data_wrong_cls[2:] = data1[2:] + + def test_signature( + self, + ): + sig = inspect.signature(MyData) + assert list(sig.parameters) == ["X", "y", "z", "batch_size", "device", "names"] - assert (a != torch.ones([])).any() - assert (a.y == torch.ones([])).all() + with pytest.raises(TypeError, match="missing 3 required positional arguments"): + MyData(batch_size=[10]) - assert (a == b).all() - assert (b == a).all() - assert (b[:2] == a[:2]).all() + with pytest.raises(TypeError, match="missing 2 required positional argument"): + MyData(X=torch.rand(10), batch_size=[10]) - assert (a == c).all() - assert (a[:2] == c[:2]).all() + with pytest.raises(TypeError, match="missing 1 required positional argument"): + MyData(X=torch.rand(10), y=torch.rand(10), batch_size=[10], device="cpu") - assert (c == a).all() - assert (c[:2] == a[:2]).all() + # if all positional arguments are specified, ommitting batch_size gives error + with pytest.raises( + TypeError, match="missing 1 required keyword-only argument: 'batch_size'" + ): + MyData(X=torch.rand(10), y=torch.rand(10)) - assert (a != c.clone().zero_()).any() - assert (c != a.clone().zero_()).any() + # all positional arguments + batch_size is fine + MyData( + X=torch.rand(10), y=torch.rand(10), z="test_tensorclass", batch_size=[10] + ) + def test_split( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 6, 5) + z = "test_tensorclass" + batch_size = [3, 6] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyData(X=X, y=data_nest, z=z, batch_size=batch_size) + split_tcs = torch.split(data, split_size_or_sections=[3, 2, 1], dim=1) + assert type(split_tcs[1]) is type(data) + assert split_tcs[0].batch_size == torch.Size([3, 3]) + assert split_tcs[1].batch_size == torch.Size([3, 2]) + assert split_tcs[2].batch_size == torch.Size([3, 1]) + + assert split_tcs[0].y.batch_size == torch.Size([3, 3]) + assert split_tcs[1].y.batch_size == torch.Size([3, 2]) + assert split_tcs[2].y.batch_size == torch.Size([3, 1]) + + assert torch.all(torch.eq(split_tcs[0].X, torch.ones(3, 3, 5))) + assert torch.all(torch.eq(split_tcs[0].y[0].X, torch.ones(3, 3, 5))) + assert split_tcs[0].z == split_tcs[1].z == split_tcs[2].z == z + assert split_tcs[0].y[0].z == split_tcs[0].y[1].z == split_tcs[0].y[2].z == z + + def test_squeeze( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(1, 4, 5) + z = "test_tensorclass" + batch_size = [1, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + squeeze_tc = torch.squeeze(data) + assert squeeze_tc.batch_size == torch.Size([4]) + assert squeeze_tc.X.shape == torch.Size([4, 5]) + assert squeeze_tc.y.X.shape == torch.Size([4, 5]) + assert squeeze_tc.z == squeeze_tc.y.z == z + + def test_stack( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data1 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + data2 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + + stacked_tc = torch.stack([data1, data2], 0) + assert type(stacked_tc) is type(data1) + assert isinstance(stacked_tc.y, type(data1.y)) + assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5]) + assert stacked_tc.y.X.shape == torch.Size([2, 3, 4, 5]) + assert (stacked_tc.X == 1).all() + assert (stacked_tc.y.X == 1).all() + assert isinstance(stacked_tc._tensordict, LazyStackedTensorDict) + assert isinstance(stacked_tc.y._tensordict, LazyStackedTensorDict) + assert stacked_tc.z == stacked_tc.y.z == z + + # Testing negative scenarios + y = torch.zeros(3, 4, 5, dtype=torch.bool) + data3 = MyData(X=X, y=y, z=z, batch_size=batch_size) + + with pytest.raises( + TypeError, + match=("Multiple dispatch failed|no implementation found"), + ): + torch.stack([data1, data3], dim=0) + + def test_statedict_errors( + self, + ): + @tensorclass + class MyClass: + x: torch.Tensor + z: str + y: "MyClass" = None -def test_all_any(): - @tensorclass - class MyClass1: - x: torch.Tensor - z: str - y: "MyClass1" = None - - # with all 0 - x = MyClass1( - torch.zeros(3, 1), - "z", - MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]), - batch_size=[3, 1], - ) - assert not x.all() - assert not x.any() - assert isinstance(x.all(), bool) - assert isinstance(x.any(), bool) - for dim in [0, 1, -1, -2]: - assert isinstance(x.all(dim=dim), MyClass1) - assert isinstance(x.any(dim=dim), MyClass1) - assert not x.all(dim=dim).all() - assert not x.any(dim=dim).any() - # with all 1 - x = x.apply(lambda x: x.fill_(1.0)) - assert isinstance(x, MyClass1) - assert x.all() - assert x.any() - assert isinstance(x.all(), bool) - assert isinstance(x.any(), bool) - for dim in [0, 1]: - assert isinstance(x.all(dim=dim), MyClass1) - assert isinstance(x.any(dim=dim), MyClass1) - assert x.all(dim=dim).all() - assert x.any(dim=dim).any() - - # with 0 and 1 - x.y.x.fill_(0.0) - assert not x.all() - assert x.any() - assert isinstance(x.all(), bool) - assert isinstance(x.any(), bool) - for dim in [0, 1]: - assert isinstance(x.all(dim=dim), MyClass1) - assert isinstance(x.any(dim=dim), MyClass1) - assert not x.all(dim=dim).all() - assert x.any(dim=dim).any() - - assert not x.y.all() - assert not x.y.any() - - -@pytest.mark.parametrize("from_torch", [True, False]) -def test_gather(from_torch): - @tensorclass - class MyClass: - x: torch.Tensor - z: str - y: "MyClass" = None - - c = MyClass( - torch.randn(3, 4), - "foo", - MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), - batch_size=[3, 4], - ) - dim = -1 - index = torch.arange(3).expand(3, 3) - if from_torch: - c_gather = torch.gather(c, index=index, dim=dim) - else: - c_gather = c.gather(index=index, dim=dim) - assert c_gather.x.shape == torch.Size([3, 3]) - assert c_gather.y.shape == torch.Size([3, 3, 5]) - assert c_gather.y.x.shape == torch.Size([3, 3, 5]) - assert c_gather.y.z == "bar" - assert c_gather.z == "foo" - c_gather_zero = c_gather.clone().zero_() - if from_torch: - c_gather2 = torch.gather(c, index=index, dim=dim, out=c_gather_zero) - else: - c_gather2 = c.gather(index=index, dim=dim, out=c_gather_zero) - - assert (c_gather2 == c_gather).all() - - -def test_to_tensordict(): - @tensorclass - class MyClass: - x: torch.Tensor - z: str - y: "MyClass" = None - - c = MyClass( - torch.randn(3, 4), - "foo", - MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), - batch_size=[3, 4], - ) + z = "test_tensorclass" + tc = MyClass( + x=torch.randn(3), + z=z, + y=MyClass(x=torch.randn(3), z=z, batch_size=[]), + batch_size=[], + ) - ctd = c.to_tensordict() - assert isinstance(ctd, TensorDictBase) - assert "x" in ctd.keys() - assert "z" not in ctd.keys() - assert "y" in ctd.keys() - assert ("y", "x") in ctd.keys(True) + sd = tc.state_dict() + sd["a"] = None + with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): + tc.load_state_dict(sd) + del sd["a"] + sd["_tensordict"]["a"] = None + with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): + tc.load_state_dict(sd) + del sd["_tensordict"]["a"] + sd["_non_tensordict"]["a"] = None + with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): + tc.load_state_dict(sd) + del sd["_non_tensordict"]["a"] + sd["_tensordict"]["y"]["_tensordict"]["a"] = None + with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): + tc.load_state_dict(sd) + + def test_tensorclass_get_at( + self, + ): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str + @tensorclass + class MyDataParent: + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + k: Optional[Tensor] = None + + batch_size = [3, 4] + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + v = "test_tensorclass" + data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) + + assert (data.get("X")[2:3] == data.get_at("X", slice(2, 3))).all() + assert (data.get(("y", "X"))[2:3] == data.get_at(("y", "X"), slice(2, 3))).all() + + # check default + assert data.get_at(("y", "foo"), slice(2, 3), "working") == "working" + assert data.get_at("foo", slice(2, 3), "working") == "working" + + def test_tensorclass_set_at_( + self, + ): + @tensorclass + class MyDataNest: + X: torch.Tensor + v: str -class TestMemmap: - def test_memmap_(self): + @tensorclass + class MyDataParent: + X: Tensor + z: TensorDictBase + y: MyDataNest + v: str + k: Optional[Tensor] = None + + batch_size = [3, 4] + X = torch.ones(3, 4, 5) + td = TensorDict({}, batch_size) + data_nest = MyDataNest(X=X, v="test_nested", batch_size=batch_size) + v = "test_tensorclass" + data = MyDataParent(X=X, y=data_nest, z=td, v=v, batch_size=batch_size) + + data.set_at_("X", 5, slice(2, 3)) + data.set_at_(("y", "X"), 5, slice(2, 3)) + assert (data.get_at("X", slice(2, 3)) == 5).all() + assert (data.get_at(("y", "X"), slice(2, 3)) == 5).all() + # assert other not changed + assert (data.get_at("X", slice(0, 2)) == 1).all() + assert (data.get_at(("y", "X"), slice(0, 2)) == 1).all() + assert (data.get_at("X", slice(3, 5)) == 1).all() + assert (data.get_at(("y", "X"), slice(3, 5)) == 1).all() + + def test_to_tensordict( + self, + ): @tensorclass class MyClass: x: torch.Tensor @@ -1606,34 +1494,157 @@ class MyClass: batch_size=[3, 4], ) - cmemmap = c.memmap_() - assert cmemmap is c - assert isinstance(c.x, MemoryMappedTensor) - assert isinstance(c.y.x, MemoryMappedTensor) - assert c.z == "foo" + ctd = c.to_tensordict() + assert isinstance(ctd, TensorDictBase) + assert "x" in ctd.keys() + assert "z" not in ctd.keys() + assert "y" in ctd.keys() + assert ("y", "x") in ctd.keys(True) - def test_memmap_like(self): + @pytest.mark.skipif( + not _has_torchsnapshot, + reason=f"torchsnapshot not found: err={TORCHSNAPSHOT_ERR}", + ) + def test_torchsnapshot(self, tmp_path): @tensorclass class MyClass: x: torch.Tensor z: str y: "MyClass" = None - c = MyClass( - torch.randn(3, 4), - "foo", - MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), + z = "test_tensorclass" + tc = MyClass( + x=torch.randn(3), + z=z, + y=MyClass(x=torch.randn(3), z=z, batch_size=[]), + batch_size=[], + ) + tc.memmap_() + assert isinstance(tc.y.x, MemoryMappedTensor) + assert tc.z == z + + app_state = { + "state": torchsnapshot.StateDict(tensordict=tc.state_dict(keep_vars=True)) + } + snapshot = torchsnapshot.Snapshot.take(app_state=app_state, path=str(tmp_path)) + + tc_dest = MyClass( + x=torch.randn(3), + z="other", + y=MyClass(x=torch.randn(3), z=z, batch_size=[]), + batch_size=[], + ) + tc_dest.memmap_() + assert isinstance(tc_dest.y.x, MemoryMappedTensor) + app_state = { + "state": torchsnapshot.StateDict( + tensordict=tc_dest.state_dict(keep_vars=True) + ) + } + snapshot.restore(app_state=app_state) + + assert (tc_dest == tc).all() + assert tc_dest.y.batch_size == tc.y.batch_size + assert isinstance(tc_dest.y.x, MemoryMappedTensor) + # torchsnapshot does not support updating strings and such + assert tc_dest.z != z + + tc_dest = MyClass( + x=torch.randn(3), + z="other", + y=MyClass(x=torch.randn(3), z=z, batch_size=[]), + batch_size=[], + ) + tc_dest.memmap_() + tc_dest.load_state_dict(tc.state_dict()) + assert (tc_dest == tc).all() + assert tc_dest.y.batch_size == tc.y.batch_size + assert isinstance(tc_dest.y.x, MemoryMappedTensor) + # load_state_dict outperforms snapshot in this case + assert tc_dest.z == z + + def test_type( + self, + ): + data = MyData( + X=torch.ones(3, 4, 5), + y=torch.zeros(3, 4, 5, dtype=torch.bool), + z="test_tensorclass", batch_size=[3, 4], ) + assert isinstance(data, MyData) + assert is_tensorclass(data) + assert is_tensorclass(MyData) + # we get an instance of the user defined class, not a dynamically defined subclass + assert type(data) is MyDataUndecorated + + def test_unbind( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + unbind_tcs = torch.unbind(data, 0) + assert type(unbind_tcs[1]) is type(data) + assert type(unbind_tcs[0].y[0]) is type(data) + assert len(unbind_tcs) == 3 + assert torch.all(torch.eq(unbind_tcs[0].X, torch.ones(4, 5))) + assert torch.all(torch.eq(unbind_tcs[0].y[0].X, torch.ones(4, 5))) + assert unbind_tcs[0].batch_size == torch.Size([4]) + assert unbind_tcs[0].z == unbind_tcs[1].z == unbind_tcs[2].z == z + + def test_unsqueeze( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + unsqueeze_tc = torch.unsqueeze(data, dim=1) + assert unsqueeze_tc.batch_size == torch.Size([3, 1, 4]) + assert unsqueeze_tc.X.shape == torch.Size([3, 1, 4, 5]) + assert unsqueeze_tc.y.X.shape == torch.Size([3, 1, 4, 5]) + assert unsqueeze_tc.z == unsqueeze_tc.y.z == z + + def test_view( + self, + ): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + stacked_tc = data.view(-1) + assert stacked_tc.X.shape == torch.Size([12, 5]) + assert stacked_tc.y.X.shape == torch.Size([12, 5]) + assert stacked_tc.shape == torch.Size([12]) + assert (stacked_tc.X == 1).all() + assert isinstance(stacked_tc._tensordict, _ViewedTensorDict) + assert stacked_tc.z == stacked_tc.y.z == z - cmemmap = c.memmap_like() - assert cmemmap is not c - assert cmemmap.y is not c.y - assert (cmemmap == 0).all() - assert isinstance(cmemmap.x, MemoryMappedTensor) - assert isinstance(cmemmap.y.x, MemoryMappedTensor) - assert cmemmap.z == "foo" +class TestMemmap: def test_from_memmap(self, tmpdir): td = TensorDict( { @@ -1683,25 +1694,47 @@ class MyOtherClass: data3 = MyOtherClass.load_memmap(tmpdir) assert isinstance(data3, MyClass) + def test_memmap_(self): + @tensorclass + class MyClass: + x: torch.Tensor + z: str + y: "MyClass" = None + + c = MyClass( + torch.randn(3, 4), + "foo", + MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), + batch_size=[3, 4], + ) -def test_from_dict(): - td = TensorDict( - { - ("a", "b", "c"): 1, - ("a", "d"): 2, - }, - [], - ).expand(10) - d = td.to_dict() + cmemmap = c.memmap_() + assert cmemmap is c + assert isinstance(c.x, MemoryMappedTensor) + assert isinstance(c.y.x, MemoryMappedTensor) + assert c.z == "foo" - @tensorclass - class MyClass: - a: TensorDictBase + def test_memmap_like(self): + @tensorclass + class MyClass: + x: torch.Tensor + z: str + y: "MyClass" = None - tc = MyClass.from_dict(d) - assert isinstance(tc, MyClass) - assert isinstance(tc.a, TensorDict) - assert tc.batch_size == torch.Size([10]) + c = MyClass( + torch.randn(3, 4), + "foo", + MyClass(torch.randn(3, 4, 5), "bar", None, batch_size=[3, 4, 5]), + batch_size=[3, 4], + ) + + cmemmap = c.memmap_like() + assert cmemmap is not c + assert cmemmap.y is not c.y + assert (cmemmap == 0).all() + assert isinstance(cmemmap.x, MemoryMappedTensor) + assert isinstance(cmemmap.y.x, MemoryMappedTensor) + assert cmemmap.z == "foo" class TestNesting: @@ -1720,28 +1753,28 @@ def get_nested(self): ) return td - def test_to(self): + def test_apply(self): td = self.get_nested() - td = td.to("cpu:1") + td = td.apply(lambda x: x + 1) + assert isinstance(td.get("c")[0], self.TensorClass) + + def test_chunk(self): + td = self.get_nested() + td, _ = td.chunk(2, dim=0) assert isinstance(td.get("c")[0], self.TensorClass) def test_idx(self): td = self.get_nested()[0] assert isinstance(td.get("c"), self.TensorClass) - def test_apply(self): - td = self.get_nested() - td = td.apply(lambda x: x + 1) - assert isinstance(td.get("c")[0], self.TensorClass) - def test_split(self): td = self.get_nested() td, _ = td.split([2, 1], dim=0) assert isinstance(td.get("c")[0], self.TensorClass) - def test_chunk(self): + def test_to(self): td = self.get_nested() - td, _ = td.chunk(2, dim=0) + td = td.to("cpu:1") assert isinstance(td.get("c")[0], self.TensorClass) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 5c8fb40f5..1edff3c8d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -66,1050 +66,1610 @@ _IS_OSX = platform.system() == "Darwin" +TD_BATCH_SIZE = 4 -@pytest.mark.parametrize("device", get_available_devices()) -def test_tensordict_set(device): - torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5), device=device) - td.set("key1", torch.randn(4, 5)) - assert td.device == torch.device(device) - # by default inplace: - with pytest.raises(RuntimeError): - td.set("key1", torch.randn(5, 5, device=device)) - - # robust to dtype casting - td.set_("key1", torch.ones(4, 5, device=device, dtype=torch.double)) - assert (td.get("key1") == 1).all() - - # robust to device casting - td.set("key_device", torch.ones(4, 5, device="cpu", dtype=torch.double)) - assert td.get("key_device").device == torch.device(device) - - with pytest.raises(KeyError, match="not found in TensorDict with keys"): - td.set_("smartypants", torch.ones(4, 5, device="cpu", dtype=torch.double)) - # test set_at_ - td.set("key2", torch.randn(4, 5, 6, device=device)) - x = torch.randn(6, device=device) - td.set_at_("key2", x, (2, 2)) - assert (td.get("key2")[2, 2] == x).all() - - # test set_at_ with dtype casting - x = torch.randn(6, dtype=torch.double, device=device) - td.set_at_("key2", x, (2, 2)) # robust to dtype casting - torch.testing.assert_close(td.get("key2")[2, 2], x.to(torch.float)) - - td.set("key1", torch.zeros(4, 5, dtype=torch.double, device=device), inplace=True) - assert (td.get("key1") == 0).all() - td.set( - "key1", - torch.randn(4, 5, 1, 2, dtype=torch.double, device=device), - inplace=False, - ) - assert td["key1"].shape == td._tensordict["key1"].shape - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_tensordict_device(device): - tensordict = TensorDict({"a": torch.randn(3, 4)}, []) - assert tensordict.device is None - - tensordict = TensorDict({"a": torch.randn(3, 4, device=device)}, []) - assert tensordict["a"].device == device - assert tensordict.device is None - - tensordict = TensorDict( - { - "a": torch.randn(3, 4, device=device), - "b": torch.randn(3, 4), - "c": torch.randn(3, 4, device="cpu"), - }, - [], - device=device, - ) - assert tensordict.device == device - assert tensordict["a"].device == device - assert tensordict["b"].device == device - assert tensordict["c"].device == device - - tensordict = TensorDict({}, [], device=device) - tensordict["a"] = torch.randn(3, 4) - tensordict["b"] = torch.randn(3, 4, device="cpu") - assert tensordict["a"].device == device - assert tensordict["b"].device == device - - tensordict = TensorDict({"a": torch.randn(3, 4)}, []) - tensordict = tensordict.to(device) - assert tensordict.device == device - assert tensordict["a"].device == device - - -@pytest.mark.skipif(torch.cuda.device_count() == 0, reason="No cuda device detected") -@pytest.mark.parametrize("device", get_available_devices()[1:]) -def test_tensordict_error_messages(device): - sub1 = TensorDict({"a": torch.randn(2, 3)}, [2]) - sub2 = TensorDict({"a": torch.randn(2, 3, device=device)}, [2]) - td1 = TensorDict({"sub": sub1}, [2]) - td2 = TensorDict({"sub": sub2}, [2]) - - with pytest.raises( - RuntimeError, match='tensors on different devices at key "sub" / "a"' - ): - torch.cat([td1, td2], 0) +def _compare_tensors_identity(td0, td1): + if isinstance(td0, LazyStackedTensorDict): + if not isinstance(td1, LazyStackedTensorDict): + return False + for _td0, _td1 in zip(td0.tensordicts, td1.tensordicts): + if not _compare_tensors_identity(_td0, _td1): + return False + return True + if td0 is td1: + return True + for key, val in td0.items(): + if is_tensor_collection(val): + if not _compare_tensors_identity(val, td1.get(key)): + return False + else: + if val is not td1.get(key): + return False + else: + return True -def test_pad(): - dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2] - td = TensorDict( - { - "a": torch.ones(3, 4, 1), - "b": torch.zeros(3, 4, 1, 1), - }, - batch_size=[3, 4], - ) - padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0) +class TestGeneric: + # Generic, type-insensitive teests - expected_a = torch.cat([torch.ones(3, 4, 1), torch.zeros(1, 4, 1)], dim=0) - expected_a = torch.cat([expected_a, torch.zeros(4, 2, 1)], dim=1) + def test_batchsize_reset(self): + td = TensorDict( + {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] + ) + # smoke-test + td.batch_size = torch.Size([3]) - assert padded_td["a"].shape == (4, 6, 1) - assert padded_td["b"].shape == (4, 6, 1, 1) - assert torch.equal(padded_td["a"], expected_a) - padded_td._check_batch_size() + # test with list + td.batch_size = [3] + # test with tuple + td.batch_size = (3,) -@pytest.mark.parametrize("device", get_available_devices()) -def test_tensordict_indexing(device): - torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) - td.set("key1", torch.randn(4, 5, 1, device=device)) - td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double)) + # incompatible size + with pytest.raises( + RuntimeError, + match=re.escape( + "the tensor a has shape torch.Size([3, 4, 5, 6]) which is incompatible with the batch-size torch.Size([3, 5])." + ), + ): + td.batch_size = [3, 5] - td_select = td[2, 2] - td_select._check_batch_size() + # test set + td.set("c", torch.randn(3)) - td_select = td[2, :2] - td_select._check_batch_size() + # test index + td[torch.tensor([1, 2])] + td[:] + td[[1, 2]] + with pytest.raises( + IndexError, + match=re.escape("too many indices for tensor of dimension 1"), + ): + td[:, 0] - td_select = td[None, :2] - td_select._check_batch_size() + # test a greater batch_size + td = TensorDict( + {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] + ) + td.batch_size = torch.Size([3, 4, 5]) - td_reconstruct = stack_td(list(td), 0, contiguous=False) - assert ( - td_reconstruct == td - ).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}" + td.set("c", torch.randn(3, 4, 5, 6)) + with pytest.raises( + RuntimeError, + match=re.escape( + "batch dimension mismatch, got self.batch_size=torch.Size([3, 4, 5]) and value.shape=torch.Size([3, 4, 2])" + ), + ): + td.set("d", torch.randn(3, 4, 2)) - superlist = [stack_td(list(_td), 0, contiguous=False) for _td in td] - td_reconstruct = stack_td(superlist, 0, contiguous=False) - assert ( - td_reconstruct == td - ).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}" + # test that lazy tds return an exception + td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)]) + with pytest.raises( + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), + ): + td_stack.batch_size = [2] + td_stack.to_tensordict().batch_size = [2] - x = torch.randn(4, 5, device=device) - td = TensorDict( - source={"key1": torch.zeros(3, 4, 5, device=device)}, - batch_size=[3, 4], - ) - td[0].set_("key1", x) - torch.testing.assert_close(td.get("key1")[0], x) - torch.testing.assert_close(td.get("key1")[0], td[0].get("key1")) - - y = torch.randn(3, 5, device=device) - td[:, 0].set_("key1", y) - torch.testing.assert_close(td.get("key1")[:, 0], y) - torch.testing.assert_close(td.get("key1")[:, 0], td[:, 0].get("key1")) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_subtensordict_construction(device): - torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) - val1 = torch.randn(4, 5, 1, device=device) - val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device) - val1_copy = val1.clone() - val2_copy = val2.clone() - td.set("key1", val1) - td.set("key2", val2) - std1 = td._get_sub_tensordict(2) - std2 = std1._get_sub_tensordict(2) - idx = (2, 2) - std_control = td._get_sub_tensordict(idx) - assert (std_control.get("key1") == std2.get("key1")).all() - assert (std_control.get("key2") == std2.get("key2")).all() - - # write values - with pytest.raises(RuntimeError, match="is prohibited for existing tensors"): - std_control.set("key1", torch.randn(1, device=device)) - with pytest.raises(RuntimeError, match="is prohibited for existing tensors"): - std_control.set("key2", torch.randn(6, device=device, dtype=torch.double)) - - subval1 = torch.randn(1, device=device) - subval2 = torch.randn(6, device=device, dtype=torch.double) - std_control.set_("key1", subval1) - std_control.set_("key2", subval2) - assert (val1_copy[idx] != subval1).all() - assert (td.get("key1")[idx] == subval1).all() - assert (td.get("key1")[1, 1] == val1_copy[1, 1]).all() - - assert (val2_copy[idx] != subval2).all() - assert (td.get("key2")[idx] == subval2).all() - assert (td.get("key2")[1, 1] == val2_copy[1, 1]).all() - - assert (std_control.get("key1") == std2.get("key1")).all() - assert (std_control.get("key2") == std2.get("key2")).all() - - assert std_control.get_parent_tensordict() is td - assert ( - std_control.get_parent_tensordict() - is std2.get_parent_tensordict().get_parent_tensordict() - ) + td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) + subtd = td._get_sub_tensordict((slice(None), torch.tensor([1, 2]))) + with pytest.raises( + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), + ): + subtd.batch_size = [3, 2] + subtd.to_tensordict().batch_size = [3, 2] + td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) + td_u = td.unsqueeze(0) + with pytest.raises( + RuntimeError, + match=re.escape( + "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." + ), + ): + td_u.batch_size = [1] + td_u.to_tensordict().batch_size = [1] -@pytest.mark.parametrize("device", get_available_devices()) -def test_mask_td(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(4, 5, 6, device=device), - "key2": torch.randn(4, 5, 10, device=device), - } - mask = torch.zeros(4, 5, dtype=torch.bool, device=device).bernoulli_() - td = TensorDict(batch_size=(4, 5), source=d) - - td_masked = torch.masked_select(td, mask) - assert len(td_masked.get("key1")) == td_masked.shape[0] - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_unbind_td(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(4, 5, 6, device=device), - "key2": torch.randn(4, 5, 10, device=device), - } - td = TensorDict(batch_size=(4, 5), source=d) - td_unbind = torch.unbind(td, dim=1) - assert ( - td_unbind[0].batch_size == td[:, 0].batch_size - ), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}" - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_cat_td(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(4, 5, 6, device=device), - "key2": torch.randn(4, 5, 10, device=device), - "key3": {"key4": torch.randn(4, 5, 10, device=device)}, - } - td1 = TensorDict(batch_size=(4, 5), source=d, device=device) - d = { - "key1": torch.randn(4, 10, 6, device=device), - "key2": torch.randn(4, 10, 10, device=device), - "key3": {"key4": torch.randn(4, 10, 10, device=device)}, - } - td2 = TensorDict(batch_size=(4, 10), source=d, device=device) - - td_cat = torch.cat([td1, td2], 1) - assert td_cat.batch_size == torch.Size([4, 15]) - d = { - "key1": torch.zeros(4, 15, 6, device=device), - "key2": torch.zeros(4, 15, 10, device=device), - "key3": {"key4": torch.zeros(4, 15, 10, device=device)}, - } - td_out = TensorDict(batch_size=(4, 15), source=d, device=device) - data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} - torch.cat([td1, td2], 1, out=td_out) - data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} - assert data_ptr_set_before == data_ptr_set_after - assert td_out.batch_size == torch.Size([4, 15]) - assert (td_out["key1"] != 0).all() - assert (td_out["key2"] != 0).all() - assert (td_out["key3", "key4"] != 0).all() - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_expand(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(4, 5, 6, device=device), - "key2": torch.randn(4, 5, 10, device=device), - } - td1 = TensorDict(batch_size=(4, 5), source=d) - td2 = td1.expand(3, 7, 4, 5) - assert td2.batch_size == torch.Size([3, 7, 4, 5]) - assert td2.get("key1").shape == torch.Size([3, 7, 4, 5, 6]) - assert td2.get("key2").shape == torch.Size([3, 7, 4, 5, 10]) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_expand_with_singleton(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(1, 5, 6, device=device), - "key2": torch.randn(1, 5, 10, device=device), - } - td1 = TensorDict(batch_size=(1, 5), source=d) - td2 = td1.expand(3, 7, 4, 5) - assert td2.batch_size == torch.Size([3, 7, 4, 5]) - assert td2.get("key1").shape == torch.Size([3, 7, 4, 5, 6]) - assert td2.get("key2").shape == torch.Size([3, 7, 4, 5, 10]) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_squeeze(device): - torch.manual_seed(1) - d = { - "key1": torch.randn(4, 5, 6, device=device), - "key2": torch.randn(4, 5, 10, device=device), - } - td1 = TensorDict(batch_size=(4, 5), source=d) - td2 = torch.unsqueeze(td1, dim=1) - assert td2.batch_size == torch.Size([4, 1, 5]) - - td1b = torch.squeeze(td2, dim=1) - assert td1b.batch_size == td1.batch_size - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_permute(device): - torch.manual_seed(1) - d = { - "a": torch.randn(4, 5, 6, 9, device=device), - "b": torch.randn(4, 5, 6, 7, device=device), - "c": torch.randn(4, 5, 6, device=device), - } - td1 = TensorDict(batch_size=(4, 5, 6), source=d) - td2 = torch.permute(td1, dims=(2, 1, 0)) - assert td2.shape == torch.Size((6, 5, 4)) - assert td2["a"].shape == torch.Size((6, 5, 4, 9)) - - td2 = torch.permute(td1, dims=(-1, -3, -2)) - assert td2.shape == torch.Size((6, 4, 5)) - assert td2["c"].shape == torch.Size((6, 4, 5)) - - td2 = torch.permute(td1, dims=(0, 1, 2)) - assert td2["a"].shape == torch.Size((4, 5, 6, 9)) - - t = TensorDict({"a": torch.randn(3, 4, 1)}, [3, 4]) - torch.permute(t, dims=(1, 0)).set("b", torch.randn(4, 3)) - assert t["b"].shape == torch.Size((3, 4)), t - - torch.permute(t, dims=(1, 0)).fill_("a", 0.0) - assert torch.sum(t["a"]) == torch.Tensor([0]) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_permute_applied_twice(device): - torch.manual_seed(1) - d = { - "a": torch.randn(4, 5, 6, 9, device=device), - "b": torch.randn(4, 5, 6, 7, device=device), - "c": torch.randn(4, 5, 6, device=device), - } - td1 = TensorDict(batch_size=(4, 5, 6), source=d) - td2 = torch.permute(td1, dims=(2, 1, 0)) - td3 = torch.permute(td2, dims=(2, 1, 0)) - assert td3 is td1 - td1 = TensorDict(batch_size=(4, 5, 6), source=d) - td2 = torch.permute(td1, dims=(2, 1, 0)) - td3 = torch.permute(td2, dims=(0, 1, 2)) - assert td3 is not td1 - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_permute_exceptions(device): - torch.manual_seed(1) - d = { - "a": torch.randn(4, 5, 6, 7, device=device), - "b": torch.randn(4, 5, 6, 8, 9, device=device), - } - td1 = TensorDict(batch_size=(4, 5, 6), source=d) - - with pytest.raises(RuntimeError): - td2 = td1.permute(1, 1, 0) - _ = td2.shape - - with pytest.raises(RuntimeError): - td2 = td1.permute(3, 2, 1, 0) - _ = td2.shape - - with pytest.raises(RuntimeError): - td2 = td1.permute(2, -1, 0) - _ = td2.shape - - with pytest.raises(IndexError): - td2 = td1.permute(2, 3, 0) - _ = td2.shape - - with pytest.raises(IndexError): - td2 = td1.permute(2, -4, 0) - _ = td2.shape - - with pytest.raises(RuntimeError): - td2 = td1.permute(2, 1) - _ = td2.shape - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_permute_with_tensordict_operations(device): - torch.manual_seed(1) - d = { - "a": torch.randn(20, 6, 9, device=device), - "b": torch.randn(20, 6, 7, device=device), - "c": torch.randn(20, 6, device=device), - } - td1 = TensorDict(batch_size=(20, 6), source=d).view(4, 5, 6).permute(2, 1, 0) - assert td1.shape == torch.Size((6, 5, 4)) - - d = { - "a": torch.randn(4, 5, 6, 7, 9, device=device), - "b": torch.randn(4, 5, 6, 7, 7, device=device), - "c": torch.randn(4, 5, 6, 7, device=device), - } - td1 = TensorDict(batch_size=(4, 5, 6, 7), source=d)[ - :, :, :, torch.tensor([1, 2]) - ].permute(3, 2, 1, 0) - assert td1.shape == torch.Size((2, 6, 5, 4)) - - d = { - "a": torch.randn(4, 5, 9, device=device), - "b": torch.randn(4, 5, 7, device=device), - "c": torch.randn(4, 5, device=device), - } - td1 = stack_td( - [TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)], - 2, - contiguous=False, - ).permute(2, 1, 0) - assert td1.shape == torch.Size((6, 5, 4)) - - -def test_inferred_view_size(): - td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) - assert td.view(-1).view(-1, 4) is td - - assert td.view(-1, 4) is td - assert td.view(3, -1) is td - assert td.view(3, 4) is td - assert td.view(-1, 12).shape == torch.Size([1, 12]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_cat_td(self, device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + "key3": {"key4": torch.randn(4, 5, 10, device=device)}, + } + td1 = TensorDict(batch_size=(4, 5), source=d, device=device) + d = { + "key1": torch.randn(4, 10, 6, device=device), + "key2": torch.randn(4, 10, 10, device=device), + "key3": {"key4": torch.randn(4, 10, 10, device=device)}, + } + td2 = TensorDict(batch_size=(4, 10), source=d, device=device) + + td_cat = torch.cat([td1, td2], 1) + assert td_cat.batch_size == torch.Size([4, 15]) + d = { + "key1": torch.zeros(4, 15, 6, device=device), + "key2": torch.zeros(4, 15, 10, device=device), + "key3": {"key4": torch.zeros(4, 15, 10, device=device)}, + } + td_out = TensorDict(batch_size=(4, 15), source=d, device=device) + data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} + torch.cat([td1, td2], 1, out=td_out) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} + assert data_ptr_set_before == data_ptr_set_after + assert td_out.batch_size == torch.Size([4, 15]) + assert (td_out["key1"] != 0).all() + assert (td_out["key2"] != 0).all() + assert (td_out["key3", "key4"] != 0).all() + @pytest.mark.parametrize( + "ellipsis_index, expectation", + [ + ((..., 0, ...), pytest.raises(RuntimeError)), + ((0, ..., 0, ...), pytest.raises(RuntimeError)), + ], + ) + def test_convert_ellipsis_to_idx_invalid(self, ellipsis_index, expectation): + torch.manual_seed(1) + batch_size = [3, 4, 5, 6, 7] -@pytest.mark.parametrize( - "ellipsis_index, expected_index", - [ - (..., (slice(None), slice(None), slice(None), slice(None), slice(None))), - ((0, ..., 0), (0, slice(None), slice(None), slice(None), 0)), - ((..., 0), (slice(None), slice(None), slice(None), slice(None), 0)), - ((0, ...), (0, slice(None), slice(None), slice(None), slice(None))), - ( - (slice(1, 2), ...), - (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), - ), - ], -) -def test_convert_ellipsis_to_idx_valid(ellipsis_index, expected_index): - torch.manual_seed(1) - batch_size = [3, 4, 5, 6, 7] + with expectation: + _ = convert_ellipsis_to_idx(ellipsis_index, batch_size) - assert convert_ellipsis_to_idx(ellipsis_index, batch_size) == expected_index + @pytest.mark.parametrize( + "ellipsis_index, expected_index", + [ + (..., (slice(None), slice(None), slice(None), slice(None), slice(None))), + ((0, ..., 0), (0, slice(None), slice(None), slice(None), 0)), + ((..., 0), (slice(None), slice(None), slice(None), slice(None), 0)), + ((0, ...), (0, slice(None), slice(None), slice(None), slice(None))), + ( + (slice(1, 2), ...), + (slice(1, 2), slice(None), slice(None), slice(None), slice(None)), + ), + ], + ) + def test_convert_ellipsis_to_idx_valid(self, ellipsis_index, expected_index): + torch.manual_seed(1) + batch_size = [3, 4, 5, 6, 7] + assert convert_ellipsis_to_idx(ellipsis_index, batch_size) == expected_index -@pytest.mark.parametrize( - "ellipsis_index, expectation", - [ - ((..., 0, ...), pytest.raises(RuntimeError)), - ((0, ..., 0, ...), pytest.raises(RuntimeError)), - ], -) -def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation): - torch.manual_seed(1) - batch_size = [3, 4, 5, 6, 7] + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") + def test_create_on_device(self): + device = torch.device(0) + + # TensorDict + td = TensorDict({}, [5]) + assert td.device is None + + td.set("a", torch.randn(5, device=device)) + assert td.device is None + + td = TensorDict({}, [5], device="cuda:0") + td.set("a", torch.randn(5, 1)) + assert td.get("a").device == device + + # stacked TensorDict + td1 = TensorDict({}, [5]) + td2 = TensorDict({}, [5]) + stackedtd = stack_td([td1, td2], 0) + assert stackedtd.device is None + + stackedtd.set("a", torch.randn(2, 5, device=device)) + assert stackedtd.device is None + + stackedtd = stackedtd.to(device) + assert stackedtd.device == device + + td1 = TensorDict({}, [5], device="cuda:0") + td2 = TensorDict({}, [5], device="cuda:0") + stackedtd = stack_td([td1, td2], 0) + stackedtd.set("a", torch.randn(2, 5, 1)) + assert stackedtd.get("a").device == device + assert td1.get("a").device == device + assert td2.get("a").device == device + + # TensorDict, indexed + td = TensorDict({}, [5]) + subtd = td[1] + assert subtd.device is None + + subtd.set("a", torch.randn(1, device=device)) + # setting element of subtensordict doesn't set top-level device + assert subtd.device is None + + subtd = subtd.to(device) + assert subtd.device == device + assert subtd["a"].device == device + + td = TensorDict({}, [5], device="cuda:0") + subtd = td[1] + subtd.set("a", torch.randn(1)) + assert subtd.get("a").device == device + + td = TensorDict({}, [5], device="cuda:0") + subtd = td[1:3] + subtd.set("a", torch.randn(2)) + assert subtd.get("a").device == device + + # ViewedTensorDict + td = TensorDict({}, [6]) + viewedtd = td.view(2, 3) + assert viewedtd.device is None + + viewedtd = viewedtd.to(device) + assert viewedtd.device == device + + td = TensorDict({}, [6], device="cuda:0") + viewedtd = td.view(2, 3) + a = torch.randn(2, 3) + viewedtd.set("a", a) + assert viewedtd.get("a").device == device + assert (a.to(device) == viewedtd.get("a")).all() - with expectation: - _ = convert_ellipsis_to_idx(ellipsis_index, batch_size) + @pytest.mark.parametrize( + "stack_dim", + [0, 1, 2, 3], + ) + @pytest.mark.parametrize( + "nested_stack_dim", + [0, 1, 2], + ) + def test_dense_stack_tds(self, stack_dim, nested_stack_dim): + batch_size = (5, 6) + td0 = TensorDict( + {"a": torch.zeros(*batch_size, 3)}, + batch_size, + ) + td1 = TensorDict( + {"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)}, + batch_size, + ) + td_lazy = torch.stack([td0, td1], dim=nested_stack_dim) + td_container = TensorDict({"lazy": td_lazy}, td_lazy.batch_size) + td_container_clone = td_container.clone() + td_container_clone.apply_(lambda x: x + 1) + + assert td_lazy.stack_dim == nested_stack_dim + td_stack = torch.stack([td_container, td_container_clone], dim=stack_dim) + assert td_stack.stack_dim == stack_dim + + assert isinstance(td_stack, LazyStackedTensorDict) + dense_td_stack = dense_stack_tds(td_stack) + assert isinstance(dense_td_stack, TensorDict) # check outer layer is non-lazy + assert isinstance( + dense_td_stack["lazy"], LazyStackedTensorDict + ) # while inner layer is still lazy + assert "b" not in dense_td_stack["lazy"].tensordicts[0].keys() + assert "b" in dense_td_stack["lazy"].tensordicts[1].keys() + assert assert_allclose_td( + dense_td_stack, + dense_stack_tds([td_container, td_container_clone], dim=stack_dim), + ) # This shows it is the same to pass a list or a LazyStackedTensorDict -TD_BATCH_SIZE = 4 + for i in range(2): + index = (slice(None),) * stack_dim + (i,) + assert (dense_td_stack[index] == i).all() + if stack_dim > nested_stack_dim: + assert dense_td_stack["lazy"].stack_dim == nested_stack_dim + else: + assert dense_td_stack["lazy"].stack_dim == nested_stack_dim + 1 -@pytest.mark.parametrize( - "td_name,device", - TestTensorDictsBase.TYPES_DEVICES, -) -class TestTensorDicts(TestTensorDictsBase): - def test_permute_applied_twice(self, td_name, device): - torch.manual_seed(0) - tensordict = getattr(self, td_name)(device) - for _ in range(10): - p = torch.randperm(4) - inv_p = p.argsort() - other_p = inv_p - while (other_p == inv_p).all(): - other_p = torch.randperm(4) - other_p = tuple(other_p.tolist()) - p = tuple(p.tolist()) - inv_p = tuple(inv_p.tolist()) - if td_name in ("td_params",): - # TODO: Should we break this? - assert ( - tensordict.permute(*p).permute(*inv_p)._param_td - is tensordict._param_td - ) - assert ( - tensordict.permute(*p).permute(*other_p)._param_td - is not tensordict._param_td - ) - assert ( - torch.permute(tensordict, p).permute(inv_p)._param_td - is tensordict._param_td - ) - assert ( - torch.permute(tensordict, p).permute(other_p)._param_td - is not tensordict._param_td - ) - else: - assert assert_allclose_td( - tensordict.permute(*p).permute(*inv_p), tensordict - ) - assert tensordict.permute(*p).permute(*inv_p) is tensordict - assert tensordict.permute(*p).permute(*other_p) is not tensordict - assert assert_allclose_td( - torch.permute(tensordict, p).permute(inv_p), tensordict - ) - assert torch.permute(tensordict, p).permute(inv_p) is tensordict - assert torch.permute(tensordict, p).permute(other_p) is not tensordict + def test_empty(self): + td = TensorDict( + { + "a": torch.zeros(()), + ("b", "c"): torch.zeros(()), + ("b", "d", "e"): torch.zeros(()), + }, + [], + ) + td_empty = td.empty(recurse=False) + assert len(list(td_empty.keys())) == 0 + td_empty = td.empty(recurse=True) + assert len(list(td_empty.keys())) == 1 + assert len(list(td_empty.get("b").keys())) == 1 - def test_to_tensordict(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td2 = td.to_tensordict() - assert (td2 == td).all() + def test_error_on_contains(self): + td = TensorDict( + {"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)}, [1] + ) + with pytest.raises( + NotImplementedError, + match="TensorDict does not support membership checks with the `in` keyword", + ): + "random_string" in td # noqa: B015 - @pytest.mark.parametrize("strict", [True, False]) @pytest.mark.parametrize("inplace", [True, False]) - def test_select(self, td_name, device, strict, inplace): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - keys = ["a"] - if td_name == "td_h5": - with pytest.raises(NotImplementedError, match="Cannot call select"): - td.select(*keys, strict=strict, inplace=inplace) - return + def test_exclude_nested(self, inplace): + tensor_1 = torch.rand(4, 5, 6, 7) + tensor_2 = torch.rand(4, 5, 6, 7) + sub_sub_tensordict = TensorDict( + {"t1": tensor_1, "t2": tensor_2}, batch_size=[4, 5, 6] + ) + sub_tensordict = TensorDict( + {"double_nested": sub_sub_tensordict}, batch_size=[4, 5] + ) + tensordict = TensorDict( + { + "a": torch.rand(4, 3), + "b": torch.rand(4, 2), + "c": torch.rand(4, 1), + "nested": sub_tensordict, + }, + batch_size=[4], + ) + # making a copy for inplace tests + tensordict2 = tensordict.clone() - if td_name in ("nested_stacked_td", "nested_td"): - keys += [("my_nested_td", "inner")] + excluded = tensordict.exclude( + "b", ("nested", "double_nested", "t2"), inplace=inplace + ) + + assert set(excluded.keys(include_nested=True)) == { + "a", + "c", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t1"), + } - with td.unlock_() if td.is_locked else contextlib.nullcontext(): - td2 = td.select(*keys, strict=strict, inplace=inplace) if inplace: - assert td2 is td - else: - assert td2 is not td - if td_name == "saved_td": - assert (len(list(td2.keys())) == len(keys)) and ("a" in td2.keys()) - assert (len(list(td2.clone().keys())) == len(keys)) and ( - "a" in td2.clone().keys() - ) + assert excluded is tensordict + assert set(tensordict.keys(include_nested=True)) == { + "a", + "c", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t1"), + } else: - assert (len(list(td2.keys(True, True))) == len(keys)) and ( - "a" in td2.keys() - ) - assert (len(list(td2.clone().keys(True, True))) == len(keys)) and ( - "a" in td2.clone().keys() - ) + assert excluded is not tensordict + assert set(tensordict.keys(include_nested=True)) == { + "a", + "b", + "c", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t1"), + ("nested", "double_nested", "t2"), + } - @pytest.mark.parametrize("strict", [True, False]) - def test_select_exception(self, td_name, device, strict): + # excluding "nested" should exclude all subkeys also + excluded2 = tensordict2.exclude("nested", inplace=inplace) + assert set(excluded2.keys(include_nested=True)) == {"a", "b", "c"} + + @pytest.mark.parametrize("device", get_available_devices()) + def test_expand(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_h5": - with pytest.raises(NotImplementedError, match="Cannot call select"): - _ = td.select("tada", strict=strict) - return + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(4, 5), source=d) + td2 = td1.expand(3, 7, 4, 5) + assert td2.batch_size == torch.Size([3, 7, 4, 5]) + assert td2.get("key1").shape == torch.Size([3, 7, 4, 5, 6]) + assert td2.get("key2").shape == torch.Size([3, 7, 4, 5, 10]) - if strict: - with pytest.raises(KeyError): - _ = td.select("tada", strict=strict) + @pytest.mark.parametrize("device", get_available_devices()) + def test_expand_with_singleton(self, device): + torch.manual_seed(1) + d = { + "key1": torch.randn(1, 5, 6, device=device), + "key2": torch.randn(1, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(1, 5), source=d) + td2 = td1.expand(3, 7, 4, 5) + assert td2.batch_size == torch.Size([3, 7, 4, 5]) + assert td2.get("key1").shape == torch.Size([3, 7, 4, 5, 6]) + assert td2.get("key2").shape == torch.Size([3, 7, 4, 5, 10]) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "td_type", ["tensordict", "view", "unsqueeze", "squeeze", "stack"] + ) + @pytest.mark.parametrize("update", [True, False]) + def test_filling_empty_tensordict(self, device, td_type, update): + if td_type == "tensordict": + td = TensorDict({}, batch_size=[16], device=device) + elif td_type == "view": + td = TensorDict({}, batch_size=[4, 4], device=device).view(-1) + elif td_type == "unsqueeze": + td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1) + elif td_type == "squeeze": + td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) + elif td_type == "stack": + td = torch.stack([TensorDict({}, [], device=device) for _ in range(16)], 0) else: - td2 = td.select("tada", strict=strict) - assert td2 is not td - assert len(list(td2.keys())) == 0 + raise NotImplementedError - def test_exclude(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_h5": - with pytest.raises(NotImplementedError, match="Cannot call exclude"): - _ = td.exclude("a") - return - td2 = td.exclude("a") - assert td2 is not td - assert ( - len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() + for i in range(16): + other_td = TensorDict({"a": torch.randn(10), "b": torch.ones(1)}, []) + if td_type == "unsqueeze": + other_td = other_td.unsqueeze(-1).to_tensordict() + if update: + subtd = td._get_sub_tensordict(i) + subtd.update(other_td, inplace=True) + else: + td[i] = other_td + + assert td.device == device + assert td.get("a").device == device + assert (td.get("b") == 1).all() + if td_type == "view": + assert td._source["a"].shape == torch.Size([4, 4, 10]) + elif td_type == "unsqueeze": + assert td._source["a"].shape == torch.Size([16, 10]) + elif td_type == "squeeze": + assert td._source["a"].shape == torch.Size([16, 1, 10]) + elif td_type == "stack": + assert (td[-1] == other_td.to(device)).all() + + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("separator", [",", "-"]) + def test_flatten_unflatten_key_collision(self, inplace, separator): + td1 = TensorDict( + { + f"a{separator}b{separator}c": torch.zeros(3), + "a": {"b": {"c": torch.zeros(3)}}, + }, + [], ) - assert ( - len(list(td2.clone().keys())) == len(list(td.keys())) - 1 - and "a" not in td2.clone().keys() + td2 = TensorDict( + { + f"a{separator}b": torch.zeros(3), + "a": {"b": torch.zeros(3)}, + "g": {"d": torch.zeros(3)}, + }, + [], + ) + td3 = TensorDict( + { + f"a{separator}b{separator}c": torch.zeros(3), + "a": {"b": {"c": torch.zeros(3), "d": torch.zeros(3)}}, + }, + [], ) - with td.unlock_(): - td2 = td.exclude("a", inplace=True) - assert td2 is td - - def test_assert(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) + td4 = TensorDict( + { + f"a{separator}b{separator}c{separator}d": torch.zeros(3), + "a": {"b": {"c": torch.zeros(3)}}, + }, + [], + ) + + td5 = TensorDict( + {f"a{separator}b": torch.zeros(3), "a": {"b": {"c": torch.zeros(3)}}}, [] + ) + + with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): + _ = td1.flatten_keys(separator) + + with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): + _ = td2.flatten_keys(separator) + + with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): + _ = td3.flatten_keys(separator) + with pytest.raises( - RuntimeError, - match="Converting a tensordict to boolean value is not permitted", + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override an existing for unflattened key" + ), ): - assert td + _ = td1.unflatten_keys(separator) - def test_expand(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - batch_size = td.batch_size - expected_size = torch.Size([3, *batch_size]) + with pytest.raises( + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override an existing for unflattened key" + ), + ): + _ = td2.unflatten_keys(separator) - new_td = td.expand(3, *batch_size) - assert new_td.batch_size == expected_size - assert all((_new_td == td).all() for _new_td in new_td) + with pytest.raises( + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override an existing for unflattened key" + ), + ): + _ = td3.unflatten_keys(separator) - new_td_torch_size = td.expand(expected_size) - assert new_td_torch_size.batch_size == expected_size - assert all((_new_td == td).all() for _new_td in new_td_torch_size) + with pytest.raises( + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override an existing for unflattened key" + ), + ): + _ = td4.unflatten_keys(separator) - new_td_iterable = td.expand([3, *batch_size]) - assert new_td_iterable.batch_size == expected_size - assert all((_new_td == td).all() for _new_td in new_td_iterable) + with pytest.raises( + KeyError, + match=re.escape( + "Unflattening key(s) in tensordict will override an existing for unflattened key" + ), + ): + _ = td5.unflatten_keys(separator) - def test_cast_to(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td_device = td.to("cpu:1") - assert td_device.device == torch.device("cpu:1") - td_dtype = td.to(torch.int) - assert all(t.dtype == torch.int for t in td_dtype.values(True, True)) - del td_dtype - # device (str), dtype - td_dtype_device = td.to("cpu:1", torch.int) - assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) - assert td_dtype_device.device == torch.device("cpu:1") - del td_dtype_device - # device, dtype - td_dtype_device = td.to(torch.device("cpu:1"), torch.int) - assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) - assert td_dtype_device.device == torch.device("cpu:1") - del td_dtype_device - # example tensor - td_dtype_device = td.to(torch.randn(3, dtype=torch.half, device="cpu:1")) - assert all(t.dtype == torch.half for t in td_dtype_device.values(True, True)) - # tensor on cpu:1 is actually on cpu. This is still meaningful for tensordicts on cuda. - assert td_dtype_device.device == torch.device("cpu") - del td_dtype_device - # example td - td_dtype_device = td.to( - other=TensorDict( - {"a": torch.randn(3, dtype=torch.half, device="cpu:1")}, - [], - device="cpu:1", - ) + td4_flat = td4.flatten_keys(separator) + assert (f"a{separator}b{separator}c{separator}d") in td4_flat.keys() + assert (f"a{separator}b{separator}c") in td4_flat.keys() + + td5_flat = td5.flatten_keys(separator) + assert (f"a{separator}b") in td5_flat.keys() + assert (f"a{separator}b{separator}c") in td5_flat.keys() + + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) + @pytest.mark.parametrize("batch_dims", [None, 1, 2]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_from_dict(self, batch_size, batch_dims, device): + data = { + "a": torch.zeros(3, 4, 5), + "b": {"c": torch.zeros(3, 4, 5, 6)}, + ("d", "e"): torch.ones(3, 4, 5), + ("b", "f"): torch.zeros(3, 4, 5, 5), + ("d", "g", "h"): torch.ones(3, 4, 5), + } + if batch_dims and batch_size: + with pytest.raises(ValueError, match="both"): + TensorDict.from_dict( + data, batch_size=batch_size, batch_dims=batch_dims, device=device + ) + return + data = TensorDict.from_dict( + data, batch_size=batch_size, batch_dims=batch_dims, device=device ) - assert all(t.dtype == torch.half for t in td_dtype_device.values(True, True)) - assert td_dtype_device.device == torch.device("cpu:1") - del td_dtype_device - # example td, many dtypes - td_nodtype_device = td.to( - other=TensorDict( - { - "a": torch.randn(3, dtype=torch.half, device="cpu:1"), - "b": torch.randint(10, ()), - }, - [], - device="cpu:1", - ) + assert data.device == device + assert "a" in data.keys() + assert ("b", "c") in data.keys(True) + assert ("b", "f") in data.keys(True) + assert ("d", "e") in data.keys(True) + assert data.device == device + if batch_dims: + assert data.ndim == batch_dims + assert data["b"].ndim == batch_dims + assert data["d"].ndim == batch_dims + assert data["d", "g"].ndim == batch_dims + elif batch_size: + assert data.batch_size == torch.Size(batch_size) + assert data["b"].batch_size == torch.Size(batch_size) + assert data["d"].batch_size == torch.Size(batch_size) + assert data["d", "g"].batch_size == torch.Size(batch_size) + + @pytest.mark.parametrize("memmap", [True, False]) + @pytest.mark.parametrize("params", [False, True]) + def test_from_module(self, memmap, params): + net = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, ) - assert all(t.dtype != torch.half for t in td_nodtype_device.values(True, True)) - assert td_nodtype_device.device == torch.device("cpu:1") - del td_nodtype_device - # batch-size: check errors (or not) - if td_name in ( - "stacked_td", - "unsqueezed_td", - "squeezed_td", - "permute_td", - "nested_stacked_td", - ): - with pytest.raises(TypeError, match="Cannot pass batch-size to a "): - td_dtype_device = td.to( - torch.device("cpu:1"), torch.int, batch_size=torch.Size([]) - ) + td = TensorDict.from_module(net, as_module=params) + # check that we have empty tensordicts, reflecting modules wihout params + for subtd in td.values(True): + if isinstance(subtd, TensorDictBase) and subtd.is_empty(): + break else: - td_dtype_device = td.to( - torch.device("cpu:1"), torch.int, batch_size=torch.Size([]) - ) - assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) - assert td_dtype_device.device == torch.device("cpu:1") - assert td_dtype_device.batch_size == torch.Size([]) - del td_dtype_device - if td_name in ( - "stacked_td", - "unsqueezed_td", - "squeezed_td", - "permute_td", - "nested_stacked_td", + raise RuntimeError + if memmap: + td = td.detach().memmap_() + net.load_state_dict(td.flatten_keys(".")) + + if not memmap and params: + assert set(td.parameters()) == set(net.parameters()) + + def test_from_module_state_dict(self): + net = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + + def adder(module, *args, **kwargs): + for p in module.parameters(recurse=False): + p.data += 1 + + def remover(module, *args, **kwargs): + for p in module.parameters(recurse=False): + p.data = p.data - 1 + + for module in net.modules(): + module.register_state_dict_pre_hook(adder) + module._register_state_dict_hook(remover) + params_reg = TensorDict.from_module(net) + params_reg = params_reg.select(*params_reg.keys(True, True)) + + params_sd = TensorDict.from_module(net, use_state_dict=True) + params_sd = params_sd.select(*params_sd.keys(True, True)) + assert_allclose_td(params_sd, params_reg.apply(lambda x: x + 1)) + + sd = net.state_dict() + assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) + + @pytest.mark.parametrize( + "idx", + [ + (slice(None),), + slice(None), + (3, 4), + (3, slice(None), slice(2, 2, 2)), + (torch.tensor([1, 2, 3]),), + ([1, 2, 3]), + ( + torch.tensor([1, 2, 3]), + torch.tensor([2, 3, 4]), + torch.tensor([0, 10, 2]), + torch.tensor([2, 4, 1]), + ), + torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(), + torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(), + (0, torch.zeros(7, dtype=torch.bool).bernoulli_()), + ], + ) + def test_getitem_batch_size(self, idx): + shape = [10, 7, 11, 5] + shape = torch.Size(shape) + mocking_tensor = torch.zeros(*shape) + expected_shape = mocking_tensor[idx].shape + resulting_shape = _getitem_batch_size(shape, idx) + assert expected_shape == resulting_shape, (idx, expected_shape, resulting_shape) + + def test_getitem_nested(self): + tensor = torch.randn(4, 5, 6, 7) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) + sub_tensordict = TensorDict({}, [4, 5]) + tensordict = TensorDict({}, [4]) + + sub_tensordict["b"] = sub_sub_tensordict + tensordict["a"] = sub_tensordict + + # check that content match + assert (tensordict["a"] == sub_tensordict).all() + assert (tensordict["a", "b"] == sub_sub_tensordict).all() + assert (tensordict["a", "b", "c"] == tensor).all() + + # check that get method returns same contents + assert (tensordict.get("a") == sub_tensordict).all() + assert (tensordict.get(("a", "b")) == sub_sub_tensordict).all() + assert (tensordict.get(("a", "b", "c")) == tensor).all() + + # check that shapes are kept + assert tensordict.shape == torch.Size([4]) + assert sub_tensordict.shape == torch.Size([4, 5]) + assert sub_sub_tensordict.shape == torch.Size([4, 5, 6]) + + def test_inferred_view_size(self): + td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) + assert td.view(-1).view(-1, 4) is td + + assert td.view(-1, 4) is td + assert td.view(3, -1) is td + assert td.view(3, 4) is td + assert td.view(-1, 12).shape == torch.Size([1, 12]) + + def test_keys_view(self): + tensor = torch.randn(4, 5, 6, 7) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) + sub_tensordict = TensorDict({}, [4, 5]) + tensordict = TensorDict({}, [4]) + + sub_tensordict["b"] = sub_sub_tensordict + tensordict["a"] = sub_tensordict + + assert "a" in tensordict.keys() + assert "random_string" not in tensordict.keys() + + assert ("a",) in tensordict.keys(include_nested=True) + assert ("a", "b", "c") in tensordict.keys(include_nested=True) + assert ("a", "c", "b") not in tensordict.keys(include_nested=True) + + with pytest.raises( + TypeError, match="checks with tuples of strings is only supported" ): - with pytest.raises(TypeError, match="Cannot pass batch-size to a "): - td.to(batch_size=torch.Size([])) - else: - td_batchsize = td.to(batch_size=torch.Size([])) - assert td_batchsize.batch_size == torch.Size([]) - del td_batchsize + ("a", "b", "c") in tensordict.keys() # noqa: B015 - # Deprecated: - # def test_cast(self, td_name, device): - # torch.manual_seed(1) - # td = getattr(self, td_name)(device) - # td_td = td.to(TensorDict) - # assert (td == td_td).all() + with pytest.raises(TypeError, match="TensorDict keys are always strings."): + 42 in tensordict.keys() # noqa: B015 - def test_broadcast(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - sub_td = td[:, :2].to_tensordict() - sub_td.zero_() - sub_dict = sub_td.to_dict() - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:, :2] = sub_dict - assert (td[:, :2] == 0).all() + with pytest.raises(TypeError, match="TensorDict keys are always strings."): + ("a", 42) in tensordict.keys() # noqa: B015 - @pytest.mark.parametrize("call_del", [True, False]) - def test_remove(self, td_name, device, call_del): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - with td.unlock_(): - if call_del: - del td["a"] - else: - td = td.del_("a") - assert td is not None - assert "a" not in td.keys() - if td_name in ("sub_td", "sub_td2"): - return - td.lock_() - with pytest.raises(RuntimeError, match="locked"): - del td["b"] + keys = set(tensordict.keys()) + keys_nested = set(tensordict.keys(include_nested=True)) - def test_set_unexisting(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td.is_locked: - with pytest.raises( - RuntimeError, - match="Cannot modify locked TensorDict. For in-place modification", - ): - td.set("z", torch.ones_like(td.get("a"))) - else: - td.set("z", torch.ones_like(td.get("a"))) - assert (td.get("z") == 1).all() + assert keys == {"a"} + assert keys_nested == {"a", ("a", "b"), ("a", "b", "c")} - def test_fill_(self, td_name, device): + leaves = set(tensordict.keys(leaves_only=True)) + leaves_nested = set(tensordict.keys(include_nested=True, leaves_only=True)) + + assert leaves == set() + assert leaves_nested == {("a", "b", "c")} + + @pytest.mark.parametrize("device", get_available_devices()) + def test_mask_td(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - new_td = td_set.fill_("a", 0.1) - assert (td.get("a") == 0.1).all() - assert new_td is td_set + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + mask = torch.zeros(4, 5, dtype=torch.bool, device=device).bernoulli_() + td = TensorDict(batch_size=(4, 5), source=d) - def test_shape(self, td_name, device): - td = getattr(self, td_name)(device) - assert td.shape == td.batch_size + td_masked = torch.masked_select(td, mask) + assert len(td_masked.get("key1")) == td_masked.shape[0] - def test_flatten_unflatten(self, td_name, device): - td = getattr(self, td_name)(device) - shape = td.shape[:3] - td_flat = td.flatten(0, 2) - td_unflat = td_flat.unflatten(0, shape) - assert (td.to_tensordict() == td_unflat).all() - assert td.batch_size == td_unflat.batch_size + @pytest.mark.parametrize("device", get_available_devices()) + def test_memmap_as_tensor(self, device): + td = TensorDict( + {"a": torch.randn(3, 4), "b": {"c": torch.randn(3, 4)}}, + [3, 4], + device="cpu", + ) + td_memmap = td.clone().memmap_() + assert (td == td_memmap).all() - def test_flatten_unflatten_bis(self, td_name, device): - td = getattr(self, td_name)(device) - shape = td.shape[1:4] - td_flat = td.flatten(1, 3) - td_unflat = td_flat.unflatten(1, shape) - assert (td.to_tensordict() == td_unflat).all() - assert td.batch_size == td_unflat.batch_size + assert (td == td_memmap.apply(lambda x: x.clone())).all() + if device.type == "cuda": + td = td.pin_memory() + td_memmap = td.clone().memmap_() + td_memmap_pm = td_memmap.apply(lambda x: x.clone()).pin_memory() + assert (td.pin_memory().to(device) == td_memmap_pm.to(device)).all() - def test_masked_fill_(self, td_name, device): + @pytest.mark.parametrize("method", ["share_memory", "memmap"]) + def test_memory_lock(self, method): torch.manual_seed(1) - td = getattr(self, td_name)(device) - mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() - if td_name == "td_params": - td_set = td.data + td = TensorDict({"a": torch.randn(4, 5)}, batch_size=(4, 5)) + + # lock=True + if method == "share_memory": + td.share_memory_() + elif method == "memmap": + td.memmap_() else: - td_set = td - new_td = td_set.masked_fill_(mask, -10.0) - assert new_td is td_set - for item in td.values(): - assert (item[mask] == -10).all(), item[mask] + raise NotImplementedError - def test_set_nested_batch_size(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() - batch_size = torch.Size([*td.batch_size, 3]) - td.set("some_other_td", TensorDict({}, batch_size)) - assert td["some_other_td"].batch_size == batch_size + td.set("a", torch.randn(4, 5), inplace=True) + td.set_("a", torch.randn(4, 5)) # No exception because set_ ignores the lock - def test_lock(self, td_name, device): - td = getattr(self, td_name)(device) - is_locked = td.is_locked - for item in td.values(): - if isinstance(item, TensorDictBase): - assert item.is_locked == is_locked - if isinstance(td, _SubTensorDict): - with pytest.raises(RuntimeError, match="the parent tensordict instead"): - td.is_locked = not is_locked - return - td.is_locked = not is_locked - assert td.is_locked != is_locked - for _, item in td.items(): - if isinstance(item, TensorDictBase): - assert item.is_locked != is_locked - td.lock_() - assert td.is_locked - for _, item in td.items(): - if isinstance(item, TensorDictBase): - assert item.is_locked - td.unlock_() - assert not td.is_locked - for _, item in td.items(): - if isinstance(item, TensorDictBase): - assert not item.is_locked + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td.set("a", torch.randn(4, 5)) - def test_lock_write(self, td_name, device): - td = getattr(self, td_name)(device) - if isinstance(td, _SubTensorDict): - with pytest.raises(RuntimeError, match="the parent tensordict instead"): - td.lock_() - return + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td.set("b", torch.randn(4, 5)) - td.lock_() - td_clone = td.clone() - assert not td_clone.is_locked - td_clone = td.to_tensordict() - assert not td_clone.is_locked - assert td.is_locked - if td_name == "td_h5": - td.unlock_() - for key in list(td.keys()): - del td[key] - td.lock_() - else: - with td.unlock_() if td.is_locked else contextlib.nullcontext(): - td = td.select(inplace=True) - for key, item in td_clone.items(True): - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td.set(key, item) - td.unlock_() - for key, item in td_clone.items(True): - td.set(key, item) - td.lock_() - for key, item in td_clone.items(True): - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td.set(key, item) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set.set_(key, item) + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td.set("b", torch.randn(4, 5), inplace=True) - def test_unlock(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td.unlock_() - assert not td.is_locked - if td.device is not None: - assert td.device.type == "cuda" or not td.is_shared() - else: - assert not td.is_shared() - assert not td.is_memmap() + def test_pad(self): + dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2] + td = TensorDict( + { + "a": torch.ones(3, 4, 1), + "b": torch.zeros(3, 4, 1, 1), + }, + batch_size=[3, 4], + ) - def test_lock_nested(self, td_name, device): - td = getattr(self, td_name)(device) - if td_name in ("sub_td", "sub_td2") and td.is_locked: - with pytest.raises(RuntimeError, match="Cannot unlock"): - td.unlock_() + padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0) + + expected_a = torch.cat([torch.ones(3, 4, 1), torch.zeros(1, 4, 1)], dim=0) + expected_a = torch.cat([expected_a, torch.zeros(4, 2, 1)], dim=1) + + assert padded_td["a"].shape == (4, 6, 1) + assert padded_td["b"].shape == (4, 6, 1, 1) + assert torch.equal(padded_td["a"], expected_a) + padded_td._check_batch_size() + + @pytest.mark.parametrize("batch_first", [True, False]) + @pytest.mark.parametrize("make_mask", [True, False]) + def test_pad_sequence(self, batch_first, make_mask): + list_td = [ + TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]), + TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]), + ] + padded_td = pad_sequence( + list_td, batch_first=batch_first, return_mask=make_mask + ) + if batch_first: + assert padded_td.shape == torch.Size([2, 4]) + assert padded_td["a"].shape == torch.Size([2, 4]) + assert padded_td["a"][0, -1] == 0 + assert padded_td["b", "c"].shape == torch.Size([2, 4, 3]) + assert padded_td["b", "c"][0, -1, 0] == 0 else: - td.unlock_() - td.set(("some", "nested"), torch.zeros(td.shape)) - if td_name in ("sub_td", "sub_td2") and not td.is_locked: - with pytest.raises(RuntimeError, match="Cannot lock"): - td.lock_() - return - td.lock_() - some = td.get("some") - assert some.is_locked - with pytest.raises(RuntimeError): - some.unlock_() - # this assumes that td is out of scope after the call to del. - # an error in unlock_() is likely due to td leaving a trace somewhere. - del td - gc.collect() - some.unlock_() + assert padded_td.shape == torch.Size([4, 2]) + assert padded_td["a"].shape == torch.Size([4, 2]) + assert padded_td["a"][-1, 0] == 0 + assert padded_td["b", "c"].shape == torch.Size([4, 2, 3]) + assert padded_td["b", "c"][-1, 0, 0] == 0 + if make_mask: + assert "mask" in padded_td.keys() + assert not padded_td["mask"].all() + else: + assert "mask" not in padded_td.keys() - # @pytest.mark.parametrize("op", ["keys_root", "keys_nested", "values", "items"]) - @pytest.mark.parametrize("op", ["flatten", "unflatten"]) - def test_cache(self, td_name, device, op): + @pytest.mark.parametrize("device", get_available_devices()) + def test_permute(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - try: - td.lock_() - except Exception: - return - if op == "keys_root": - a = list(td.keys()) - b = list(td.keys()) - assert a == b - elif op == "keys_nested": - a = list(td.keys(True)) - b = list(td.keys(True)) - assert a == b - elif op == "values": - a = list(td.values(True)) - b = list(td.values(True)) - assert all((_a == _b).all() for _a, _b in zip(a, b)) - elif op == "items": - keys_a, values_a = zip(*td.items(True)) - keys_b, values_b = zip(*td.items(True)) - assert all((_a == _b).all() for _a, _b in zip(values_a, values_b)) - assert keys_a == keys_b - elif op == "flatten": - a = td.flatten_keys() - b = td.flatten_keys() - if td_name not in ("td_h5",): - assert a is b - else: - assert a is not b - elif op == "unflatten": - a = td.unflatten_keys() - b = td.unflatten_keys() - if td_name not in ("td_h5",): - assert a is b - else: - assert a is not b + d = { + "a": torch.randn(4, 5, 6, 9, device=device), + "b": torch.randn(4, 5, 6, 7, device=device), + "c": torch.randn(4, 5, 6, device=device), + } + td1 = TensorDict(batch_size=(4, 5, 6), source=d) + td2 = torch.permute(td1, dims=(2, 1, 0)) + assert td2.shape == torch.Size((6, 5, 4)) + assert td2["a"].shape == torch.Size((6, 5, 4, 9)) - if td_name != "td_params": - assert len(td._cache) - td.unlock_() - assert td._cache is None - for val in td.values(True): - if is_tensor_collection(val): - assert td._cache is None + td2 = torch.permute(td1, dims=(-1, -3, -2)) + assert td2.shape == torch.Size((6, 4, 5)) + assert td2["c"].shape == torch.Size((6, 4, 5)) - def test_enter_exit(self, td_name, device): - torch.manual_seed(1) - if td_name in ("sub_td", "sub_td2"): - return - td = getattr(self, td_name)(device) - is_locked = td.is_locked - with td.lock_() as other: - assert other is td - assert td.is_locked - with td.unlock_() as other: - assert other is td - assert not td.is_locked - assert td.is_locked - assert td.is_locked is is_locked + td2 = torch.permute(td1, dims=(0, 1, 2)) + assert td2["a"].shape == torch.Size((4, 5, 6, 9)) - def test_lock_change_names(self, td_name, device): + t = TensorDict({"a": torch.randn(3, 4, 1)}, [3, 4]) + torch.permute(t, dims=(1, 0)).set("b", torch.randn(4, 3)) + assert t["b"].shape == torch.Size((3, 4)), t + + torch.permute(t, dims=(1, 0)).fill_("a", 0.0) + assert torch.sum(t["a"]) == torch.Tensor([0]) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_permute_applied_twice(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - try: - td.names = [str(i) for i in range(td.ndim)] - td.lock_() - except Exception: - return - # cache values - list(td.values(True)) - td.names = [str(-i) for i in range(td.ndim)] - for val in td.values(True): - if not is_tensor_collection(val): - continue - assert val.names[: td.ndim] == [str(-i) for i in range(td.ndim)] + d = { + "a": torch.randn(4, 5, 6, 9, device=device), + "b": torch.randn(4, 5, 6, 7, device=device), + "c": torch.randn(4, 5, 6, device=device), + } + td1 = TensorDict(batch_size=(4, 5, 6), source=d) + td2 = torch.permute(td1, dims=(2, 1, 0)) + td3 = torch.permute(td2, dims=(2, 1, 0)) + assert td3 is td1 + td1 = TensorDict(batch_size=(4, 5, 6), source=d) + td2 = torch.permute(td1, dims=(2, 1, 0)) + td3 = torch.permute(td2, dims=(0, 1, 2)) + assert td3 is not td1 - def test_sorted_keys(self, td_name, device): + @pytest.mark.parametrize("device", get_available_devices()) + def test_permute_exceptions(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - sorted_keys = td.sorted_keys - i = -1 - for i, (key1, key2) in enumerate(zip(sorted_keys, td.keys())): # noqa: B007 - assert key1 == key2 - assert i == len(td.keys()) - 1 - if td.is_locked: - assert td._cache.get("sorted_keys", None) is not None - td.unlock_() - assert td._cache is None - elif td_name not in ("sub_td", "sub_td2"): # we cannot lock sub tensordicts - if isinstance(td, _CustomOpTensorDict): - target = td._source - else: - target = td - assert target._cache is None - td.lock_() - _ = td.sorted_keys - assert target._cache.get("sorted_keys", None) is not None - td.unlock_() - assert target._cache is None + d = { + "a": torch.randn(4, 5, 6, 7, device=device), + "b": torch.randn(4, 5, 6, 8, 9, device=device), + } + td1 = TensorDict(batch_size=(4, 5, 6), source=d) - def test_masked_fill(self, td_name, device): + with pytest.raises(RuntimeError): + td2 = td1.permute(1, 1, 0) + _ = td2.shape + + with pytest.raises(RuntimeError): + td2 = td1.permute(3, 2, 1, 0) + _ = td2.shape + + with pytest.raises(RuntimeError): + td2 = td1.permute(2, -1, 0) + _ = td2.shape + + with pytest.raises(IndexError): + td2 = td1.permute(2, 3, 0) + _ = td2.shape + + with pytest.raises(IndexError): + td2 = td1.permute(2, -4, 0) + _ = td2.shape + + with pytest.raises(RuntimeError): + td2 = td1.permute(2, 1) + _ = td2.shape + + @pytest.mark.parametrize("device", get_available_devices()) + def test_permute_with_tensordict_operations(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() - new_td = td.masked_fill(mask, -10.0) - assert new_td is not td - for item in new_td.values(): - assert (item[mask] == -10).all() + d = { + "a": torch.randn(20, 6, 9, device=device), + "b": torch.randn(20, 6, 7, device=device), + "c": torch.randn(20, 6, device=device), + } + td1 = TensorDict(batch_size=(20, 6), source=d).view(4, 5, 6).permute(2, 1, 0) + assert td1.shape == torch.Size((6, 5, 4)) - def test_zero_(self, td_name, device): + d = { + "a": torch.randn(4, 5, 6, 7, 9, device=device), + "b": torch.randn(4, 5, 6, 7, 7, device=device), + "c": torch.randn(4, 5, 6, 7, device=device), + } + td1 = TensorDict(batch_size=(4, 5, 6, 7), source=d)[ + :, :, :, torch.tensor([1, 2]) + ].permute(3, 2, 1, 0) + assert td1.shape == torch.Size((2, 6, 5, 4)) + + d = { + "a": torch.randn(4, 5, 9, device=device), + "b": torch.randn(4, 5, 7, device=device), + "c": torch.randn(4, 5, device=device), + } + td1 = stack_td( + [TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)], + 2, + contiguous=False, + ).permute(2, 1, 0) + assert td1.shape == torch.Size((6, 5, 4)) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_requires_grad(self, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) - new_td = td.zero_() - assert new_td is td - for k in td.keys(): - assert (td.get(k) == 0).all() + # Just one of the tensors have requires_grad + tensordicts = [ + TensorDict( + batch_size=[11, 12], + source={ + "key1": torch.randn( + 11, + 12, + 5, + device=device, + requires_grad=True if i == 5 else False, + ), + "key2": torch.zeros( + 11, 12, 50, device=device, dtype=torch.bool + ).bernoulli_(), + }, + ) + for i in range(10) + ] + stacked_td = LazyStackedTensorDict(*tensordicts, stack_dim=0) + # First stacked tensor has requires_grad == True + assert list(stacked_td.values())[0].requires_grad is True - @pytest.mark.parametrize("inplace", [False, True]) - def test_apply(self, td_name, device, inplace): - td = getattr(self, td_name)(device) - td_c = td.to_tensordict() - if inplace and td_name == "td_params": - with pytest.raises(ValueError, match="Failed to update"): - td.apply(lambda x: x + 1, inplace=inplace) - return - td_1 = td.apply(lambda x: x + 1, inplace=inplace) - if inplace: - for key in td.keys(True, True): - assert (td_c[key] + 1 == td[key]).all() - assert (td_1[key] == td[key]).all() + @pytest.mark.parametrize("like", [True, False]) + def test_save_load_memmap_stacked_td( + self, + like, + tmpdir, + ): + a = TensorDict({"a": [1]}, []) + b = TensorDict({"b": [1]}, []) + c = torch.stack([a, b]) + c = c.expand(10, 2) + if like: + d = c.memmap_like(prefix=tmpdir) else: - for key in td.keys(True, True): - assert (td_c[key] + 1 != td[key]).any() - assert (td_1[key] == td[key] + 1).all() + d = c.memmap_(prefix=tmpdir) - @pytest.mark.parametrize("inplace", [False, True]) - def test_apply_default(self, td_name, device, inplace): - if td_name in ("td_h5",): - pytest.skip("Cannot test assignment in persistent tensordict.") - td = getattr(self, td_name)(device) - td_c = td.to_tensordict() - if td_name in ("td_params",): + d2 = LazyStackedTensorDict.load_memmap(tmpdir) + assert (d2 == d).all() + assert (d2[:, 0] == d[:, 0]).all() + if like: + assert (d2[:, 0] == a.zero_()).all() + else: + assert (d2[:, 0] == a).all() + + @pytest.mark.parametrize("inplace", [True, False]) + def test_select_nested(self, inplace): + tensor_1 = torch.rand(4, 5, 6, 7) + tensor_2 = torch.rand(4, 5, 6, 7) + sub_sub_tensordict = TensorDict( + {"t1": tensor_1, "t2": tensor_2}, batch_size=[4, 5, 6] + ) + sub_tensordict = TensorDict( + {"double_nested": sub_sub_tensordict}, batch_size=[4, 5] + ) + tensordict = TensorDict( + { + "a": torch.rand(4, 3), + "b": torch.rand(4, 2), + "c": torch.rand(4, 1), + "nested": sub_tensordict, + }, + batch_size=[4], + ) + + selected = tensordict.select( + "b", ("nested", "double_nested", "t2"), inplace=inplace + ) + + assert set(selected.keys(include_nested=True)) == { + "b", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t2"), + } + + if inplace: + assert selected is tensordict + assert set(tensordict.keys(include_nested=True)) == { + "b", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t2"), + } + else: + assert selected is not tensordict + assert set(tensordict.keys(include_nested=True)) == { + "a", + "b", + "c", + "nested", + ("nested", "double_nested"), + ("nested", "double_nested", "t1"), + ("nested", "double_nested", "t2"), + } + + def test_select_nested_missing(self): + # checks that we keep a nested key even if missing nested keys are present + td = TensorDict({"a": {"b": [1], "c": [2]}}, []) + + td_select = td.select(("a", "b"), "r", ("a", "z"), strict=False) + assert ("a", "b") in list(td_select.keys(True, True)) + assert ("a", "b") in td_select.keys(True, True) + + def test_set_nested_keys(self): + tensor = torch.randn(4, 5, 6, 7) + tensor2 = torch.ones(4, 5, 6, 7) + tensordict = TensorDict({}, [4]) + sub_tensordict = TensorDict({}, [4, 5]) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) + sub_sub_tensordict2 = TensorDict({"c": tensor2}, [4, 5, 6]) + sub_tensordict.set("b", sub_sub_tensordict) + tensordict.set("a", sub_tensordict) + assert tensordict.get(("a", "b")) is sub_sub_tensordict + + tensordict.set(("a", "b"), sub_sub_tensordict2) + assert tensordict.get(("a", "b")) is sub_sub_tensordict2 + assert (tensordict.get(("a", "b", "c")) == 1).all() + + @pytest.mark.parametrize("index0", [None, slice(None)]) + def test_set_sub_key(self, index0): + # tests that parent tensordict is affected when subtensordict is set with a new key + batch_size = [10, 10] + source = {"a": torch.randn(10, 10, 10), "b": torch.ones(10, 10, 2)} + td = TensorDict(source, batch_size=batch_size) + idx0 = (index0, 0) if index0 is not None else 0 + td0 = td._get_sub_tensordict(idx0) + idx = (index0, slice(2, 4)) if index0 is not None else slice(2, 4) + sub_td = td._get_sub_tensordict(idx) + if index0 is None: + c = torch.randn(2, 10, 10) + else: + c = torch.randn(10, 2, 10) + sub_td.set("c", c) + assert (td.get("c")[idx] == sub_td.get("c")).all() + assert (sub_td.get("c") == c).all() + assert (td.get("c")[idx0] == 0).all() + assert (td._get_sub_tensordict(idx0).get("c") == 0).all() + assert (td0.get("c") == 0).all() + + def test_setdefault_nested(self): + tensor = torch.randn(4, 5, 6, 7) + tensor2 = torch.ones(4, 5, 6, 7) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) + sub_tensordict = TensorDict({"b": sub_sub_tensordict}, [4, 5]) + tensordict = TensorDict({"a": sub_tensordict}, [4]) + + # if key exists we return the existing value + assert tensordict.setdefault(("a", "b", "c"), tensor2) is tensor + + assert tensordict.setdefault(("a", "b", "d"), tensor2) is tensor2 + assert (tensordict["a", "b", "d"] == 1).all() + assert tensordict.get(("a", "b", "d")) is tensor2 + + def test_shared_inheritance(self): + td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) + td.share_memory_() + + td0, *_ = td.unbind(1) + assert td0.is_shared() + + td0, *_ = td.split(1, 0) + assert td0.is_shared() + + td0 = td.exclude("a") + assert td0.is_shared() + + td0 = td.select("a") + assert td0.is_shared() + + td.unlock_() + td0 = td.rename_key_("a", "a.a") + assert not td0.is_shared() + td.share_memory_() + + td0 = td.unflatten_keys(".") + assert td0.is_shared() + + td0 = td.flatten_keys(".") + assert td0.is_shared() + + td0 = td.view(-1) + assert td0.is_shared() + + td0 = td.permute(1, 0) + assert td0.is_shared() + + td0 = td.unsqueeze(0) + assert td0.is_shared() + + td0 = td0.squeeze(0) + assert td0.is_shared() + + def test_setitem_nested(self): + tensor = torch.randn(4, 5, 6, 7) + tensor2 = torch.ones(4, 5, 6, 7) + tensordict = TensorDict({}, [4]) + sub_tensordict = TensorDict({}, [4, 5]) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) + sub_sub_tensordict2 = TensorDict({"c": tensor2}, [4, 5, 6]) + sub_tensordict["b"] = sub_sub_tensordict + tensordict["a"] = sub_tensordict + assert tensordict["a", "b"] is sub_sub_tensordict + tensordict["a", "b"] = sub_sub_tensordict2 + assert tensordict["a", "b"] is sub_sub_tensordict2 + assert (tensordict["a", "b", "c"] == 1).all() + + # check the same with set method + sub_tensordict.set("b", sub_sub_tensordict) + tensordict.set("a", sub_tensordict) + assert tensordict["a", "b"] is sub_sub_tensordict + + tensordict.set(("a", "b"), sub_sub_tensordict2) + assert tensordict["a", "b"] is sub_sub_tensordict2 + assert (tensordict["a", "b", "c"] == 1).all() + + def test_split_with_empty_tensordict(self): + td = TensorDict({}, [10]) + + tds = td.split(4, 0) + assert len(tds) == 3 + assert tds[0].shape == torch.Size([4]) + assert tds[1].shape == torch.Size([4]) + assert tds[2].shape == torch.Size([2]) + + tds = td.split([1, 9], 0) + + assert len(tds) == 2 + assert tds[0].shape == torch.Size([1]) + assert tds[1].shape == torch.Size([9]) + + td = TensorDict({}, [10, 10, 3]) + + tds = td.split(4, 1) + assert len(tds) == 3 + assert tds[0].shape == torch.Size([10, 4, 3]) + assert tds[1].shape == torch.Size([10, 4, 3]) + assert tds[2].shape == torch.Size([10, 2, 3]) + + tds = td.split([1, 9], 1) + assert len(tds) == 2 + assert tds[0].shape == torch.Size([10, 1, 3]) + assert tds[1].shape == torch.Size([10, 9, 3]) + + def test_split_with_invalid_arguments(self): + td = TensorDict({"a": torch.zeros(2, 1)}, []) + # Test empty batch size + with pytest.raises(IndexError, match="Dimension out of range"): + td.split(1, 0) + + td = TensorDict({}, [3, 2]) + + # Test invalid split_size input + with pytest.raises(TypeError, match="must be int or list of ints"): + td.split("1", 0) + with pytest.raises(TypeError, match="must be int or list of ints"): + td.split(["1", 2], 0) + + # Test invalid split_size sum + with pytest.raises( + RuntimeError, match="Insufficient number of elements in split_size" + ): + td.split([], 0) + + with pytest.raises(RuntimeError, match="expects split_size to sum exactly"): + td.split([1, 1], 0) + + # Test invalid dimension input + with pytest.raises(IndexError, match="Dimension out of range"): + td.split(1, 2) + with pytest.raises(IndexError, match="Dimension out of range"): + td.split(1, -3) + + def test_split_with_negative_dim(self): + td = TensorDict( + {"a": torch.zeros(5, 4, 2, 1), "b": torch.zeros(5, 4, 1)}, [5, 4] + ) + + tds = td.split([1, 3], -1) + assert len(tds) == 2 + assert tds[0].shape == torch.Size([5, 1]) + assert tds[0]["a"].shape == torch.Size([5, 1, 2, 1]) + assert tds[0]["b"].shape == torch.Size([5, 1, 1]) + assert tds[1].shape == torch.Size([5, 3]) + assert tds[1]["a"].shape == torch.Size([5, 3, 2, 1]) + assert tds[1]["b"].shape == torch.Size([5, 3, 1]) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_squeeze(self, device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td1 = TensorDict(batch_size=(4, 5), source=d) + td2 = torch.unsqueeze(td1, dim=1) + assert td2.batch_size == torch.Size([4, 1, 5]) + + td1b = torch.squeeze(td2, dim=1) + assert td1b.batch_size == td1.batch_size + + @pytest.mark.parametrize("device", get_available_devices()) + def test_subtensordict_construction(self, device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5)) + val1 = torch.randn(4, 5, 1, device=device) + val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device) + val1_copy = val1.clone() + val2_copy = val2.clone() + td.set("key1", val1) + td.set("key2", val2) + std1 = td._get_sub_tensordict(2) + std2 = std1._get_sub_tensordict(2) + idx = (2, 2) + std_control = td._get_sub_tensordict(idx) + assert (std_control.get("key1") == std2.get("key1")).all() + assert (std_control.get("key2") == std2.get("key2")).all() + + # write values + with pytest.raises(RuntimeError, match="is prohibited for existing tensors"): + std_control.set("key1", torch.randn(1, device=device)) + with pytest.raises(RuntimeError, match="is prohibited for existing tensors"): + std_control.set("key2", torch.randn(6, device=device, dtype=torch.double)) + + subval1 = torch.randn(1, device=device) + subval2 = torch.randn(6, device=device, dtype=torch.double) + std_control.set_("key1", subval1) + std_control.set_("key2", subval2) + assert (val1_copy[idx] != subval1).all() + assert (td.get("key1")[idx] == subval1).all() + assert (td.get("key1")[1, 1] == val1_copy[1, 1]).all() + + assert (val2_copy[idx] != subval2).all() + assert (td.get("key2")[idx] == subval2).all() + assert (td.get("key2")[1, 1] == val2_copy[1, 1]).all() + + assert (std_control.get("key1") == std2.get("key1")).all() + assert (std_control.get("key2") == std2.get("key2")).all() + + assert std_control.get_parent_tensordict() is td + assert ( + std_control.get_parent_tensordict() + is std2.get_parent_tensordict().get_parent_tensordict() + ) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_tensordict_device(self, device): + tensordict = TensorDict({"a": torch.randn(3, 4)}, []) + assert tensordict.device is None + + tensordict = TensorDict({"a": torch.randn(3, 4, device=device)}, []) + assert tensordict["a"].device == device + assert tensordict.device is None + + tensordict = TensorDict( + { + "a": torch.randn(3, 4, device=device), + "b": torch.randn(3, 4), + "c": torch.randn(3, 4, device="cpu"), + }, + [], + device=device, + ) + assert tensordict.device == device + assert tensordict["a"].device == device + assert tensordict["b"].device == device + assert tensordict["c"].device == device + + tensordict = TensorDict({}, [], device=device) + tensordict["a"] = torch.randn(3, 4) + tensordict["b"] = torch.randn(3, 4, device="cpu") + assert tensordict["a"].device == device + assert tensordict["b"].device == device + + tensordict = TensorDict({"a": torch.randn(3, 4)}, []) + tensordict = tensordict.to(device) + assert tensordict.device == device + assert tensordict["a"].device == device + + @pytest.mark.skipif( + torch.cuda.device_count() == 0, reason="No cuda device detected" + ) + @pytest.mark.parametrize("device", get_available_devices()[1:]) + def test_tensordict_error_messages(self, device): + sub1 = TensorDict({"a": torch.randn(2, 3)}, [2]) + sub2 = TensorDict({"a": torch.randn(2, 3, device=device)}, [2]) + td1 = TensorDict({"sub": sub1}, [2]) + td2 = TensorDict({"sub": sub2}, [2]) + + with pytest.raises( + RuntimeError, match='tensors on different devices at key "sub" / "a"' + ): + torch.cat([td1, td2], 0) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_tensordict_indexing(self, device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5)) + td.set("key1", torch.randn(4, 5, 1, device=device)) + td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double)) + + td_select = td[2, 2] + td_select._check_batch_size() + + td_select = td[2, :2] + td_select._check_batch_size() + + td_select = td[None, :2] + td_select._check_batch_size() + + td_reconstruct = stack_td(list(td), 0, contiguous=False) + assert ( + td_reconstruct == td + ).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}" + + superlist = [stack_td(list(_td), 0, contiguous=False) for _td in td] + td_reconstruct = stack_td(superlist, 0, contiguous=False) + assert ( + td_reconstruct == td + ).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}" + + x = torch.randn(4, 5, device=device) + td = TensorDict( + source={"key1": torch.zeros(3, 4, 5, device=device)}, + batch_size=[3, 4], + ) + td[0].set_("key1", x) + torch.testing.assert_close(td.get("key1")[0], x) + torch.testing.assert_close(td.get("key1")[0], td[0].get("key1")) + + y = torch.randn(3, 5, device=device) + td[:, 0].set_("key1", y) + torch.testing.assert_close(td.get("key1")[:, 0], y) + torch.testing.assert_close(td.get("key1")[:, 0], td[:, 0].get("key1")) + + def test_tensordict_prealloc_nested(self): + N = 3 + B = 5 + T = 4 + buffer = TensorDict({}, batch_size=[B, N]) + + td_0 = TensorDict( + { + "env.time": torch.rand(N, 1), + "agent.obs": TensorDict( + { # assuming 3 agents in a multi-agent setting + "image": torch.rand(N, T, 64), + "state": torch.rand(N, T, 3, 32, 32), + }, + batch_size=[N, T], + ), + }, + batch_size=[N], + ) + + td_1 = td_0.clone() + buffer[0] = td_0 + buffer[1] = td_1 + assert ( + repr(buffer) + == """TensorDict( + fields={ + agent.obs: TensorDict( + fields={ + image: Tensor(shape=torch.Size([5, 3, 4, 64]), device=cpu, dtype=torch.float32, is_shared=False), + state: Tensor(shape=torch.Size([5, 3, 4, 3, 32, 32]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 3, 4]), + device=None, + is_shared=False), + env.time: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5, 3]), + device=None, + is_shared=False)""" + ) + assert buffer.batch_size == torch.Size([B, N]) + assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_tensordict_set(self, device): + torch.manual_seed(1) + td = TensorDict({}, batch_size=(4, 5), device=device) + td.set("key1", torch.randn(4, 5)) + assert td.device == torch.device(device) + # by default inplace: + with pytest.raises(RuntimeError): + td.set("key1", torch.randn(5, 5, device=device)) + + # robust to dtype casting + td.set_("key1", torch.ones(4, 5, device=device, dtype=torch.double)) + assert (td.get("key1") == 1).all() + + # robust to device casting + td.set("key_device", torch.ones(4, 5, device="cpu", dtype=torch.double)) + assert td.get("key_device").device == torch.device(device) + + with pytest.raises(KeyError, match="not found in TensorDict with keys"): + td.set_("smartypants", torch.ones(4, 5, device="cpu", dtype=torch.double)) + # test set_at_ + td.set("key2", torch.randn(4, 5, 6, device=device)) + x = torch.randn(6, device=device) + td.set_at_("key2", x, (2, 2)) + assert (td.get("key2")[2, 2] == x).all() + + # test set_at_ with dtype casting + x = torch.randn(6, dtype=torch.double, device=device) + td.set_at_("key2", x, (2, 2)) # robust to dtype casting + torch.testing.assert_close(td.get("key2")[2, 2], x.to(torch.float)) + + td.set( + "key1", torch.zeros(4, 5, dtype=torch.double, device=device), inplace=True + ) + assert (td.get("key1") == 0).all() + td.set( + "key1", + torch.randn(4, 5, 1, 2, dtype=torch.double, device=device), + inplace=False, + ) + assert td["key1"].shape == td._tensordict["key1"].shape + + def test_to_module_state_dict(self): + net0 = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + net1 = nn.Transformer( + d_model=16, + nhead=2, + num_encoder_layers=3, + dim_feedforward=12, + ) + + def hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + for key, val in list(state_dict.items()): + state_dict[key] = val * 0 + + for module in net0.modules(): + module._register_load_state_dict_pre_hook(hook, with_module=True) + for module in net1.modules(): + module._register_load_state_dict_pre_hook(hook, with_module=True) + + params_reg = TensorDict.from_module(net0) + params_reg.to_module(net0, use_state_dict=True) + params_reg = TensorDict.from_module(net0) + + sd = net1.state_dict() + net1.load_state_dict(sd) + sd = net1.state_dict() + + assert (params_reg == 0).all() + assert set(params_reg.flatten_keys(".").keys()) == set(sd.keys()) + assert_allclose_td(params_reg.flatten_keys("."), TensorDict(sd, [])) + + def test_unbind_batchsize(self): + td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2]) + td["a"].batch_size + tds = td.unbind(0) + assert tds[0].batch_size == torch.Size([]) + assert tds[0]["a"].batch_size == torch.Size([3]) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_unbind_td(self, device): + torch.manual_seed(1) + d = { + "key1": torch.randn(4, 5, 6, device=device), + "key2": torch.randn(4, 5, 10, device=device), + } + td = TensorDict(batch_size=(4, 5), source=d) + td_unbind = torch.unbind(td, dim=1) + assert ( + td_unbind[0].batch_size == td[:, 0].batch_size + ), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}" + + def test_update_nested_dict(self): + t = TensorDict({"a": {"d": [[[0]] * 3] * 2}}, [2, 3]) + assert ("a", "d") in t.keys(include_nested=True) + t.update({"a": {"b": [[[1]] * 3] * 2}}) + assert ("a", "d") in t.keys(include_nested=True) + assert ("a", "b") in t.keys(include_nested=True) + assert t["a", "b"].shape == torch.Size([2, 3, 1]) + t.update({"a": {"d": [[[1]] * 3] * 2}}) + + +@pytest.mark.parametrize( + "td_name,device", + TestTensorDictsBase.TYPES_DEVICES, +) +class TestTensorDicts(TestTensorDictsBase): + @pytest.mark.parametrize("nested", [False, True]) + def test_add_batch_dim_cache(self, td_name, device, nested): + td = getattr(self, td_name)(device) + if nested: + td = TensorDict({"parent": td}, td.batch_size) + from tensordict.nn import TensorDictModule # noqa + from torch import vmap + + fun = vmap(lambda x: x) + if td_name == "td_h5": + with pytest.raises( + RuntimeError, match="Persistent tensordicts cannot be used with vmap" + ): + fun(td) + return + if td_name == "memmap_td" and device.type != "cpu": + with pytest.raises( + RuntimeError, + match="MemoryMappedTensor with non-cpu device are not supported in vmap ops", + ): + fun(td) + return + fun(td) + + td.zero_() + # this value should be cached + std = fun(td) + for value in std.values(True, True): + assert (value == 0).all() + + @pytest.mark.parametrize("inplace", [False, True]) + def test_apply(self, td_name, device, inplace): + td = getattr(self, td_name)(device) + td_c = td.to_tensordict() + if inplace and td_name == "td_params": + with pytest.raises(ValueError, match="Failed to update"): + td.apply(lambda x: x + 1, inplace=inplace) + return + td_1 = td.apply(lambda x: x + 1, inplace=inplace) + if inplace: + for key in td.keys(True, True): + assert (td_c[key] + 1 == td[key]).all() + assert (td_1[key] == td[key]).all() + else: + for key in td.keys(True, True): + assert (td_c[key] + 1 != td[key]).any() + assert (td_1[key] == td[key] + 1).all() + + @pytest.mark.parametrize("inplace", [False, True]) + def test_apply_default(self, td_name, device, inplace): + if td_name in ("td_h5",): + pytest.skip("Cannot test assignment in persistent tensordict.") + td = getattr(self, td_name)(device) + td_c = td.to_tensordict() + if td_name in ("td_params",): td.data.zero_() else: td.zero_() @@ -1144,34 +1704,6 @@ def get_old_val(newval, oldval): assert key == ("nested", "newkey") assert (td_1[key] == 0).all() - @pytest.mark.parametrize("inplace", [False, True]) - def test_named_apply(self, td_name, device, inplace): - td = getattr(self, td_name)(device) - td_c = td.to_tensordict() - - def named_plus(name, x): - if "a" in name: - return x + 1 - - if inplace and td_name == "td_params": - with pytest.raises(ValueError, match="Failed to update"): - td.named_apply(named_plus, inplace=inplace) - return - td_1 = td.named_apply(named_plus, inplace=inplace) - if inplace: - assert td_1 is td - for key in td_1.keys(True, True): - if "a" in key: - assert (td_c[key] + 1 == td_1[key]).all() - else: - assert (td_c[key] == td_1[key]).all() - assert (td_1[key] == td[key]).all() - else: - for key in td_1.keys(True, True): - assert "a" in key - assert (td_c[key] + 1 != td[key]).any() - assert (td_1[key] == td[key] + 1).all() - @pytest.mark.parametrize("inplace", [False, True]) def test_apply_other(self, td_name, device, inplace): td = getattr(self, td_name)(device) @@ -1190,283 +1722,99 @@ def test_apply_other(self, td_name, device, inplace): assert (td_c[key] * 2 != td[key]).any() assert (td_1[key] == td[key] * 2).all() - def test_from_empty(self, td_name, device): + def test_as_tensor(self, td_name, device): + td = getattr(self, td_name)(device) + if "memmap" in td_name and device == torch.device("cpu"): + tdt = td.as_tensor() + assert (tdt == td).all() + elif "memmap" in td_name: + with pytest.raises( + RuntimeError, match="can only be called with MemoryMappedTensors stored" + ): + td.as_tensor() + else: + # checks that it runs + td.as_tensor() + + def test_assert(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - new_td = TensorDict({}, batch_size=td.batch_size, device=device) - for key, item in td.items(): - new_td.set(key, item) - assert_allclose_td(td, new_td) - assert td.device == new_td.device - assert td.shape == new_td.shape - - def test_masking(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - while True: - mask = torch.zeros( - td.batch_size, dtype=torch.bool, device=device - ).bernoulli_(0.8) - if not mask.all() and mask.any(): - break - td_masked = td[mask] - td_masked2 = torch.masked_select(td, mask) - assert_allclose_td(td_masked, td_masked2) - assert td_masked.batch_size[0] == mask.sum() - assert td_masked.batch_dims == 1 - - # mask_list = mask.cpu().numpy().tolist() - # td_masked3 = td[mask_list] - # assert_allclose_td(td_masked3, td_masked2) - # assert td_masked3.batch_size[0] == mask.sum() - # assert td_masked3.batch_dims == 1 - - def test_entry_type(self, td_name, device): - td = getattr(self, td_name)(device) - for key in td.keys(include_nested=True): - assert type(td.get(key)) is td.entry_class(key) - - def test_equal(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - assert (td == td.to_tensordict()).all() - td0 = td.to_tensordict().zero_() - assert (td != td0).any() - - def test_equal_float(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set.zero_() - assert (td == 0.0).all() - td0 = td.clone() - if td_name == "td_params": - td_set = td0.data - else: - td_set = td0 - td_set.zero_() - assert (td0 != 1.0).all() - - def test_equal_other(self, td_name, device): - td = getattr(self, td_name)(device) - assert not td == "z" - assert td != "z" - - def test_equal_int(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set.zero_() - assert (td == 0).all() - td0 = td.to_tensordict().zero_() - assert (td0 != 1).all() + with pytest.raises( + RuntimeError, + match="Converting a tensordict to boolean value is not permitted", + ): + assert td - def test_equal_tensor(self, td_name, device): + def test_broadcast(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) + sub_td = td[:, :2].to_tensordict() + sub_td.zero_() + sub_dict = sub_td.to_dict() if td_name == "td_params": td_set = td.data else: td_set = td - td_set.zero_() - assert (td == torch.zeros([], dtype=torch.int, device=device)).all() - td0 = td.to_tensordict().zero_() - assert (td0 != torch.ones([], dtype=torch.int, device=device)).all() - - def test_equal_dict(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - assert (td == td.to_dict()).all() - td0 = td.to_tensordict().zero_().to_dict() - assert (td != td0).any() - - @pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2, -3]) - def test_gather(self, td_name, device, dim): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - index = torch.ones(td.shape, device=td.device, dtype=torch.long) - other_dim = dim + index.ndim if dim < 0 else dim - idx = (*[slice(None) for _ in range(other_dim)], slice(2)) - index = index[idx] - index = index.cumsum(dim=other_dim) - 1 - # gather - td_gather = torch.gather(td, dim=dim, index=index) - # gather with out - td_gather.zero_() - out = td_gather.clone() - if td_name == "td_params": - with pytest.raises( - RuntimeError, match="don't support automatic differentiation" - ): - torch.gather(td, dim=dim, index=index, out=out) - return - td_gather2 = torch.gather(td, dim=dim, index=index, out=out) - assert (td_gather2 != 0).any() + td_set[:, :2] = sub_dict + assert (td[:, :2] == 0).all() - def test_where(self, td_name, device): + @pytest.mark.parametrize("op", ["flatten", "unflatten"]) + def test_cache(self, td_name, device, op): torch.manual_seed(1) td = getattr(self, td_name)(device) - mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() - td_where = torch.where(mask, td, 0) - for k in td.keys(True, True): - assert (td_where.get(k)[~mask] == 0).all() - td_where = torch.where(mask, td, torch.ones_like(td)) - for k in td.keys(True, True): - assert (td_where.get(k)[~mask] == 1).all() - td_where = td.clone() - - if td_name == "td_h5": - with pytest.raises( - RuntimeError, - match="Cannot use a persistent tensordict as output of torch.where", - ): - torch.where(mask, td, torch.ones_like(td), out=td_where) + try: + td.lock_() + except Exception: return - torch.where(mask, td, torch.ones_like(td), out=td_where) - for k in td.keys(True, True): - assert (td_where.get(k)[~mask] == 1).all() - - def test_where_pad(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - # test with other empty td - mask = torch.zeros(td.shape, dtype=torch.bool, device=td.device).bernoulli_() - if td_name in ("td_h5",): - td_full = td.to_tensordict() - else: - td_full = td - td_empty = td_full.empty() - result = td.where(mask, td_empty, pad=1) - for v in result.values(True, True): - assert (v[~mask] == 1).all() - td_empty = td_full.empty() - result = td_empty.where(~mask, td, pad=1) - for v in result.values(True, True): - assert (v[~mask] == 1).all() - # with output - td_out = td_full.empty() - result = td.where(mask, td_empty, pad=1, out=td_out) - for v in result.values(True, True): - assert (v[~mask] == 1).all() - if td_name not in ("td_params",): - assert result is td_out - # TODO: decide if we want where to return a TensorDictParams. - # probably not, given - # else: - # assert isinstance(result, TensorDictParams) - td_out = td_full.empty() - td_empty = td_full.empty() - result = td_empty.where(~mask, td, pad=1, out=td_out) - for v in result.values(True, True): - assert (v[~mask] == 1).all() - assert result is td_out - - with pytest.raises(KeyError, match="not found and no pad value provided"): - td.where(mask, td_full.empty()) - with pytest.raises(KeyError, match="not found and no pad value provided"): - td_full.empty().where(mask, td) - - def test_masking_set(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_( - 0.8 - ) - n = mask.sum() - d = td.ndimension() - pseudo_td = td.apply( - lambda item: torch.zeros( - (n, *item.shape[d:]), dtype=item.dtype, device=device - ), - batch_size=[n, *td.batch_size[d:]], - ) - - if td_name == "td_params": - td_set = td.data - else: - td_set = td + if op == "keys_root": + a = list(td.keys()) + b = list(td.keys()) + assert a == b + elif op == "keys_nested": + a = list(td.keys(True)) + b = list(td.keys(True)) + assert a == b + elif op == "values": + a = list(td.values(True)) + b = list(td.values(True)) + assert all((_a == _b).all() for _a, _b in zip(a, b)) + elif op == "items": + keys_a, values_a = zip(*td.items(True)) + keys_b, values_b = zip(*td.items(True)) + assert all((_a == _b).all() for _a, _b in zip(values_a, values_b)) + assert keys_a == keys_b + elif op == "flatten": + a = td.flatten_keys() + b = td.flatten_keys() + if td_name not in ("td_h5",): + assert a is b + else: + assert a is not b + elif op == "unflatten": + a = td.unflatten_keys() + b = td.unflatten_keys() + if td_name not in ("td_h5",): + assert a is b + else: + assert a is not b - td_set[mask] = pseudo_td - for item in td.values(): - assert (item[mask] == 0).all() + if td_name != "td_params": + assert len(td._cache) + td.unlock_() + assert td._cache is None + for val in td.values(True): + if is_tensor_collection(val): + assert td._cache is None @pytest.mark.skipif( torch.cuda.device_count() == 0, reason="No cuda device detected" ) - @pytest.mark.parametrize("device_cast", [0, "cuda:0", torch.device("cuda:0")]) - def test_pin_memory(self, td_name, device_cast, device): + @pytest.mark.parametrize("device_cast", get_available_devices()) + def test_cast_device(self, td_name, device, device_cast): torch.manual_seed(1) td = getattr(self, td_name)(device) - td.unlock_() - if device.type == "cuda": - with pytest.raises(RuntimeError, match="cannot pin"): - td.pin_memory() - return - td.pin_memory() td_device = td.to(device_cast) - _device_cast = torch.device(device_cast) - assert td_device.device == _device_cast - assert td_device.clone().device == _device_cast - if device != _device_cast: - assert td_device is not td - for item in td_device.values(): - assert item.device == _device_cast - for item in td_device.clone().values(): - assert item.device == _device_cast - # assert type(td_device) is type(td) - assert_allclose_td(td, td_device.to(device)) - - def test_indexed_properties(self, td_name, device): - td = getattr(self, td_name)(device) - td_index = td[0] - assert td_index.is_memmap() is td.is_memmap() - assert td_index.is_shared() is td.is_shared() - assert td_index.device == td.device - - @pytest.mark.parametrize( - "idx", - [ - (..., None), - (None, ...), - (None,), - None, - (slice(None), None), - (0, None), - (None, slice(None), slice(None)), - (None, ..., None), - (None, 1, ..., None), - (1, ..., None), - (..., None, 0), - ([1], ..., None), - ], - ) - def test_index_none(self, td_name, device, idx): - td = getattr(self, td_name)(device) - tdnone = td[idx] - tensor = torch.zeros(td.shape) - assert tdnone.shape == tensor[idx].shape, idx - # Fixed by 451 - # if td_name == "td_h5": - # with pytest.raises(TypeError, match="can't process None"): - # assert (tdnone.to_tensordict() == td.to_tensordict()[idx]).all() - # return - assert (tdnone.to_tensordict() == td.to_tensordict()[idx]).all() - - @pytest.mark.skipif( - torch.cuda.device_count() == 0, reason="No cuda device detected" - ) - @pytest.mark.parametrize("device_cast", get_available_devices()) - def test_cast_device(self, td_name, device, device_cast): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td_device = td.to(device_cast) - + for item in td_device.values(): assert item.device == device_cast for item in td_device.clone().values(): @@ -1482,6 +1830,166 @@ def test_cast_device(self, td_name, device, device_cast): assert td.to(device) is td assert_allclose_td(td, td_device.to(device)) + def test_cast_to(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td_device = td.to("cpu:1") + assert td_device.device == torch.device("cpu:1") + td_dtype = td.to(torch.int) + assert all(t.dtype == torch.int for t in td_dtype.values(True, True)) + del td_dtype + # device (str), dtype + td_dtype_device = td.to("cpu:1", torch.int) + assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) + assert td_dtype_device.device == torch.device("cpu:1") + del td_dtype_device + # device, dtype + td_dtype_device = td.to(torch.device("cpu:1"), torch.int) + assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) + assert td_dtype_device.device == torch.device("cpu:1") + del td_dtype_device + # example tensor + td_dtype_device = td.to(torch.randn(3, dtype=torch.half, device="cpu:1")) + assert all(t.dtype == torch.half for t in td_dtype_device.values(True, True)) + # tensor on cpu:1 is actually on cpu. This is still meaningful for tensordicts on cuda. + assert td_dtype_device.device == torch.device("cpu") + del td_dtype_device + # example td + td_dtype_device = td.to( + other=TensorDict( + {"a": torch.randn(3, dtype=torch.half, device="cpu:1")}, + [], + device="cpu:1", + ) + ) + assert all(t.dtype == torch.half for t in td_dtype_device.values(True, True)) + assert td_dtype_device.device == torch.device("cpu:1") + del td_dtype_device + # example td, many dtypes + td_nodtype_device = td.to( + other=TensorDict( + { + "a": torch.randn(3, dtype=torch.half, device="cpu:1"), + "b": torch.randint(10, ()), + }, + [], + device="cpu:1", + ) + ) + assert all(t.dtype != torch.half for t in td_nodtype_device.values(True, True)) + assert td_nodtype_device.device == torch.device("cpu:1") + del td_nodtype_device + # batch-size: check errors (or not) + if td_name in ( + "stacked_td", + "unsqueezed_td", + "squeezed_td", + "permute_td", + "nested_stacked_td", + ): + with pytest.raises(TypeError, match="Cannot pass batch-size to a "): + td_dtype_device = td.to( + torch.device("cpu:1"), torch.int, batch_size=torch.Size([]) + ) + else: + td_dtype_device = td.to( + torch.device("cpu:1"), torch.int, batch_size=torch.Size([]) + ) + assert all(t.dtype == torch.int for t in td_dtype_device.values(True, True)) + assert td_dtype_device.device == torch.device("cpu:1") + assert td_dtype_device.batch_size == torch.Size([]) + del td_dtype_device + if td_name in ( + "stacked_td", + "unsqueezed_td", + "squeezed_td", + "permute_td", + "nested_stacked_td", + ): + with pytest.raises(TypeError, match="Cannot pass batch-size to a "): + td.to(batch_size=torch.Size([])) + else: + td_batchsize = td.to(batch_size=torch.Size([])) + assert td_batchsize.batch_size == torch.Size([]) + del td_batchsize + + def test_casts(self, td_name, device): + td = getattr(self, td_name)(device) + # exclude non-tensor data + is_leaf = lambda cls: issubclass(cls, torch.Tensor) + tdfloat = td.float() + assert all( + value.dtype is torch.float + for value in tdfloat.values(True, True, is_leaf=is_leaf) + ) + tddouble = td.double() + assert all( + value.dtype is torch.double + for value in tddouble.values(True, True, is_leaf=is_leaf) + ) + tdbfloat16 = td.bfloat16() + assert all( + value.dtype is torch.bfloat16 + for value in tdbfloat16.values(True, True, is_leaf=is_leaf) + ) + tdhalf = td.half() + assert all( + value.dtype is torch.half + for value in tdhalf.values(True, True, is_leaf=is_leaf) + ) + tdint = td.int() + assert all( + value.dtype is torch.int + for value in tdint.values(True, True, is_leaf=is_leaf) + ) + tdint = td.type(torch.int) + assert all( + value.dtype is torch.int + for value in tdint.values(True, True, is_leaf=is_leaf) + ) + + @pytest.mark.parametrize("dim", [0, 1]) + @pytest.mark.parametrize("chunks", [1, 2]) + def test_chunk(self, td_name, device, dim, chunks): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if len(td.shape) - 1 < dim: + pytest.mark.skip(f"no dim {dim} in td") + return + + chunks = min(td.shape[dim], chunks) + td_chunks = td.chunk(chunks, dim) + assert len(td_chunks) == chunks + assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] + assert (torch.cat(td_chunks, dim) == td).all() + + def test_clone_td(self, td_name, device, tmp_path): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td_name == "td_h5": + # need a new file + newfile = tmp_path / "file.h5" + clone = td.clone(newfile=newfile) + else: + clone = torch.clone(td) + assert (clone == td).all() + assert td.batch_size == clone.batch_size + assert type(td.clone(recurse=False)) is type(td) + if td_name in ( + "stacked_td", + "nested_stacked_td", + "saved_td", + "squeezed_td", + "unsqueezed_td", + "sub_td", + "sub_td2", + "permute_td", + "td_h5", + ): + assert td.clone(recurse=False).get("a") is not td.get("a") + else: + assert td.clone(recurse=False).get("a") is td.get("a") + @pytest.mark.skipif( torch.cuda.device_count() == 0, reason="No cuda device detected" ) @@ -1493,94 +2001,157 @@ def test_cpu_cuda(self, td_name, device): assert td_device.device == torch.device("cuda") assert td_back.device == torch.device("cpu") - def test_state_dict(self, td_name, device): - torch.manual_seed(1) + def test_create_nested(self, td_name, device): td = getattr(self, td_name)(device) - sd = td.state_dict() - td_zero = td.clone().detach().zero_() - td_zero.load_state_dict(sd) - assert_allclose_td(td, td_zero) + with td.unlock_(): + td.create_nested("root") + assert td.get("root").shape == td.shape + assert is_tensor_collection(td.get("root")) + td.create_nested(("some", "nested", "key")) - def test_state_dict_strict(self, td_name, device): + some = td.get("some") + nested = some.get("nested") + _ = nested.get("key") + assert td.get(("some", "nested", "key")).shape == td.shape + assert is_tensor_collection(td.get(("some", "nested", "key"))) + if td_name in ("sub_td", "sub_td2"): + return + with td.lock_(), pytest.raises(RuntimeError): + td.create_nested("root") + + def test_default_nested(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - sd = td.state_dict() - td_zero = td.clone().detach().zero_() - del sd["a"] - td_zero.load_state_dict(sd, strict=False) - with pytest.raises(RuntimeError): - td_zero.load_state_dict(sd, strict=True) + default_val = torch.randn(()) + timbers = td.get(("shiver", "my", "timbers"), default_val) + assert timbers == default_val - def test_state_dict_assign(self, td_name, device): + def test_delitem(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - sd = td.state_dict() - td_zero = td.clone().detach().zero_() - shallow_copy = td_zero.clone(False) - td_zero.load_state_dict(sd, assign=True) - assert (shallow_copy == 0).all() - assert_allclose_td(td, td_zero) - - @pytest.mark.parametrize("dim", range(4)) - def test_unbind(self, td_name, device, dim): - if td_name not in ["sub_td", "idx_td", "td_reset_bs"]: - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td_unbind = torch.unbind(td, dim=dim) - assert (td == stack_td(td_unbind, dim).contiguous()).all() - idx = (slice(None),) * dim + (0,) - assert (td[idx] == td_unbind[0]).all() + if td_name in ("memmap_td",): + with pytest.raises(RuntimeError, match="Cannot modify"): + del td["a"] + return + del td["a"] + assert "a" not in td.keys() - @pytest.mark.parametrize("squeeze_dim", [0, 1]) - def test_unsqueeze(self, td_name, device, squeeze_dim): + def test_empty_like(self, td_name, device): + if "sub_td" in td_name: + # we do not call skip to avoid systematic skips in internal code base + return + td = getattr(self, td_name)(device) + if isinstance(td, _CustomOpTensorDict): + # we do not call skip to avoid systematic skips in internal code base + return + td_empty = torch.empty_like(td) + + td.apply_(lambda x: x + 1.0) + assert type(td) is type(td_empty) + # exclude non tensor data + comp = td.filter_non_tensor_data() != td_empty.filter_non_tensor_data() + assert all(val.any() for val in comp.values(True, True)) + + def test_enter_exit(self, td_name, device): torch.manual_seed(1) + if td_name in ("sub_td", "sub_td2"): + return td = getattr(self, td_name)(device) - with td.unlock_(): # make sure that the td is not locked - td_unsqueeze = torch.unsqueeze(td, dim=squeeze_dim) - tensor = torch.ones_like(td.get("a").unsqueeze(squeeze_dim)) - if td_name in ("sub_td", "sub_td2"): - td_unsqueeze.set_("a", tensor) - else: - td_unsqueeze.set("a", tensor) - assert (td_unsqueeze.get("a") == tensor).all() - assert (td.get("a") == tensor.squeeze(squeeze_dim)).all() - # the tensors should match - assert _compare_tensors_identity(td_unsqueeze.squeeze(squeeze_dim), td) - assert (td_unsqueeze.get("a") == 1).all() - assert (td.get("a") == 1).all() + is_locked = td.is_locked + with td.lock_() as other: + assert other is td + assert td.is_locked + with td.unlock_() as other: + assert other is td + assert not td.is_locked + assert td.is_locked + assert td.is_locked is is_locked - def test_squeeze(self, td_name, device, squeeze_dim=-1): + def test_entry_type(self, td_name, device): + td = getattr(self, td_name)(device) + for key in td.keys(include_nested=True): + assert type(td.get(key)) is td.entry_class(key) + + def test_equal(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - with td.unlock_(): # make sure that the td is not locked - td_squeeze = torch.squeeze(td, dim=-1) - tensor_squeeze_dim = td.batch_dims + squeeze_dim - tensor = torch.ones_like(td.get("a").squeeze(tensor_squeeze_dim)) - if td_name in ("sub_td", "sub_td2"): - td_squeeze.set_("a", tensor) - else: - td_squeeze.set("a", tensor) - assert td.batch_size[squeeze_dim] == 1 - assert (td_squeeze.get("a") == tensor).all() - assert (td.get("a") == tensor.unsqueeze(tensor_squeeze_dim)).all() - if td_name != "unsqueezed_td": - assert _compare_tensors_identity(td_squeeze.unsqueeze(squeeze_dim), td) - else: - assert td_squeeze is td._source - assert (td_squeeze.get("a") == 1).all() - assert (td.get("a") == 1).all() + assert (td == td.to_tensordict()).all() + td0 = td.to_tensordict().zero_() + assert (td != td0).any() - def test_squeeze_with_none(self, td_name, device, squeeze_dim=None): + def test_equal_dict(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - td_squeeze = torch.squeeze(td, dim=None) - tensor = torch.ones_like(td.get("a").squeeze()) - td_squeeze.set_("a", tensor) - assert (td_squeeze.get("a") == tensor).all() - if td_name == "unsqueezed_td": - assert td_squeeze._source is td - assert (td_squeeze.get("a") == 1).all() - assert (td.get("a") == 1).all() + assert (td == td.to_dict()).all() + td0 = td.to_tensordict().zero_().to_dict() + assert (td != td0).any() + + def test_equal_float(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() + assert (td == 0.0).all() + td0 = td.clone() + if td_name == "td_params": + td_set = td0.data + else: + td_set = td0 + td_set.zero_() + assert (td0 != 1.0).all() + + def test_equal_int(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() + assert (td == 0).all() + td0 = td.to_tensordict().zero_() + assert (td0 != 1).all() + + def test_equal_other(self, td_name, device): + td = getattr(self, td_name)(device) + assert not td == "z" + assert td != "z" + + def test_equal_tensor(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.zero_() + assert (td == torch.zeros([], dtype=torch.int, device=device)).all() + td0 = td.to_tensordict().zero_() + assert (td0 != torch.ones([], dtype=torch.int, device=device)).all() + + def test_exclude(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td_name == "td_h5": + with pytest.raises(NotImplementedError, match="Cannot call exclude"): + _ = td.exclude("a") + return + td2 = td.exclude("a") + assert td2 is not td + assert ( + len(list(td2.keys())) == len(list(td.keys())) - 1 and "a" not in td2.keys() + ) + assert ( + len(list(td2.clone().keys())) == len(list(td.keys())) - 1 + and "a" not in td2.clone().keys() + ) + + with td.unlock_(): + td2 = td.exclude("a", inplace=True) + assert td2 is td @pytest.mark.parametrize("nested", [True, False]) def test_exclude_missing(self, td_name, device, nested): @@ -1630,158 +2201,219 @@ def test_exclude_nested(self, td_name, device, nested): # perhaps exclude should return an error in these cases? assert type(td2) is type(td) - @pytest.mark.parametrize("clone", [True, False]) - def test_update(self, td_name, device, clone): + def test_expand(self, td_name, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) - td.unlock_() # make sure that the td is not locked - keys = set(td.keys()) - td.update({"x": torch.zeros(td.shape)}, clone=clone) - assert set(td.keys()) == keys.union({"x"}) - # now with nested: using tuples for keys - td.update({("somenested", "z"): torch.zeros(td.shape)}) - assert td["somenested"].shape == td.shape - assert td["somenested", "z"].shape == td.shape - td.update({("somenested", "zz"): torch.zeros(td.shape)}) - assert td["somenested"].shape == td.shape - assert td["somenested", "zz"].shape == td.shape - # now with nested: using nested dicts - td["newnested"] = {"z": torch.zeros(td.shape)} - keys = set(td.keys(True)) - assert ("newnested", "z") in keys - td.update({"newnested": {"y": torch.zeros(td.shape)}}, clone=clone) - keys = keys.union({("newnested", "y")}) - assert keys == set(td.keys(True)) - td.update( - { - ("newnested", "x"): torch.zeros(td.shape), - ("newnested", "w"): torch.zeros(td.shape), - }, - clone=clone, - ) - keys = keys.union({("newnested", "x"), ("newnested", "w")}) - assert keys == set(td.keys(True)) - td.update({("newnested",): {"v": torch.zeros(td.shape)}}, clone=clone) - keys = keys.union( - { - ("newnested", "v"), - } - ) - assert keys == set(td.keys(True)) + batch_size = td.batch_size + expected_size = torch.Size([3, *batch_size]) - if td_name in ("sub_td", "sub_td2"): - with pytest.raises(ValueError, match="Tried to replace a tensordict with"): - td.update({"newnested": torch.zeros(td.shape)}, clone=clone) - else: - td.update({"newnested": torch.zeros(td.shape)}, clone=clone) - assert isinstance(td["newnested"], torch.Tensor) + new_td = td.expand(3, *batch_size) + assert new_td.batch_size == expected_size + assert all((_new_td == td).all() for _new_td in new_td) - def test_update_at_(self, td_name, device): - td = getattr(self, td_name)(device) - td0 = td[1].clone().zero_() - td.update_at_(td0, 0) - assert (td[0] == 0).all() + new_td_torch_size = td.expand(expected_size) + assert new_td_torch_size.batch_size == expected_size + assert all((_new_td == td).all() for _new_td in new_td_torch_size) - def test_write_on_subtd(self, td_name, device): + new_td_iterable = td.expand([3, *batch_size]) + assert new_td_iterable.batch_size == expected_size + assert all((_new_td == td).all() for _new_td in new_td_iterable) + + def test_fill_(self, td_name, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) - sub_td = td._get_sub_tensordict(0) - # should not work with td_params if td_name == "td_params": - with pytest.raises(RuntimeError, match="a view of a leaf"): - sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) - return - sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) - assert (td["a"][0] == 1).all() + td_set = td.data + else: + td_set = td + new_td = td_set.fill_("a", 0.1) + assert (td.get("a") == 0.1).all() + assert new_td is td_set - def test_pad(self, td_name, device): + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("separator", [",", "-"]) + def test_flatten_keys(self, td_name, device, inplace, separator): td = getattr(self, td_name)(device) - paddings = [ - [0, 1, 0, 2], - [1, 0, 0, 2], - [1, 0, 2, 1], - ] - - for pad_size in paddings: - padded_td = pad(td, pad_size) - padded_td._check_batch_size() - amount_expanded = [0] * (len(pad_size) // 2) - for i in range(0, len(pad_size), 2): - amount_expanded[i // 2] = pad_size[i] + pad_size[i + 1] - - for key in padded_td.keys(): - expected_dims = tuple( - sum(p) - for p in zip( - td[key].shape, - amount_expanded - + [0] * (len(td[key].shape) - len(amount_expanded)), - ) - ) - assert padded_td[key].shape == expected_dims + locked = td.is_locked + td.unlock_() + nested_nested_tensordict = TensorDict( + { + "a": torch.zeros(*td.shape, 2, 3), + }, + [*td.shape, 2], + ) + nested_tensordict = TensorDict( + { + "a": torch.zeros(*td.shape, 2), + "nested_nested_tensordict": nested_nested_tensordict, + }, + td.shape, + ) + td["nested_tensordict"] = nested_tensordict + if locked: + td.lock_() - with pytest.raises(RuntimeError): - pad(td, [0] * 100) + if inplace and locked: + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + return + elif td_name in ("td_h5",) and inplace: + with pytest.raises( + ValueError, + match="Cannot call flatten_keys in_place with a PersistentTensorDict", + ): + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + return + else: + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + for value in td_flatten.values(): + assert not isinstance(value, TensorDictBase) + assert ( + separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) + in td_flatten.keys() + ) + if inplace: + assert td_flatten is td + else: + assert td_flatten is not td - with pytest.raises(RuntimeError): - pad(td, [0]) + def test_flatten_unflatten(self, td_name, device): + td = getattr(self, td_name)(device) + shape = td.shape[:3] + td_flat = td.flatten(0, 2) + td_unflat = td_flat.unflatten(0, shape) + assert (td.to_tensordict() == td_unflat).all() + assert td.batch_size == td_unflat.batch_size - def test_reshape(self, td_name, device): + def test_flatten_unflatten_bis(self, td_name, device): td = getattr(self, td_name)(device) - td_reshape = td.reshape(td.shape) - # assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == td.shape - td_reshape = td.reshape(*td.shape) - # assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == td.shape - td_reshape = td.reshape(size=td.shape) - # assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == td.shape - td_reshape = td.reshape(-1) - assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == torch.Size([td.shape.numel()]) - td_reshape = td.reshape((-1,)) - assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == torch.Size([td.shape.numel()]) - td_reshape = td.reshape(size=(-1,)) - assert isinstance(td_reshape, TensorDict) - assert td_reshape.shape.numel() == td.shape.numel() - assert td_reshape.shape == torch.Size([td.shape.numel()]) + shape = td.shape[1:4] + td_flat = td.flatten(1, 3) + td_unflat = td_flat.unflatten(1, shape) + assert (td.to_tensordict() == td_unflat).all() + assert td.batch_size == td_unflat.batch_size - def test_view(self, td_name, device): - if td_name in ("permute_td", "sub_td2"): - pytest.skip("view incompatible with stride / permutation") + def test_from_empty(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - with td.unlock_(): # make sure that the td is not locked - td_view = td.view(-1) - tensor = td.get("a") - tensor = tensor.view(-1, tensor.numel() // prod(td.batch_size)) - tensor = torch.ones_like(tensor) - if td_name == "sub_td": - td_view.set_("a", tensor) - else: - td_view.set("a", tensor) - assert (td_view.get("a") == tensor).all() - assert (td.get("a") == tensor.view(td.get("a").shape)).all() - if td_name in ("td_params",): - assert td_view.view(td.shape)._param_td is td._param_td - assert td_view.view(*td.shape)._param_td is td._param_td - else: - assert td_view.view(td.shape) is td - assert td_view.view(*td.shape) is td - assert (td_view.get("a") == 1).all() - assert (td.get("a") == 1).all() + new_td = TensorDict({}, batch_size=td.batch_size, device=device) + for key, item in td.items(): + new_td.set(key, item) + assert_allclose_td(td, new_td) + assert td.device == new_td.device + assert td.shape == new_td.shape - def test_default_nested(self, td_name, device): + @pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2, -3]) + def test_gather(self, td_name, device, dim): torch.manual_seed(1) td = getattr(self, td_name)(device) - default_val = torch.randn(()) - timbers = td.get(("shiver", "my", "timbers"), default_val) - assert timbers == default_val + index = torch.ones(td.shape, device=td.device, dtype=torch.long) + other_dim = dim + index.ndim if dim < 0 else dim + idx = (*[slice(None) for _ in range(other_dim)], slice(2)) + index = index[idx] + index = index.cumsum(dim=other_dim) - 1 + # gather + td_gather = torch.gather(td, dim=dim, index=index) + # gather with out + td_gather.zero_() + out = td_gather.clone() + if td_name == "td_params": + with pytest.raises( + RuntimeError, match="don't support automatic differentiation" + ): + torch.gather(td, dim=dim, index=index, out=out) + return + td_gather2 = torch.gather(td, dim=dim, index=index, out=out) + assert (td_gather2 != 0).any() + + @pytest.mark.parametrize( + "actual_index,expected_index", + [ + (..., (slice(None),) * TD_BATCH_SIZE), + ((..., 0), (slice(None),) * (TD_BATCH_SIZE - 1) + (0,)), + ((0, ...), (0,) + (slice(None),) * (TD_BATCH_SIZE - 1)), + ((0, ..., 0), (0,) + (slice(None),) * (TD_BATCH_SIZE - 2) + (0,)), + ], + ) + def test_getitem_ellipsis(self, td_name, device, actual_index, expected_index): + torch.manual_seed(1) + + td = getattr(self, td_name)(device) + + actual_td = td[actual_index] + expected_td = td[expected_index] + other_expected_td = td.to_tensordict()[expected_index] + assert expected_td.shape == _getitem_batch_size( + td.batch_size, convert_ellipsis_to_idx(actual_index, td.batch_size) + ) + assert other_expected_td.shape == actual_td.shape + assert_allclose_td(actual_td, other_expected_td) + assert_allclose_td(actual_td, expected_td) + + def test_getitem_nestedtuple(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert isinstance(td[(("a",))], torch.Tensor) + assert isinstance(td.get((("a",))), torch.Tensor) + + def test_getitem_range(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert_allclose_td(td[range(2)], td[[0, 1]]) + if td_name not in ("td_h5",): + # for h5, we can't use a double list index + assert td[range(1), range(1)].shape == td[[0], [0]].shape + assert_allclose_td(td[range(1), range(1)], td[[0], [0]]) + assert_allclose_td(td[:, range(2)], td[:, [0, 1]]) + assert_allclose_td(td[..., range(1)], td[..., [0]]) + + if td_name in ("stacked_td", "nested_stacked_td"): + # this is a bit contrived, but want to check that if we pass something + # weird as the index to the stacking dimension we'll get the error + idx = (slice(None),) * td.stack_dim + ({1, 2, 3},) + with pytest.raises(TypeError, match="Invalid index"): + td[idx] + + def test_getitem_string(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert isinstance(td["a"], torch.Tensor) + + @pytest.mark.parametrize( + "idx", + [ + (..., None), + (None, ...), + (None,), + None, + (slice(None), None), + (0, None), + (None, slice(None), slice(None)), + (None, ..., None), + (None, 1, ..., None), + (1, ..., None), + (..., None, 0), + ([1], ..., None), + ], + ) + def test_index_none(self, td_name, device, idx): + td = getattr(self, td_name)(device) + tdnone = td[idx] + tensor = torch.zeros(td.shape) + assert tdnone.shape == tensor[idx].shape, idx + # Fixed by 451 + # if td_name == "td_h5": + # with pytest.raises(TypeError, match="can't process None"): + # assert (tdnone.to_tensordict() == td.to_tensordict()[idx]).all() + # return + assert (tdnone.to_tensordict() == td.to_tensordict()[idx]).all() + + def test_indexed_properties(self, td_name, device): + td = getattr(self, td_name)(device) + td_index = td[0] + assert td_index.is_memmap() is td.is_memmap() + assert td_index.is_shared() is td.is_shared() + assert td_index.device == td.device def test_inferred_view_size(self, td_name, device): if td_name in ("permute_td", "sub_td2"): @@ -1803,682 +2435,471 @@ def test_inferred_view_size(self, td_name, device): assert td.view(*new_shape) is td assert td.view(-1).view(*new_shape) is td - @pytest.mark.parametrize("dim", [0, 1, -1, -5]) - @pytest.mark.parametrize( - "key", ["heterogeneous-entry", ("sub", "heterogeneous-entry")] - ) - def test_nestedtensor_stack(self, td_name, device, dim, key): + def test_items_values_keys(self, td_name, device): torch.manual_seed(1) - td1 = getattr(self, td_name)(device).unlock_() - td2 = getattr(self, td_name)(device).unlock_() + td = getattr(self, td_name)(device) + td.unlock_() + keys = list(td.keys()) + values = list(td.values()) + items = list(td.items()) - td1[key] = torch.randn(*td1.shape, 2) - td2[key] = torch.randn(*td1.shape, 3) - td_stack = torch.stack([td1, td2], dim) - # get will fail - with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" - ): - td_stack.get(key) - with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" - ): - td_stack[key] - if dim in (0, -5): - # this will work if stack_dim is 0 (or equivalently -self.batch_dims) - # it is the proper way to get that entry - td_stack.get_nestedtensor(key) - else: - # if the stack_dim is not zero, then calling get_nestedtensor is disallowed - with pytest.raises( - RuntimeError, - match="LazyStackedTensorDict.get_nestedtensor can only be called " - "when the stack_dim is 0.", - ): - td_stack.get_nestedtensor(key) - with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" - ): - td_stack.contiguous() - with pytest.raises( - RuntimeError, match="Found more than one unique shape in the tensors" - ): - td_stack.to_tensordict() - # cloning is type-preserving: we can do that operation - td_stack.clone() + # Test td.items() + constructed_td1 = TensorDict({}, batch_size=td.shape) + for key, value in items: + constructed_td1.set(key, value) - def test_clone_td(self, td_name, device, tmp_path): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td_name == "td_h5": - # need a new file - newfile = tmp_path / "file.h5" - clone = td.clone(newfile=newfile) - else: - clone = torch.clone(td) - assert (clone == td).all() - assert td.batch_size == clone.batch_size - assert type(td.clone(recurse=False)) is type(td) - if td_name in ( - "stacked_td", - "nested_stacked_td", - "saved_td", - "squeezed_td", - "unsqueezed_td", - "sub_td", - "sub_td2", - "permute_td", - "td_h5", - ): - assert td.clone(recurse=False).get("a") is not td.get("a") - else: - assert td.clone(recurse=False).get("a") is td.get("a") + assert (td == constructed_td1).all() - def test_rename_key(self, td_name, device) -> None: - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if td.is_locked: - with pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)): - td.rename_key_("a", "b", safe=True) - else: - with pytest.raises(KeyError, match="already present in TensorDict"): - td.rename_key_("a", "b", safe=True) + # Test td.keys() and td.values() + # items = [key, value] should be verified + assert len(values) == len(items) + assert len(keys) == len(items) + constructed_td2 = TensorDict({}, batch_size=td.shape) + for key, value in list(zip(td.keys(), td.values())): + constructed_td2.set(key, value) + + assert (td == constructed_td2).all() + + # Test that keys is sorted + assert all(keys[i] <= keys[i + 1] for i in range(len(keys) - 1)) + + # Add new element to tensor a = td.get("a") - if td.is_locked: - with pytest.raises(RuntimeError, match="Cannot modify"): - td.rename_key_("a", "z") - return - else: - td.rename_key_("a", "z") - with pytest.raises(KeyError): - td.get("a") - assert "a" not in td.keys() + td.set("x", torch.randn_like(a)) + keys = list(td.keys()) + values = list(td.values()) + items = list(td.items()) - z = td.get("z") - torch.testing.assert_close(a, z) + # Test that keys is still sorted after adding the element + assert all(keys[i] <= keys[i + 1] for i in range(len(keys) - 1)) - new_z = torch.randn_like(z) - if td_name in ("sub_td", "sub_td2"): - td.set_("z", new_z) - else: - td.set("z", new_z) + # Test td.items() + # after adding the new element + constructed_td1 = TensorDict({}, batch_size=td.shape) + for key, value in items: + constructed_td1.set(key, value) - torch.testing.assert_close(new_z, td.get("z")) + assert (td == constructed_td1).all() - new_z = torch.randn_like(z) - if td_name == "td_params": - td.data.set_("z", new_z) - else: - td.set_("z", new_z) - torch.testing.assert_close(new_z, td.get("z")) + # Test td.keys() and td.values() + # items = [key, value] should be verified + # even after adding the new element + assert len(values) == len(items) + assert len(keys) == len(items) - def test_rename_key_nested(self, td_name, device) -> None: - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td.unlock_() - td["nested", "conflict"] = torch.zeros(td.shape) - with pytest.raises(KeyError, match="already present in TensorDict"): - td.rename_key_(("nested", "conflict"), "b", safe=True) - td["nested", "first"] = torch.zeros(td.shape) - td.rename_key_(("nested", "first"), "second") - assert (td["second"] == 0).all() - assert ("nested", "first") not in td.keys(True) - td.rename_key_("second", ("nested", "back")) - assert (td[("nested", "back")] == 0).all() - assert "second" not in td.keys() + constructed_td2 = TensorDict({}, batch_size=td.shape) + for key, value in list(zip(td.keys(), td.values())): + constructed_td2.set(key, value) - def test_set_nontensor(self, td_name, device): - torch.manual_seed(1) + assert (td == constructed_td2).all() + + def test_lock(self, td_name, device): td = getattr(self, td_name)(device) + is_locked = td.is_locked + for item in td.values(): + if isinstance(item, TensorDictBase): + assert item.is_locked == is_locked + if isinstance(td, _SubTensorDict): + with pytest.raises(RuntimeError, match="the parent tensordict instead"): + td.is_locked = not is_locked + return + td.is_locked = not is_locked + assert td.is_locked != is_locked + for _, item in td.items(): + if isinstance(item, TensorDictBase): + assert item.is_locked != is_locked + td.lock_() + assert td.is_locked + for _, item in td.items(): + if isinstance(item, TensorDictBase): + assert item.is_locked td.unlock_() - r = torch.randn_like(td.get("a")) - td.set("numpy", r.cpu().numpy()) - torch.testing.assert_close(td.get("numpy"), r) + assert not td.is_locked + for _, item in td.items(): + if isinstance(item, TensorDictBase): + assert not item.is_locked - @pytest.mark.parametrize( - "actual_index,expected_index", - [ - (..., (slice(None),) * TD_BATCH_SIZE), - ((..., 0), (slice(None),) * (TD_BATCH_SIZE - 1) + (0,)), - ((0, ...), (0,) + (slice(None),) * (TD_BATCH_SIZE - 1)), - ((0, ..., 0), (0,) + (slice(None),) * (TD_BATCH_SIZE - 2) + (0,)), - ], - ) - def test_getitem_ellipsis(self, td_name, device, actual_index, expected_index): + def test_lock_change_names(self, td_name, device): torch.manual_seed(1) - td = getattr(self, td_name)(device) + try: + td.names = [str(i) for i in range(td.ndim)] + td.lock_() + except Exception: + return + # cache values + list(td.values(True)) + td.names = [str(-i) for i in range(td.ndim)] + for val in td.values(True): + if not is_tensor_collection(val): + continue + assert val.names[: td.ndim] == [str(-i) for i in range(td.ndim)] - actual_td = td[actual_index] - expected_td = td[expected_index] - other_expected_td = td.to_tensordict()[expected_index] - assert expected_td.shape == _getitem_batch_size( - td.batch_size, convert_ellipsis_to_idx(actual_index, td.batch_size) - ) - assert other_expected_td.shape == actual_td.shape - assert_allclose_td(actual_td, other_expected_td) - assert_allclose_td(actual_td, expected_td) - - @pytest.mark.parametrize("actual_index", [..., (..., 0), (0, ...), (0, ..., 0)]) - def test_setitem_ellipsis(self, td_name, device, actual_index): - torch.manual_seed(1) + def test_lock_nested(self, td_name, device): td = getattr(self, td_name)(device) - - idx = actual_index - td_clone = td.clone() - actual_td = td_clone[idx].clone() - if td_name in ("td_params",): - td_set = actual_td.apply(lambda x: x.data) - else: - td_set = actual_td - td_set.zero_() - - for key in actual_td.keys(): - assert (actual_td.get(key) == 0).all() - - if td_name in ("td_params",): - td_set = td_clone.data + if td_name in ("sub_td", "sub_td2") and td.is_locked: + with pytest.raises(RuntimeError, match="Cannot unlock"): + td.unlock_() else: - td_set = td_clone - - td_set[idx] = actual_td - for key in td_clone.keys(): - assert (td_clone[idx].get(key) == 0).all() + td.unlock_() + td.set(("some", "nested"), torch.zeros(td.shape)) + if td_name in ("sub_td", "sub_td2") and not td.is_locked: + with pytest.raises(RuntimeError, match="Cannot lock"): + td.lock_() + return + td.lock_() + some = td.get("some") + assert some.is_locked + with pytest.raises(RuntimeError): + some.unlock_() + # this assumes that td is out of scope after the call to del. + # an error in unlock_() is likely due to td leaving a trace somewhere. + del td + gc.collect() + some.unlock_() - @pytest.mark.parametrize( - "idx", [slice(1), torch.tensor([0]), torch.tensor([0, 1]), range(1), range(2)] - ) - def test_setitem(self, td_name, device, idx): - torch.manual_seed(1) + def test_lock_write(self, td_name, device): td = getattr(self, td_name)(device) - if isinstance(idx, torch.Tensor) and idx.numel() > 1 and td.shape[0] == 1: - pytest.mark.skip("cannot index tensor with desired index") + if isinstance(td, _SubTensorDict): + with pytest.raises(RuntimeError, match="the parent tensordict instead"): + td.lock_() return - td_clone = td[idx].to_tensordict().zero_() - if td_name == "td_params": - td.data[idx] = td_clone + td.lock_() + td_clone = td.clone() + assert not td_clone.is_locked + td_clone = td.to_tensordict() + assert not td_clone.is_locked + assert td.is_locked + if td_name == "td_h5": + td.unlock_() + for key in list(td.keys()): + del td[key] + td.lock_() else: - td[idx] = td_clone - assert (td[idx].get("a") == 0).all() - - td_clone = torch.cat([td_clone, td_clone], 0) - with pytest.raises( - RuntimeError, - match=r"differs from the source batch size|batch dimension mismatch|Cannot broadcast the tensordict", - ): - td[idx] = td_clone - - def test_setitem_string(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) + with td.unlock_() if td.is_locked else contextlib.nullcontext(): + td = td.select(inplace=True) + for key, item in td_clone.items(True): + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td.set(key, item) td.unlock_() - td["d"] = torch.randn(4, 3, 2, 1, 5) - assert "d" in td.keys() - - def test_getitem_string(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - assert isinstance(td["a"], torch.Tensor) + for key, item in td_clone.items(True): + td.set(key, item) + td.lock_() + for key, item in td_clone.items(True): + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td.set(key, item) + if td_name == "td_params": + td_set = td.data + else: + td_set = td + td_set.set_(key, item) - def test_getitem_nestedtuple(self, td_name, device): + def test_masked_fill(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert isinstance(td[(("a",))], torch.Tensor) - assert isinstance(td.get((("a",))), torch.Tensor) + mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() + new_td = td.masked_fill(mask, -10.0) + assert new_td is not td + for item in new_td.values(): + assert (item[mask] == -10).all() - def test_setitem_nestedtuple(self, td_name, device): + def test_masked_fill_(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td.is_locked: - td.unlock_() - td[" a ", (("little", "story")), "about", ("myself",)] = torch.zeros(td.shape) - assert (td[" a ", "little", "story", "about", "myself"] == 0).all() + mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() + if td_name == "td_params": + td_set = td.data + else: + td_set = td + new_td = td_set.masked_fill_(mask, -10.0) + assert new_td is td_set + for item in td.values(): + assert (item[mask] == -10).all(), item[mask] - def test_getitem_range(self, td_name, device): + def test_masking(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert_allclose_td(td[range(2)], td[[0, 1]]) - if td_name not in ("td_h5",): - # for h5, we can't use a double list index - assert td[range(1), range(1)].shape == td[[0], [0]].shape - assert_allclose_td(td[range(1), range(1)], td[[0], [0]]) - assert_allclose_td(td[:, range(2)], td[:, [0, 1]]) - assert_allclose_td(td[..., range(1)], td[..., [0]]) - - if td_name in ("stacked_td", "nested_stacked_td"): - # this is a bit contrived, but want to check that if we pass something - # weird as the index to the stacking dimension we'll get the error - idx = (slice(None),) * td.stack_dim + ({1, 2, 3},) - with pytest.raises(TypeError, match="Invalid index"): - td[idx] + while True: + mask = torch.zeros( + td.batch_size, dtype=torch.bool, device=device + ).bernoulli_(0.8) + if not mask.all() and mask.any(): + break + td_masked = td[mask] + td_masked2 = torch.masked_select(td, mask) + assert_allclose_td(td_masked, td_masked2) + assert td_masked.batch_size[0] == mask.sum() + assert td_masked.batch_dims == 1 - def test_setitem_nested_dict_value(self, td_name, device): + def test_masking_set(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - - # Create equivalent TensorDict and dict nested values for setitem - nested_dict_value = {"e": torch.randn(4, 3, 2, 1, 10)} - nested_tensordict_value = TensorDict( - nested_dict_value, batch_size=td.batch_size, device=device + mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_( + 0.8 + ) + n = mask.sum() + d = td.ndimension() + pseudo_td = td.apply( + lambda item: torch.zeros( + (n, *item.shape[d:]), dtype=item.dtype, device=device + ), + batch_size=[n, *td.batch_size[d:]], ) - td_clone1 = td.clone(recurse=True) - td_clone2 = td.clone(recurse=True) - - td_clone1["d"] = nested_dict_value - td_clone2["d"] = nested_tensordict_value - assert (td_clone1 == td_clone2).all() - def test_transpose(self, td_name, device): - td = getattr(self, td_name)(device) - tdt = td.transpose(0, 1) - assert tdt.shape == torch.Size([td.shape[1], td.shape[0], *td.shape[2:]]) - for key, value in tdt.items(True): - assert value.shape == torch.Size( - [td.get(key).shape[1], td.get(key).shape[0], *td.get(key).shape[2:]] - ) - tdt = td.transpose(-1, -2) - for key, value in tdt.items(True): - assert value.shape == td.get(key).transpose(2, 3).shape - if td_name in ("td_params",): - assert tdt.transpose(-1, -2)._param_td is td._param_td - else: - assert tdt.transpose(-1, -2) is td - with td.unlock_(): - tdt.set(("some", "transposed", "tensor"), torch.zeros(tdt.shape)) - assert td.get(("some", "transposed", "tensor")).shape == td.shape - if td_name in ("td_params",): - assert td.transpose(0, 0)._param_td is td._param_td + if td_name == "td_params": + td_set = td.data else: - assert td.transpose(0, 0) is td - with pytest.raises( - ValueError, match="The provided dimensions are incompatible" - ): - td.transpose(-5, -6) - with pytest.raises( - ValueError, match="The provided dimensions are incompatible" - ): - tdt.transpose(-5, -6) + td_set = td - def test_create_nested(self, td_name, device): - td = getattr(self, td_name)(device) - with td.unlock_(): - td.create_nested("root") - assert td.get("root").shape == td.shape - assert is_tensor_collection(td.get("root")) - td.create_nested(("some", "nested", "key")) + td_set[mask] = pseudo_td + for item in td.values(): + assert (item[mask] == 0).all() - some = td.get("some") - nested = some.get("nested") - _ = nested.get("key") - assert td.get(("some", "nested", "key")).shape == td.shape - assert is_tensor_collection(td.get(("some", "nested", "key"))) + @pytest.mark.parametrize("use_dir", [True, False]) + @pytest.mark.parametrize("num_threads", [0, 2]) + def test_memmap_(self, td_name, device, use_dir, tmpdir, num_threads): + td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): + with pytest.raises( + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", + ): + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) return - with td.lock_(), pytest.raises(RuntimeError): - td.create_nested("root") + elif td_name in ("td_h5", "td_params"): + with pytest.raises( + RuntimeError, + match="Cannot build a memmap TensorDict in-place", + ): + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + return + else: + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + assert td.is_memmap(), (td, td._is_memmap) + if use_dir: + assert_allclose_td(TensorDict.load_memmap(tmpdir), td) - def test_tensordict_set(self, td_name, device): - torch.manual_seed(1) - np.random.seed(1) - td = getattr(self, td_name)(device) - td.unlock_() + @pytest.mark.parametrize("copy_existing", [False, True]) + def test_memmap_existing(self, td_name, device, copy_existing, tmp_path): + if td_name == "memmap_td": + pytest.skip( + "Memmap case is redundant, functionality checked by other cases" + ) + elif td_name in ("sub_td", "sub_td2", "td_h5", "td_params"): + pytest.skip( + "_SubTensorDict/H5 and memmap_ incompatibility is checked elsewhere" + ) - # test set - val1 = np.ones(shape=(4, 3, 2, 1, 10)) - td.set("key1", val1) - assert (td.get("key1") == 1).all() - with pytest.raises(RuntimeError): - td.set("key1", np.ones(shape=(5, 10))) + td = getattr(self, td_name)(device).memmap_(prefix=tmp_path / "tensordict") + td2 = getattr(self, td_name)(device).memmap_() - # test set_ - val2 = np.zeros(shape=(4, 3, 2, 1, 10)) - td.set_("key1", val2) - assert (td.get("key1") == 0).all() - if td_name not in ("stacked_td", "nested_stacked_td"): - err_msg = r"key.*smartypants.*not found in " - elif td_name in ("td_h5",): - err_msg = "Unable to open object" + if copy_existing: + td3 = td.memmap_(prefix=tmp_path / "tensordict2", copy_existing=True) + assert (td == td3).all() else: - err_msg = "setting a value in-place on a stack of TensorDict" + with pytest.raises( + RuntimeError, + match="A filename was provided but the tensor already has a file associated", + ): + # calling memmap_ with prefix that is different to contents gives error + td.memmap_(prefix=tmp_path / "tensordict2") - with pytest.raises(KeyError, match=err_msg): - td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) + # calling memmap_ without prefix means no-op, regardless of whether contents + # were saved in temporary or designated location (td vs. td2 resp.) + td3 = td.memmap_() + td4 = td2.memmap_() - # test set_at_ - td.set("key2", np.random.randn(4, 3, 2, 1, 5)) - x = np.ones(shape=(2, 1, 5)) * 42 - td.set_at_("key2", x, (2, 2)) - assert (td.get("key2")[2, 2] == 42).all() + if td_name in ("stacked_td", "nested_stacked_td"): + assert all( + all( + td3_[key] is value + for key, value in td_.items( + include_nested=True, leaves_only=True + ) + ) + for td_, td3_ in zip(td.tensordicts, td3.tensordicts) + ) + assert all( + all( + td4_[key] is value + for key, value in td2_.items( + include_nested=True, leaves_only=True + ) + ) + for td2_, td4_ in zip(td2.tensordicts, td4.tensordicts) + ) + elif td_name in ("permute_td", "squeezed_td", "unsqueezed_td"): + assert all( + td3._source[key] is value + for key, value in td._source.items( + include_nested=True, leaves_only=True + ) + ) + assert all( + td4._source[key] is value + for key, value in td2._source.items( + include_nested=True, leaves_only=True + ) + ) + else: + assert all( + td3[key] is value + for key, value in td.items(include_nested=True, leaves_only=True) + ) + assert all( + td4[key] is value + for key, value in td2.items(include_nested=True, leaves_only=True) + ) - def test_tensordict_set_dict_value(self, td_name, device): - torch.manual_seed(1) - np.random.seed(1) + @pytest.mark.parametrize("use_dir", [True, False]) + @pytest.mark.parametrize("num_threads", [0, 2]) + def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): td = getattr(self, td_name)(device) - td.unlock_() + tdmemmap = td.memmap_like( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + assert tdmemmap is not td + for key in td.keys(True): + assert td[key] is not tdmemmap[key] + assert (tdmemmap == 0).all() - # test set - val1 = {"subkey1": torch.ones(4, 3, 2, 1, 10)} - td.set("key1", val1) - assert (td.get("key1").get("subkey1") == 1).all() - with pytest.raises(RuntimeError): - td.set("key1", torch.ones(5, 10)) + def test_memmap_prefix(self, td_name, device, tmp_path): + if td_name == "memmap_td": + pytest.skip( + "Memmap case is redundant, functionality checked by other cases" + ) - # test set_ - val2 = {"subkey1": torch.zeros(4, 3, 2, 1, 10)} - if td_name in ("td_params",): - td.data.set_("key1", val2) + td = getattr(self, td_name)(device) + if td_name in ("sub_td", "sub_td2"): + with pytest.raises( + RuntimeError, + match="Converting a sub-tensordict values to memmap cannot be done", + ): + td.memmap_(tmp_path / "tensordict") + return + elif td_name in ("td_h5", "td_params"): + with pytest.raises( + RuntimeError, + match="Cannot build a memmap TensorDict in-place", + ): + td.memmap_(tmp_path / "tensordict") + return else: - td.set_("key1", val2) - assert (td.get("key1").get("subkey1") == 0).all() - + td.memmap_(tmp_path / "tensordict") if td_name not in ("stacked_td", "nested_stacked_td"): - err_msg = r"key.*smartypants.*not found in " - elif td_name in ("td_h5",): - err_msg = "Unable to open object" + jsonpath = tmp_path / "tensordict" / "meta.json" else: - err_msg = "setting a value in-place on a stack of TensorDict" + jsonpath = tmp_path / "tensordict" / "0" / "meta.json" + assert jsonpath.exists(), td + with open(jsonpath) as file: + metadata = json.load(file) + if td_name in ("stacked_td", "nested_stacked_td"): + assert metadata["shape"] == list(td.tensordicts[0].batch_size) + else: + assert metadata["shape"] == list(td.batch_size) - with pytest.raises(KeyError, match=err_msg): - td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) + td2 = td.load_memmap(tmp_path / "tensordict") + assert (td.cpu() == td2.cpu()).all() - def test_delitem(self, td_name, device): - torch.manual_seed(1) + @pytest.mark.parametrize("use_dir", [True, False]) + @pytest.mark.parametrize("num_threads", [2]) + def test_memmap_threads(self, td_name, device, use_dir, tmpdir, num_threads): td = getattr(self, td_name)(device) - if td_name in ("memmap_td",): - with pytest.raises(RuntimeError, match="Cannot modify"): - del td["a"] - return - del td["a"] - assert "a" not in td.keys() + tdmmap = td.memmap( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + tdfuture = td.memmap( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + return_early=True, + ) + assert_allclose_td(td.cpu().detach(), tdmmap) + assert_allclose_td(td.cpu().detach(), tdfuture.result()) - def test_to_dict_nested(self, td_name, device): - def recursive_checker(cur_dict): - for _, value in cur_dict.items(): - if is_tensor_collection(value): - return False - elif isinstance(value, dict) and not recursive_checker(value): - return False - return True + @pytest.mark.parametrize("inplace", [False, True]) + def test_named_apply(self, td_name, device, inplace): + td = getattr(self, td_name)(device) + td_c = td.to_tensordict() + + def named_plus(name, x): + if "a" in name: + return x + 1 + + if inplace and td_name == "td_params": + with pytest.raises(ValueError, match="Failed to update"): + td.named_apply(named_plus, inplace=inplace) + return + td_1 = td.named_apply(named_plus, inplace=inplace) + if inplace: + assert td_1 is td + for key in td_1.keys(True, True): + if "a" in key: + assert (td_c[key] + 1 == td_1[key]).all() + else: + assert (td_c[key] == td_1[key]).all() + assert (td_1[key] == td[key]).all() + else: + for key in td_1.keys(True, True): + assert "a" in key + assert (td_c[key] + 1 != td[key]).any() + assert (td_1[key] == td[key] + 1).all() + def test_nested_dict_init(self, td_name, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) td.unlock_() - # Create nested TensorDict + # Create TensorDict and dict equivalent values, and populate each with according nested value + td_clone = td.clone(recurse=True) + td_dict = td.to_dict() + nested_dict_value = {"e": torch.randn(4, 3, 2, 1, 10)} nested_tensordict_value = TensorDict( - {"e": torch.randn(4, 3, 2, 1, 10)}, batch_size=td.batch_size, device=device + nested_dict_value, batch_size=td.batch_size, device=device ) - td["d"] = nested_tensordict_value + td_dict["d"] = nested_dict_value + td_clone["d"] = nested_tensordict_value - # Convert into dictionary and recursively check if the values are TensorDicts - td_dict = td.to_dict() - assert recursive_checker(td_dict) - if td_name == "td_with_non_tensor": - assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict) == td).all() + # Re-init new TensorDict from dict, and check if they're equal + td_dict_init = TensorDict(td_dict, batch_size=td.batch_size, device=device) - @pytest.mark.parametrize( - "index", ["tensor1", "mask", "int", "range", "tensor2", "slice_tensor"] - ) - def test_update_subtensordict(self, td_name, device, index): + assert (td_clone == td_dict_init).all() + + def test_nested_td(self, td_name, device): td = getattr(self, td_name)(device) - if index == "mask": - index = torch.zeros(td.shape[0], dtype=torch.bool, device=device) - index[-1] = 1 - elif index == "int": - index = td.shape[0] - 1 - elif index == "range": - index = range(td.shape[0] - 1, td.shape[0]) - elif index == "tensor1": - index = torch.tensor(td.shape[0] - 1, device=device) - elif index == "tensor2": - index = torch.tensor([td.shape[0] - 2, td.shape[0] - 1], device=device) - elif index == "slice_tensor": - index = ( - slice(None), - torch.tensor([td.shape[1] - 2, td.shape[1] - 1], device=device), - ) + td.unlock_() + tdin = TensorDict({"inner": torch.randn(td.shape)}, td.shape, device=device) + td.set("inner_td", tdin) + assert (td["inner_td"] == tdin).all() - sub_td = td._get_sub_tensordict(index) - assert sub_td.shape == td.to_tensordict()[index].shape - assert sub_td.shape == td[index].shape, (td, index) - td0 = td[index] - td0 = td0.to_tensordict() - td0 = td0.apply(lambda x: x * 0 + 2) - assert sub_td.shape == td0.shape - if td_name == "td_params": - with pytest.raises(RuntimeError, match="a leaf Variable"): - sub_td.update(td0) - return - sub_td.update(td0) - assert (sub_td == 2).all() - assert (td[index] == 2).all() + def test_nested_td_emptyshape(self, td_name, device): + td = getattr(self, td_name)(device) + td.unlock_() + tdin = TensorDict({"inner": torch.randn(*td.shape, 1)}, [], device=device) + td["inner_td"] = tdin + tdin.batch_size = td.batch_size + assert (td["inner_td"] == tdin).all() - @pytest.mark.filterwarnings("error") - def test_stack_onto(self, td_name, device, tmpdir): - torch.manual_seed(1) + def test_nested_td_index(self, td_name, device): td = getattr(self, td_name)(device) - if td_name == "td_h5": - td0 = td.clone(newfile=tmpdir / "file0.h5").apply_(lambda x: x.zero_()) - td1 = td.clone(newfile=tmpdir / "file1.h5").apply_(lambda x: x.zero_() + 1) - else: - td0 = td.clone() - if td_name in ("td_params",): - td0.data.apply_(lambda x: x.zero_()) - else: - td0.apply_(lambda x: x.zero_()) - td1 = td.clone() - if td_name in ("td_params",): - td1.data.apply_(lambda x: x.zero_() + 1) - else: - td1.apply_(lambda x: x.zero_() + 1) - - td_out = td.unsqueeze(1).expand(td.shape[0], 2, *td.shape[1:]).clone() - td_stack = torch.stack([td0, td1], 1) - if td_name == "td_params": - with pytest.raises(RuntimeError, match="out.batch_size and stacked"): - torch.stack([td0, td1], 0, out=td_out) - return - data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} - torch.stack([td0, td1], 1, out=td_out) - data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} - assert data_ptr_set_before == data_ptr_set_after - assert (td_stack == td_out).all() - - @pytest.mark.filterwarnings("error") - def test_stack_tds_on_subclass(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - tds_count = td.batch_size[0] - tds_batch_size = td.batch_size[1:] - tds_list = [ - TensorDict( - source={ - "a": torch.ones(*tds_batch_size, 5), - "b": torch.ones(*tds_batch_size, 10), - "c": torch.ones(*tds_batch_size, 3, dtype=torch.long), - }, - batch_size=tds_batch_size, - device=device, - ) - for _ in range(tds_count) - ] - if td_name in ("sub_td", "sub_td2"): - with pytest.raises(IndexError, match="storages of the indexed tensors"): - torch.stack(tds_list, 0, out=td) - return - data_ptr_set_before = {val.data_ptr() for val in decompose(td)} - - stacked_td = torch.stack(tds_list, 0, out=td) - data_ptr_set_after = {val.data_ptr() for val in decompose(td)} - assert data_ptr_set_before == data_ptr_set_after - assert stacked_td.batch_size == td.batch_size - assert stacked_td is td - for key in ("a", "b", "c"): - assert (stacked_td[key] == 1).all() - - @pytest.mark.filterwarnings("error") - def test_stack_subclasses_on_td(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td = td.expand(3, *td.batch_size).clone().zero_() - tds_list = [getattr(self, td_name)(device) for _ in range(3)] - if td_name == "td_params": - with pytest.raises(RuntimeError, match="arguments don't support automatic"): - torch.stack(tds_list, 0, out=td) - return - data_ptr_set_before = {val.data_ptr() for val in decompose(td)} - stacked_td = stack_td(tds_list, 0, out=td) - data_ptr_set_after = {val.data_ptr() for val in decompose(td)} - assert data_ptr_set_before == data_ptr_set_after - assert stacked_td.batch_size == td.batch_size - for key in ("a", "b", "c"): - assert (stacked_td[key] == td[key]).all() - - @pytest.mark.parametrize("dim", [0, 1]) - @pytest.mark.parametrize("chunks", [1, 2]) - def test_chunk(self, td_name, device, dim, chunks): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - if len(td.shape) - 1 < dim: - pytest.mark.skip(f"no dim {dim} in td") - return - - chunks = min(td.shape[dim], chunks) - td_chunks = td.chunk(chunks, dim) - assert len(td_chunks) == chunks - assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim] - assert (torch.cat(td_chunks, dim) == td).all() - - def test_as_tensor(self, td_name, device): - td = getattr(self, td_name)(device) - if "memmap" in td_name and device == torch.device("cpu"): - tdt = td.as_tensor() - assert (tdt == td).all() - elif "memmap" in td_name: - with pytest.raises( - RuntimeError, match="can only be called with MemoryMappedTensors stored" - ): - td.as_tensor() - else: - # checks that it runs - td.as_tensor() - - def test_items_values_keys(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td.unlock_() - keys = list(td.keys()) - values = list(td.values()) - items = list(td.items()) - - # Test td.items() - constructed_td1 = TensorDict({}, batch_size=td.shape) - for key, value in items: - constructed_td1.set(key, value) - - assert (td == constructed_td1).all() - - # Test td.keys() and td.values() - # items = [key, value] should be verified - assert len(values) == len(items) - assert len(keys) == len(items) - constructed_td2 = TensorDict({}, batch_size=td.shape) - for key, value in list(zip(td.keys(), td.values())): - constructed_td2.set(key, value) - - assert (td == constructed_td2).all() - - # Test that keys is sorted - assert all(keys[i] <= keys[i + 1] for i in range(len(keys) - 1)) - - # Add new element to tensor - a = td.get("a") - td.set("x", torch.randn_like(a)) - keys = list(td.keys()) - values = list(td.values()) - items = list(td.items()) - - # Test that keys is still sorted after adding the element - assert all(keys[i] <= keys[i + 1] for i in range(len(keys) - 1)) - - # Test td.items() - # after adding the new element - constructed_td1 = TensorDict({}, batch_size=td.shape) - for key, value in items: - constructed_td1.set(key, value) - - assert (td == constructed_td1).all() - - # Test td.keys() and td.values() - # items = [key, value] should be verified - # even after adding the new element - assert len(values) == len(items) - assert len(keys) == len(items) - - constructed_td2 = TensorDict({}, batch_size=td.shape) - for key, value in list(zip(td.keys(), td.values())): - constructed_td2.set(key, value) - - assert (td == constructed_td2).all() - - def test_set_requires_grad(self, td_name, device): - td = getattr(self, td_name)(device) - if td_name in ("td_params",): - td.apply(lambda x: x.requires_grad_(False)) - td.unlock_() - assert not td.get("a").requires_grad - if td_name in ("td_h5",): - with pytest.raises( - RuntimeError, match="Cannot set a tensor that has requires_grad=True" - ): - td.set("a", torch.randn_like(td.get("a")).requires_grad_()) - return - if td_name in ("sub_td", "sub_td2"): - td.set_("a", torch.randn_like(td.get("a")).requires_grad_()) - else: - td.set("a", torch.randn_like(td.get("a")).requires_grad_()) - - assert td.get("a").requires_grad - - def test_nested_td_emptyshape(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() - tdin = TensorDict({"inner": torch.randn(*td.shape, 1)}, [], device=device) - td["inner_td"] = tdin - tdin.batch_size = td.batch_size - assert (td["inner_td"] == tdin).all() - - def test_nested_td(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() - tdin = TensorDict({"inner": torch.randn(td.shape)}, td.shape, device=device) - td.set("inner_td", tdin) - assert (td["inner_td"] == tdin).all() - - def test_nested_dict_init(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td.unlock_() - - # Create TensorDict and dict equivalent values, and populate each with according nested value - td_clone = td.clone(recurse=True) - td_dict = td.to_dict() - nested_dict_value = {"e": torch.randn(4, 3, 2, 1, 10)} - nested_tensordict_value = TensorDict( - nested_dict_value, batch_size=td.batch_size, device=device - ) - td_dict["d"] = nested_dict_value - td_clone["d"] = nested_tensordict_value - - # Re-init new TensorDict from dict, and check if they're equal - td_dict_init = TensorDict(td_dict, batch_size=td.batch_size, device=device) - - assert (td_clone == td_dict_init).all() - - def test_nested_td_index(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() + td.unlock_() sub_td = TensorDict({}, [*td.shape, 2], device=device) a = torch.zeros([*td.shape, 2, 2], device=device) @@ -2510,406 +2931,210 @@ def test_nested_td_index(self, td_name, device): td["sub_td"]["sub_sub_td"]["b"] == td["sub_td", "sub_sub_td", "b"] ).all() - @pytest.mark.parametrize("inplace", [True, False]) - @pytest.mark.parametrize("separator", [",", "-"]) - def test_flatten_keys(self, td_name, device, inplace, separator): - td = getattr(self, td_name)(device) - locked = td.is_locked - td.unlock_() - nested_nested_tensordict = TensorDict( - { - "a": torch.zeros(*td.shape, 2, 3), - }, - [*td.shape, 2], - ) - nested_tensordict = TensorDict( - { - "a": torch.zeros(*td.shape, 2), - "nested_nested_tensordict": nested_nested_tensordict, - }, - td.shape, - ) - td["nested_tensordict"] = nested_tensordict - if locked: - td.lock_() + @pytest.mark.parametrize("dim", [0, 1, -1, -5]) + @pytest.mark.parametrize( + "key", ["heterogeneous-entry", ("sub", "heterogeneous-entry")] + ) + def test_nestedtensor_stack(self, td_name, device, dim, key): + torch.manual_seed(1) + td1 = getattr(self, td_name)(device).unlock_() + td2 = getattr(self, td_name)(device).unlock_() - if inplace and locked: - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) - return - elif td_name in ("td_h5",) and inplace: - with pytest.raises( - ValueError, - match="Cannot call flatten_keys in_place with a PersistentTensorDict", - ): - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) - return - else: - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) - for value in td_flatten.values(): - assert not isinstance(value, TensorDictBase) - assert ( - separator.join(["nested_tensordict", "nested_nested_tensordict", "a"]) - in td_flatten.keys() - ) - if inplace: - assert td_flatten is td + td1[key] = torch.randn(*td1.shape, 2) + td2[key] = torch.randn(*td1.shape, 3) + td_stack = torch.stack([td1, td2], dim) + # get will fail + with pytest.raises( + RuntimeError, match="Found more than one unique shape in the tensors" + ): + td_stack.get(key) + with pytest.raises( + RuntimeError, match="Found more than one unique shape in the tensors" + ): + td_stack[key] + if dim in (0, -5): + # this will work if stack_dim is 0 (or equivalently -self.batch_dims) + # it is the proper way to get that entry + td_stack.get_nestedtensor(key) else: - assert td_flatten is not td + # if the stack_dim is not zero, then calling get_nestedtensor is disallowed + with pytest.raises( + RuntimeError, + match="LazyStackedTensorDict.get_nestedtensor can only be called " + "when the stack_dim is 0.", + ): + td_stack.get_nestedtensor(key) + with pytest.raises( + RuntimeError, match="Found more than one unique shape in the tensors" + ): + td_stack.contiguous() + with pytest.raises( + RuntimeError, match="Found more than one unique shape in the tensors" + ): + td_stack.to_tensordict() + # cloning is type-preserving: we can do that operation + td_stack.clone() - @pytest.mark.parametrize("inplace", [True, False]) - @pytest.mark.parametrize("separator", [",", "-"]) - def test_unflatten_keys(self, td_name, device, inplace, separator): + def test_non_tensor_data(self, td_name, device): td = getattr(self, td_name)(device) - locked = td.is_locked - td.unlock_() - nested_nested_tensordict = TensorDict( - { - "a": torch.zeros(*td.shape, 2, 3), - }, - [*td.shape, 2], - ) - nested_tensordict = TensorDict( - { - "a": torch.zeros(*td.shape, 2), - "nested_nested_tensordict": nested_nested_tensordict, - }, - td.shape, - ) - td["nested_tensordict"] = nested_tensordict - - if inplace and locked: - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) - td_flatten.lock_() - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td_unflatten = td_flatten.unflatten_keys( - inplace=inplace, separator=separator - ) - return - else: - if locked: - td.lock_() - if td_name in ("td_h5",) and inplace: - with pytest.raises( - ValueError, - match="Cannot call flatten_keys in_place with a PersistentTensorDict", - ): - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + # check lock + if td_name not in ("sub_td", "sub_td2"): + with td.lock_(), pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)): + td.set_non_tensor(("this", "will"), "fail") + # check set + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False + with pytest.raises( + RuntimeError, + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: return - td_flatten = td.flatten_keys(inplace=inplace, separator=separator) - td_unflatten = td_flatten.unflatten_keys( - inplace=inplace, separator=separator - ) - assert (td == td_unflatten).all() - if inplace: - assert td is td_unflatten - - def test_repr(self, td_name, device): - td = getattr(self, td_name)(device) - _ = str(td) + # check get (for tensor) + assert (td.get_non_tensor(("this", "tensor")) == 0).all() + # check get (for non-tensor) + assert td.get_non_tensor(("this", "will")) == "succeed" + assert isinstance(td.get(("this", "will")), NonTensorData) - @pytest.mark.parametrize("use_dir", [True, False]) - @pytest.mark.parametrize("num_threads", [0, 2]) - def test_memmap_(self, td_name, device, use_dir, tmpdir, num_threads): + def test_non_tensor_data_flatten_keys(self, td_name, device): td = getattr(self, td_name)(device) - if td_name in ("sub_td", "sub_td2"): + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False with pytest.raises( RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", - ): - td.memmap_( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - ) - return - elif td_name in ("td_h5", "td_params"): + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: + return + td_flat = td.flatten_keys() + assert (td_flat.get("this.tensor") == 0).all() + assert td_flat.get_non_tensor("this.will") == "succeed" + + def test_non_tensor_data_pickle(self, td_name, device, tmpdir): + td = getattr(self, td_name)(device) + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False with pytest.raises( RuntimeError, - match="Cannot build a memmap TensorDict in-place", - ): - td.memmap_( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - ) - return - else: - td.memmap_( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - ) - assert td.is_memmap() - if use_dir: - assert_allclose_td(TensorDict.load_memmap(tmpdir), td) + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: + return + td.set_non_tensor(("non", "json", "serializable"), DummyPicklableClass(10)) + td.memmap(prefix=tmpdir, copy_existing=True) + loaded = TensorDict.load_memmap(tmpdir) + assert isinstance(loaded.get(("non", "json", "serializable")), NonTensorData) + assert loaded.get_non_tensor(("non", "json", "serializable")).value == 10 + assert loaded.get_non_tensor(("this", "will")) == "succeed" - @pytest.mark.parametrize("use_dir", [True, False]) - @pytest.mark.parametrize("num_threads", [2]) - def test_memmap_threads(self, td_name, device, use_dir, tmpdir, num_threads): + def test_pad(self, td_name, device): td = getattr(self, td_name)(device) - tdmmap = td.memmap( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - ) - tdfuture = td.memmap( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - return_early=True, - ) - assert_allclose_td(td.cpu().detach(), tdmmap) - assert_allclose_td(td.cpu().detach(), tdfuture.result()) + paddings = [ + [0, 1, 0, 2], + [1, 0, 0, 2], + [1, 0, 2, 1], + ] - @pytest.mark.parametrize("use_dir", [True, False]) - @pytest.mark.parametrize("num_threads", [0, 2]) - def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): - td = getattr(self, td_name)(device) - tdmemmap = td.memmap_like( - prefix=tmpdir if use_dir else None, - num_threads=num_threads, - copy_existing=True, - ) - assert tdmemmap is not td - for key in td.keys(True): - assert td[key] is not tdmemmap[key] - assert (tdmemmap == 0).all() + for pad_size in paddings: + padded_td = pad(td, pad_size) + padded_td._check_batch_size() + amount_expanded = [0] * (len(pad_size) // 2) + for i in range(0, len(pad_size), 2): + amount_expanded[i // 2] = pad_size[i] + pad_size[i + 1] - def test_memmap_prefix(self, td_name, device, tmp_path): - if td_name == "memmap_td": - pytest.skip( - "Memmap case is redundant, functionality checked by other cases" - ) + for key in padded_td.keys(): + expected_dims = tuple( + sum(p) + for p in zip( + td[key].shape, + amount_expanded + + [0] * (len(td[key].shape) - len(amount_expanded)), + ) + ) + assert padded_td[key].shape == expected_dims - td = getattr(self, td_name)(device) - if td_name in ("sub_td", "sub_td2"): - with pytest.raises( - RuntimeError, - match="Converting a sub-tensordict values to memmap cannot be done", - ): - td.memmap_(tmp_path / "tensordict") - return - elif td_name in ("td_h5", "td_params"): - with pytest.raises( - RuntimeError, - match="Cannot build a memmap TensorDict in-place", - ): - td.memmap_(tmp_path / "tensordict") - return - else: - td.memmap_(tmp_path / "tensordict") - if td_name not in ("stacked_td", "nested_stacked_td"): - jsonpath = tmp_path / "tensordict" / "meta.json" - else: - jsonpath = tmp_path / "tensordict" / "0" / "meta.json" - assert jsonpath.exists(), td - with open(jsonpath) as file: - metadata = json.load(file) - if td_name in ("stacked_td", "nested_stacked_td"): - assert metadata["shape"] == list(td.tensordicts[0].batch_size) - else: - assert metadata["shape"] == list(td.batch_size) + with pytest.raises(RuntimeError): + pad(td, [0] * 100) - td2 = td.load_memmap(tmp_path / "tensordict") - assert (td.cpu() == td2.cpu()).all() + with pytest.raises(RuntimeError): + pad(td, [0]) - @pytest.mark.parametrize("copy_existing", [False, True]) - def test_memmap_existing(self, td_name, device, copy_existing, tmp_path): - if td_name == "memmap_td": - pytest.skip( - "Memmap case is redundant, functionality checked by other cases" - ) - elif td_name in ("sub_td", "sub_td2", "td_h5", "td_params"): - pytest.skip( - "_SubTensorDict/H5 and memmap_ incompatibility is checked elsewhere" - ) - - td = getattr(self, td_name)(device).memmap_(prefix=tmp_path / "tensordict") - td2 = getattr(self, td_name)(device).memmap_() - - if copy_existing: - td3 = td.memmap_(prefix=tmp_path / "tensordict2", copy_existing=True) - assert (td == td3).all() - else: - with pytest.raises( - RuntimeError, - match="A filename was provided but the tensor already has a file associated", - ): - # calling memmap_ with prefix that is different to contents gives error - td.memmap_(prefix=tmp_path / "tensordict2") - - # calling memmap_ without prefix means no-op, regardless of whether contents - # were saved in temporary or designated location (td vs. td2 resp.) - td3 = td.memmap_() - td4 = td2.memmap_() - - if td_name in ("stacked_td", "nested_stacked_td"): - assert all( - all( - td3_[key] is value - for key, value in td_.items( - include_nested=True, leaves_only=True - ) - ) - for td_, td3_ in zip(td.tensordicts, td3.tensordicts) + def test_permute_applied_twice(self, td_name, device): + torch.manual_seed(0) + tensordict = getattr(self, td_name)(device) + for _ in range(10): + p = torch.randperm(4) + inv_p = p.argsort() + other_p = inv_p + while (other_p == inv_p).all(): + other_p = torch.randperm(4) + other_p = tuple(other_p.tolist()) + p = tuple(p.tolist()) + inv_p = tuple(inv_p.tolist()) + if td_name in ("td_params",): + # TODO: Should we break this? + assert ( + tensordict.permute(*p).permute(*inv_p)._param_td + is tensordict._param_td ) - assert all( - all( - td4_[key] is value - for key, value in td2_.items( - include_nested=True, leaves_only=True - ) - ) - for td2_, td4_ in zip(td2.tensordicts, td4.tensordicts) + assert ( + tensordict.permute(*p).permute(*other_p)._param_td + is not tensordict._param_td ) - elif td_name in ("permute_td", "squeezed_td", "unsqueezed_td"): - assert all( - td3._source[key] is value - for key, value in td._source.items( - include_nested=True, leaves_only=True - ) + assert ( + torch.permute(tensordict, p).permute(inv_p)._param_td + is tensordict._param_td ) - assert all( - td4._source[key] is value - for key, value in td2._source.items( - include_nested=True, leaves_only=True - ) + assert ( + torch.permute(tensordict, p).permute(other_p)._param_td + is not tensordict._param_td ) else: - assert all( - td3[key] is value - for key, value in td.items(include_nested=True, leaves_only=True) + assert assert_allclose_td( + tensordict.permute(*p).permute(*inv_p), tensordict ) - assert all( - td4[key] is value - for key, value in td2.items(include_nested=True, leaves_only=True) + assert tensordict.permute(*p).permute(*inv_p) is tensordict + assert tensordict.permute(*p).permute(*other_p) is not tensordict + assert assert_allclose_td( + torch.permute(tensordict, p).permute(inv_p), tensordict ) + assert torch.permute(tensordict, p).permute(inv_p) is tensordict + assert torch.permute(tensordict, p).permute(other_p) is not tensordict - def test_setdefault_missing_key(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() - expected = torch.ones_like(td.get("a")) - inserted = td.setdefault("z", expected) - assert (inserted == expected).all() - - def test_setdefault_existing_key(self, td_name, device): - td = getattr(self, td_name)(device) - td.unlock_() - expected = td.get("a") - inserted = td.setdefault("a", torch.ones_like(td.get("b"))) - assert (inserted == expected).all() - - def test_setdefault_nested(self, td_name, device): + @pytest.mark.skipif( + torch.cuda.device_count() == 0, reason="No cuda device detected" + ) + @pytest.mark.parametrize("device_cast", [0, "cuda:0", torch.device("cuda:0")]) + def test_pin_memory(self, td_name, device_cast, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) td.unlock_() - - tensor = torch.randn(4, 3, 2, 1, 5, device=device) - tensor2 = torch.ones(4, 3, 2, 1, 5, device=device) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 3, 2, 1], device=device) - sub_tensordict = TensorDict( - {"b": sub_sub_tensordict}, [4, 3, 2, 1], device=device - ) - if td_name == "td_h5": - del td["a"] - if td_name == "sub_td": - td = td._source.set( - "a", sub_tensordict.expand(2, *sub_tensordict.shape) - )._get_sub_tensordict(1) - elif td_name == "sub_td2": - td = td._source.set( - "a", - sub_tensordict.expand(2, *sub_tensordict.shape).permute(1, 0, 2, 3, 4), - )._get_sub_tensordict((slice(None), 1)) - else: - td.set("a", sub_tensordict) - - # if key exists we return the existing value - torch.testing.assert_close(td.setdefault(("a", "b", "c"), tensor2), tensor) - - if not td_name == "stacked_td": - torch.testing.assert_close(td.setdefault(("a", "b", "d"), tensor2), tensor2) - torch.testing.assert_close(td.get(("a", "b", "d")), tensor2) - - @pytest.mark.parametrize("performer", ["torch", "tensordict"]) - @pytest.mark.parametrize("dim", range(4)) - def test_split(self, td_name, device, performer, dim): - td = getattr(self, td_name)(device) - t = torch.zeros(()).expand(td.shape) - for dim in range(td.batch_dims): - rep, remainder = divmod(td.shape[dim], 2) - split_sizes = [2] * rep + [1] * remainder - for test_split_size in (2, split_sizes): - tensorsplit = t.split(test_split_size, dim=dim) - length = len(tensorsplit) - if performer == "torch": - tds = torch.split(td, test_split_size, dim) - elif performer == "tensordict": - tds = td.split(test_split_size, dim) - assert len(tds) == length - - for idx, split_td in enumerate(tds): - expected_split_dim_size = 1 if idx == rep else 2 - expected_batch_size = tensorsplit[idx].shape - # Test each split_td has the expected batch_size - assert split_td.batch_size == torch.Size(expected_batch_size) - - if td_name == "nested_td": - assert isinstance(split_td["my_nested_td"], TensorDict) - assert isinstance( - split_td["my_nested_td", "inner"], torch.Tensor - ) - - # Test each tensor (or nested_td) in split_td has the expected shape - for key, item in split_td.items(): - expected_shape = [ - expected_split_dim_size if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate(td[key].shape) - ] - assert item.shape == torch.Size(expected_shape) - - if key == "my_nested_td": - expected_inner_tensor_size = [ - expected_split_dim_size if dim_idx == dim else dim_size - for (dim_idx, dim_size) in enumerate( - td[key]["inner"].shape - ) - ] - assert item["inner"].shape == torch.Size( - expected_inner_tensor_size - ) - # tensor = torch.zeros(()).expand(td.shape) - # - # rep, remainder = divmod(td.shape[dim], 2) - # - # # split_sizes to be [2, 2, ..., 2, 1] or [2, 2, ..., 2] - # split_sizes = [2] * rep + [1] * remainder - # for test_split_size in (2, split_sizes): - # if performer == "torch": - # tds = torch.split(td, test_split_size, dim) - # elif performer == "tensordict": - # tds = td.split(test_split_size, dim) - # tensors = tensor.split(test_split_size, dim) - # length = len(tensors) - # assert len(tds) == length, ( - # test_split_size, - # dim, - # [td.shape for td in tds], - # td.shape, - # length, - # ) - # - # for _tensor, split_td in zip(tensors, tds): - # assert _tensor.shape == split_td.shape - # - # if td_name == "nested_td": - # assert isinstance(split_td["my_nested_td"], TensorDict) - # assert isinstance(split_td["my_nested_td", "inner"], torch.Tensor) - # - # # Test each tensor (or nested_td) in split_td has the expected shape - # for key, item in split_td.items(True, True): - # expected_shape = _tensor.shape + td[key].shape[len(_tensor.shape) :] - # assert item.shape == torch.Size(expected_shape) + if device.type == "cuda": + with pytest.raises(RuntimeError, match="cannot pin"): + td.pin_memory() + return + td.pin_memory() + td_device = td.to(device_cast) + _device_cast = torch.device(device_cast) + assert td_device.device == _device_cast + assert td_device.clone().device == _device_cast + if device != _device_cast: + assert td_device is not td + for item in td_device.values(): + assert item.device == _device_cast + for item in td_device.clone().values(): + assert item.device == _device_cast + # assert type(td_device) is type(td) + assert_allclose_td(td, td_device.to(device)) def test_pop(self, td_name, device): td = getattr(self, td_name)(device) @@ -2939,150 +3164,156 @@ def test_pop(self, td_name, device): ): td.pop("z") - def test_setitem_slice(self, td_name, device): - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:] = td.clone() - td_set[:1] = td[:1].clone().zero_() - assert (td[:1] == 0).all() + @pytest.mark.parametrize("call_del", [True, False]) + def test_remove(self, td_name, device, call_del): + torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:1] = td[:1].to_tensordict().zero_() - assert (td[:1] == 0).all() + with td.unlock_(): + if call_del: + del td["a"] + else: + td = td.del_("a") + assert td is not None + assert "a" not in td.keys() + if td_name in ("sub_td", "sub_td2"): + return + td.lock_() + with pytest.raises(RuntimeError, match="locked"): + del td["b"] - # with broadcast - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:1] = td[0].clone().zero_() - assert (td[:1] == 0).all() + def test_rename_key(self, td_name, device) -> None: + torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data + if td.is_locked: + with pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)): + td.rename_key_("a", "b", safe=True) else: - td_set = td - td_set[:1] = td[0].to_tensordict().zero_() - assert (td[:1] == 0).all() + with pytest.raises(KeyError, match="already present in TensorDict"): + td.rename_key_("a", "b", safe=True) + a = td.get("a") + if td.is_locked: + with pytest.raises(RuntimeError, match="Cannot modify"): + td.rename_key_("a", "z") + return + else: + td.rename_key_("a", "z") + with pytest.raises(KeyError): + td.get("a") + assert "a" not in td.keys() - td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data + z = td.get("z") + torch.testing.assert_close(a, z) + + new_z = torch.randn_like(z) + if td_name in ("sub_td", "sub_td2"): + td.set_("z", new_z) else: - td_set = td - td_set[:1, 0] = td[0, 0].clone().zero_() - assert (td[:1, 0] == 0).all() - td = getattr(self, td_name)(device) + td.set("z", new_z) + + torch.testing.assert_close(new_z, td.get("z")) + + new_z = torch.randn_like(z) if td_name == "td_params": - td_set = td.data + td.data.set_("z", new_z) else: - td_set = td - td_set[:1, 0] = td[0, 0].to_tensordict().zero_() - assert (td[:1, 0] == 0).all() + td.set_("z", new_z) + torch.testing.assert_close(new_z, td.get("z")) + def test_rename_key_nested(self, td_name, device) -> None: + torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:1, :, 0] = td[0, :, 0].clone().zero_() - assert (td[:1, :, 0] == 0).all() + td.unlock_() + td["nested", "conflict"] = torch.zeros(td.shape) + with pytest.raises(KeyError, match="already present in TensorDict"): + td.rename_key_(("nested", "conflict"), "b", safe=True) + td["nested", "first"] = torch.zeros(td.shape) + td.rename_key_(("nested", "first"), "second") + assert (td["second"] == 0).all() + assert ("nested", "first") not in td.keys(True) + td.rename_key_("second", ("nested", "back")) + assert (td[("nested", "back")] == 0).all() + assert "second" not in td.keys() + + def test_repr(self, td_name, device): td = getattr(self, td_name)(device) - if td_name == "td_params": - td_set = td.data - else: - td_set = td - td_set[:1, :, 0] = td[0, :, 0].to_tensordict().zero_() - assert (td[:1, :, 0] == 0).all() + _ = str(td) - def test_casts(self, td_name, device): + def test_reshape(self, td_name, device): td = getattr(self, td_name)(device) - # exclude non-tensor data - is_leaf = lambda cls: issubclass(cls, torch.Tensor) - tdfloat = td.float() - assert all( - value.dtype is torch.float - for value in tdfloat.values(True, True, is_leaf=is_leaf) - ) - tddouble = td.double() - assert all( - value.dtype is torch.double - for value in tddouble.values(True, True, is_leaf=is_leaf) - ) - tdbfloat16 = td.bfloat16() - assert all( - value.dtype is torch.bfloat16 - for value in tdbfloat16.values(True, True, is_leaf=is_leaf) - ) - tdhalf = td.half() - assert all( - value.dtype is torch.half - for value in tdhalf.values(True, True, is_leaf=is_leaf) - ) - tdint = td.int() - assert all( - value.dtype is torch.int - for value in tdint.values(True, True, is_leaf=is_leaf) - ) - tdint = td.type(torch.int) - assert all( - value.dtype is torch.int - for value in tdint.values(True, True, is_leaf=is_leaf) - ) + td_reshape = td.reshape(td.shape) + # assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == td.shape + td_reshape = td.reshape(*td.shape) + # assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == td.shape + td_reshape = td.reshape(size=td.shape) + # assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == td.shape + td_reshape = td.reshape(-1) + assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == torch.Size([td.shape.numel()]) + td_reshape = td.reshape((-1,)) + assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == torch.Size([td.shape.numel()]) + td_reshape = td.reshape(size=(-1,)) + assert isinstance(td_reshape, TensorDict) + assert td_reshape.shape.numel() == td.shape.numel() + assert td_reshape.shape == torch.Size([td.shape.numel()]) - def test_empty_like(self, td_name, device): - if "sub_td" in td_name: - # we do not call skip to avoid systematic skips in internal code base - return + @pytest.mark.parametrize("strict", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + def test_select(self, td_name, device, strict, inplace): + torch.manual_seed(1) td = getattr(self, td_name)(device) - if isinstance(td, _CustomOpTensorDict): - # we do not call skip to avoid systematic skips in internal code base + keys = ["a"] + if td_name == "td_h5": + with pytest.raises(NotImplementedError, match="Cannot call select"): + td.select(*keys, strict=strict, inplace=inplace) return - td_empty = torch.empty_like(td) - td.apply_(lambda x: x + 1.0) - assert type(td) is type(td_empty) - # exclude non tensor data - comp = td.filter_non_tensor_data() != td_empty.filter_non_tensor_data() - logging.info(td.filter_non_tensor_data()) - assert all(val.any() for val in comp.values(True, True)) + if td_name in ("nested_stacked_td", "nested_td"): + keys += [("my_nested_td", "inner")] - @pytest.mark.parametrize("nested", [False, True]) - def test_add_batch_dim_cache(self, td_name, device, nested): - td = getattr(self, td_name)(device) - if nested: - td = TensorDict({"parent": td}, td.batch_size) - from tensordict.nn import TensorDictModule # noqa - from torch import vmap + with td.unlock_() if td.is_locked else contextlib.nullcontext(): + td2 = td.select(*keys, strict=strict, inplace=inplace) + if inplace: + assert td2 is td + else: + assert td2 is not td + if td_name == "saved_td": + assert (len(list(td2.keys())) == len(keys)) and ("a" in td2.keys()) + assert (len(list(td2.clone().keys())) == len(keys)) and ( + "a" in td2.clone().keys() + ) + else: + assert (len(list(td2.keys(True, True))) == len(keys)) and ( + "a" in td2.keys() + ) + assert (len(list(td2.clone().keys(True, True))) == len(keys)) and ( + "a" in td2.clone().keys() + ) - fun = vmap(lambda x: x) + @pytest.mark.parametrize("strict", [True, False]) + def test_select_exception(self, td_name, device, strict): + torch.manual_seed(1) + td = getattr(self, td_name)(device) if td_name == "td_h5": - with pytest.raises( - RuntimeError, match="Persistent tensordicts cannot be used with vmap" - ): - fun(td) - return - if td_name == "memmap_td" and device.type != "cpu": - with pytest.raises( - RuntimeError, - match="MemoryMappedTensor with non-cpu device are not supported in vmap ops", - ): - fun(td) + with pytest.raises(NotImplementedError, match="Cannot call select"): + _ = td.select("tada", strict=strict) return - fun(td) - td.zero_() - # this value should be cached - std = fun(td) - for value in std.values(True, True): - assert (value == 0).all() + if strict: + with pytest.raises(KeyError): + _ = td.select("tada", strict=strict) + else: + td2 = td.select("tada", strict=strict) + assert td2 is not td + assert len(list(td2.keys())) == 0 def test_set_lazy_legacy(self, td_name, device): def test_not_id(): @@ -3166,2125 +3397,1987 @@ def test_id(): assert lazy_legacy() test_id() - def test_update_select(self, td_name, device): - if td_name in ("memmap_td",): - pytest.skip(reason="update not possible with memory-mapped td") + def test_set_nested_batch_size(self, td_name, device): td = getattr(self, td_name)(device) - t = lambda: torch.zeros(()).expand((4, 3, 2, 1)) - other_td = TensorDict( - { - "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, - "self-improving": t(), - }, - batch_size=(4, 3, 2, 1), - ) - td.update( - other_td, - keys_to_update=(("My", ("father",), "was"), ("My", "relentlessly")), - ) - assert ("My", "father", "was") in td.keys(True) - assert ("My", ("father",), "was") in td.keys(True) - assert ("My", "relentlessly") in td.keys(True) - assert ("My", "father", "a") in td.keys(True) - assert ("self-improving",) not in td.keys(True) - t = lambda: torch.ones(()).expand((4, 3, 2, 1)) - other_td = TensorDict( - { - "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, - "self-improving": t(), - }, - batch_size=(4, 3, 2, 1), - ) - td.update(other_td, keys_to_update=(("My", "relentlessly"),)) - assert (td["My", "relentlessly"] == 1).all() - assert (td["My", "father", "was"] == 0).all() - td.update(other_td, keys_to_update=(("My", ("father",), "was"),)) - assert (td["My", "father", "was"] == 1).all() + td.unlock_() + batch_size = torch.Size([*td.batch_size, 3]) + td.set("some_other_td", TensorDict({}, batch_size)) + assert td["some_other_td"].batch_size == batch_size - def test_non_tensor_data(self, td_name, device): + def test_set_nontensor(self, td_name, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) - # check lock - if td_name not in ("sub_td", "sub_td2"): - with td.lock_(), pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)): - td.set_non_tensor(("this", "will"), "fail") - # check set - with td.unlock_(): - td.set(("this", "tensor"), torch.zeros(td.shape)) - reached = False - with pytest.raises( - RuntimeError, - match="set_non_tensor is not compatible with the tensordict type", - ) if td_name in ("td_h5",) else contextlib.nullcontext(): - td.set_non_tensor(("this", "will"), "succeed") - reached = True - if not reached: - return - # check get (for tensor) - assert (td.get_non_tensor(("this", "tensor")) == 0).all() - # check get (for non-tensor) - assert td.get_non_tensor(("this", "will")) == "succeed" - assert isinstance(td.get(("this", "will")), NonTensorData) + td.unlock_() + r = torch.randn_like(td.get("a")) + td.set("numpy", r.cpu().numpy()) + torch.testing.assert_close(td.get("numpy"), r) - def test_non_tensor_data_flatten_keys(self, td_name, device): + def test_set_requires_grad(self, td_name, device): td = getattr(self, td_name)(device) - with td.unlock_(): - td.set(("this", "tensor"), torch.zeros(td.shape)) - reached = False + if td_name in ("td_params",): + td.apply(lambda x: x.requires_grad_(False)) + td.unlock_() + assert not td.get("a").requires_grad + if td_name in ("td_h5",): with pytest.raises( - RuntimeError, - match="set_non_tensor is not compatible with the tensordict type", - ) if td_name in ("td_h5",) else contextlib.nullcontext(): - td.set_non_tensor(("this", "will"), "succeed") - reached = True - if not reached: - return - td_flat = td.flatten_keys() - assert (td_flat.get("this.tensor") == 0).all() - assert td_flat.get_non_tensor("this.will") == "succeed" + RuntimeError, match="Cannot set a tensor that has requires_grad=True" + ): + td.set("a", torch.randn_like(td.get("a")).requires_grad_()) + return + if td_name in ("sub_td", "sub_td2"): + td.set_("a", torch.randn_like(td.get("a")).requires_grad_()) + else: + td.set("a", torch.randn_like(td.get("a")).requires_grad_()) - def test_non_tensor_data_pickle(self, td_name, device, tmpdir): + assert td.get("a").requires_grad + + def test_set_unexisting(self, td_name, device): + torch.manual_seed(1) td = getattr(self, td_name)(device) - with td.unlock_(): - td.set(("this", "tensor"), torch.zeros(td.shape)) - reached = False + if td.is_locked: with pytest.raises( RuntimeError, - match="set_non_tensor is not compatible with the tensordict type", - ) if td_name in ("td_h5",) else contextlib.nullcontext(): - td.set_non_tensor(("this", "will"), "succeed") - reached = True - if not reached: - return - td.set_non_tensor(("non", "json", "serializable"), DummyPicklableClass(10)) - td.memmap(prefix=tmpdir, copy_existing=True) - loaded = TensorDict.load_memmap(tmpdir) - assert isinstance(loaded.get(("non", "json", "serializable")), NonTensorData) - assert loaded.get_non_tensor(("non", "json", "serializable")).value == 10 - assert loaded.get_non_tensor(("this", "will")) == "succeed" + match="Cannot modify locked TensorDict. For in-place modification", + ): + td.set("z", torch.ones_like(td.get("a"))) + else: + td.set("z", torch.ones_like(td.get("a"))) + assert (td.get("z") == 1).all() + def test_setdefault_existing_key(self, td_name, device): + td = getattr(self, td_name)(device) + td.unlock_() + expected = td.get("a") + inserted = td.setdefault("a", torch.ones_like(td.get("b"))) + assert (inserted == expected).all() -@pytest.mark.parametrize("device", [None, *get_available_devices()]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) -class TestTensorDictRepr: - def td(self, device, dtype): - if device is not None: - device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): - device_not_none = torch.device("cuda:0") - else: - device_not_none = torch.device("cpu") + def test_setdefault_missing_key(self, td_name, device): + td = getattr(self, td_name)(device) + td.unlock_() + expected = torch.ones_like(td.get("a")) + inserted = td.setdefault("z", expected) + assert (inserted == expected).all() - return TensorDict( - source={ - "a": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none) - }, - batch_size=[4, 3, 2, 1], - device=device, - ) + def test_setdefault_nested(self, td_name, device): + td = getattr(self, td_name)(device) + td.unlock_() - def nested_td(self, device, dtype): - if device is not None: - device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): - device_not_none = torch.device("cuda:0") - else: - device_not_none = torch.device("cpu") - return TensorDict( - source={ - "my_nested_td": self.td(device, dtype), - "b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none), - }, - batch_size=[4, 3, 2, 1], - device=device, + tensor = torch.randn(4, 3, 2, 1, 5, device=device) + tensor2 = torch.ones(4, 3, 2, 1, 5, device=device) + sub_sub_tensordict = TensorDict({"c": tensor}, [4, 3, 2, 1], device=device) + sub_tensordict = TensorDict( + {"b": sub_sub_tensordict}, [4, 3, 2, 1], device=device ) + if td_name == "td_h5": + del td["a"] + if td_name == "sub_td": + td = td._source.set( + "a", sub_tensordict.expand(2, *sub_tensordict.shape) + )._get_sub_tensordict(1) + elif td_name == "sub_td2": + td = td._source.set( + "a", + sub_tensordict.expand(2, *sub_tensordict.shape).permute(1, 0, 2, 3, 4), + )._get_sub_tensordict((slice(None), 1)) + else: + td.set("a", sub_tensordict) - def nested_tensorclass(self, device, dtype): - from tensordict import tensorclass + # if key exists we return the existing value + torch.testing.assert_close(td.setdefault(("a", "b", "c"), tensor2), tensor) - @tensorclass - class MyClass: - X: torch.Tensor - y: "MyClass" - z: str + if not td_name == "stacked_td": + torch.testing.assert_close(td.setdefault(("a", "b", "d"), tensor2), tensor2) + torch.testing.assert_close(td.get(("a", "b", "d")), tensor2) - if device is not None: - device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): - device_not_none = torch.device("cuda:0") - else: - device_not_none = torch.device("cpu") - nested_class = MyClass( - X=torch.zeros(4, 3, 2, 1, dtype=dtype, device=device_not_none), - y=MyClass( - X=torch.zeros(4, 3, 2, 1, dtype=dtype, device=device_not_none), - y=None, - z=None, - batch_size=[4, 3, 2, 1], - ), - z="z", - batch_size=[4, 3, 2, 1], - ) - return TensorDict( - source={ - "my_nested_td": nested_class, - "b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none), - }, - batch_size=[4, 3, 2, 1], - device=device, - ) + @pytest.mark.parametrize( + "idx", [slice(1), torch.tensor([0]), torch.tensor([0, 1]), range(1), range(2)] + ) + def test_setitem(self, td_name, device, idx): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if isinstance(idx, torch.Tensor) and idx.numel() > 1 and td.shape[0] == 1: + pytest.mark.skip("cannot index tensor with desired index") + return - def stacked_td(self, device, dtype): - if device is not None: - device_not_none = device - elif torch.has_cuda and torch.cuda.device_count(): - device_not_none = torch.device("cuda:0") + td_clone = td[idx].to_tensordict().zero_() + if td_name == "td_params": + td.data[idx] = td_clone else: - device_not_none = torch.device("cpu") - td1 = TensorDict( - source={ - "a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), - "c": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), - }, - batch_size=[4, 3, 1], - device=device, - ) - td2 = TensorDict( - source={ - "a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), - "b": torch.zeros(4, 3, 1, 10, dtype=dtype, device=device_not_none), - }, - batch_size=[4, 3, 1], - device=device, - ) - - return stack_td([td1, td2], 2) + td[idx] = td_clone + assert (td[idx].get("a") == 0).all() - def memmap_td(self, device, dtype): - if device is not None and device.type != "cpu": - pytest.skip("MemoryMappedTensors can only be placed on CPU.") - return self.td(device, dtype).memmap_() + td_clone = torch.cat([td_clone, td_clone], 0) + with pytest.raises( + RuntimeError, + match=r"differs from the source batch size|batch dimension mismatch|Cannot broadcast the tensordict", + ): + td[idx] = td_clone - def share_memory_td(self, device, dtype): - return self.td(device, dtype).share_memory_() + @pytest.mark.parametrize("actual_index", [..., (..., 0), (0, ...), (0, ..., 0)]) + def test_setitem_ellipsis(self, td_name, device, actual_index): + torch.manual_seed(1) + td = getattr(self, td_name)(device) - def test_repr_plain(self, device, dtype): - tensordict = self.td(device, dtype) - if device is not None and device.type == "cuda": - is_shared = True + idx = actual_index + td_clone = td.clone() + actual_td = td_clone[idx].clone() + if td_name in ("td_params",): + td_set = actual_td.apply(lambda x: x.data) else: - is_shared = False - tensor_device = device if device else tensordict["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True + td_set = actual_td + td_set.zero_() + + for key in actual_td.keys(): + assert (actual_td.get(key) == 0).all() + + if td_name in ("td_params",): + td_set = td_clone.data else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - a: Tensor(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - assert repr(tensordict) == expected + td_set = td_clone - def test_repr_memmap(self, device, dtype): - tensordict = self.memmap_td(device, dtype) - # tensor_device = device if device else tensordict["a"].device # noqa: F841 - expected = f"""TensorDict( - fields={{ - a: MemoryMappedTensor(shape=torch.Size([4, 3, 2, 1, 5]), device=cpu, dtype={dtype}, is_shared=False)}}, - batch_size=torch.Size([4, 3, 2, 1]), - device=cpu, - is_shared=False)""" - assert repr(tensordict) == expected + td_set[idx] = actual_td + for key in td_clone.keys(): + assert (td_clone[idx].get(key) == 0).all() - def test_repr_share_memory(self, device, dtype): - tensordict = self.share_memory_td(device, dtype) - is_shared = True - tensor_class = "Tensor" - tensor_device = device if device else tensordict["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - assert repr(tensordict) == expected + def test_setitem_nested_dict_value(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) - def test_repr_nested(self, device, dtype): - nested_td = self.nested_td(device, dtype) - if device is not None and device.type == "cuda": - is_shared = True - else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else nested_td["b"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - b: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - my_nested_td: TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - assert repr(nested_td) == expected + # Create equivalent TensorDict and dict nested values for setitem + nested_dict_value = {"e": torch.randn(4, 3, 2, 1, 10)} + nested_tensordict_value = TensorDict( + nested_dict_value, batch_size=td.batch_size, device=device + ) + td_clone1 = td.clone(recurse=True) + td_clone2 = td.clone(recurse=True) - def test_repr_nested_update(self, device, dtype): - nested_td = self.nested_td(device, dtype) - nested_td["my_nested_td"].rename_key_("a", "z") - if device is not None and device.type == "cuda": - is_shared = True + td_clone1["d"] = nested_dict_value + td_clone2["d"] = nested_tensordict_value + assert (td_clone1 == td_clone2).all() + + def test_setitem_nestedtuple(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + if td.is_locked: + td.unlock_() + td[" a ", (("little", "story")), "about", ("myself",)] = torch.zeros(td.shape) + assert (td[" a ", "little", "story", "about", "myself"] == 0).all() + + def test_setitem_slice(self, td_name, device): + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else nested_td["b"].device - if tensor_device.type == "cuda": - is_shared_tensor = True + td_set = td + td_set[:] = td.clone() + td_set[:1] = td[:1].clone().zero_() + assert (td[:1] == 0).all() + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - b: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - my_nested_td: TensorDict( - fields={{ - z: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - assert repr(nested_td) == expected + td_set = td + td_set[:1] = td[:1].to_tensordict().zero_() + assert (td[:1] == 0).all() - def test_repr_stacked(self, device, dtype): - stacked_td = self.stacked_td(device, dtype) - if device is not None and device.type == "cuda": - is_shared = True + # with broadcast + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else stacked_td["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True + td_set = td + td_set[:1] = td[0].clone().zero_() + assert (td[:1] == 0).all() + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared_tensor = is_shared - expected = f"""LazyStackedTensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - exclusive_fields={{ - 0 -> - c: {tensor_class}(shape=torch.Size([4, 3, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - 1 -> - b: {tensor_class}(shape=torch.Size([4, 3, 1, 10]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared}, - stack_dim={stacked_td.stack_dim})""" - assert repr(stacked_td) == expected + td_set = td + td_set[:1] = td[0].to_tensordict().zero_() + assert (td[:1] == 0).all() - def test_repr_stacked_het(self, device, dtype): - stacked_td = torch.stack( - [ - TensorDict( - { - "a": torch.zeros(3, dtype=dtype), - "b": torch.zeros(2, 3, dtype=dtype), - }, - [], - device=device, - ), - TensorDict( - { - "a": torch.zeros(2, dtype=dtype), - "b": torch.zeros((), dtype=dtype), - }, - [], - device=device, - ), - ] - ) - if device is not None and device.type == "cuda": - is_shared = True + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared = False - tensor_device = device if device else torch.device("cpu") - if tensor_device.type == "cuda": - is_shared_tensor = True + td_set = td + td_set[:1, 0] = td[0, 0].clone().zero_() + assert (td[:1, 0] == 0).all() + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared_tensor = is_shared - expected = f"""LazyStackedTensorDict( - fields={{ - a: Tensor(shape=torch.Size([2, -1]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - b: Tensor(shape=torch.Size([-1]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - exclusive_fields={{ - }}, - batch_size=torch.Size([2]), - device={str(device)}, - is_shared={is_shared}, - stack_dim={stacked_td.stack_dim})""" - assert repr(stacked_td) == expected + td_set = td + td_set[:1, 0] = td[0, 0].to_tensordict().zero_() + assert (td[:1, 0] == 0).all() - @pytest.mark.parametrize("index", [None, (slice(None), 0)]) - def test_repr_indexed_tensordict(self, device, dtype, index): - tensordict = self.td(device, dtype)[index] - if device is not None and device.type == "cuda": - is_shared = True - else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else tensordict["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - is_shared_tensor = is_shared - if index is None: - expected = f"""TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([1, 4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" + td_set = td + td_set[:1, :, 0] = td[0, :, 0].clone().zero_() + assert (td[:1, :, 0] == 0).all() + td = getattr(self, td_name)(device) + if td_name == "td_params": + td_set = td.data else: - expected = f"""TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" + td_set = td + td_set[:1, :, 0] = td[0, :, 0].to_tensordict().zero_() + assert (td[:1, :, 0] == 0).all() - assert repr(tensordict) == expected + def test_setitem_string(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td.unlock_() + td["d"] = torch.randn(4, 3, 2, 1, 5) + assert "d" in td.keys() - @pytest.mark.parametrize("index", [None, (slice(None), 0)]) - def test_repr_indexed_nested_tensordict(self, device, dtype, index): - nested_tensordict = self.nested_td(device, dtype)[index] - if device is not None and device.type == "cuda": - is_shared = True - else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else nested_tensordict["b"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - if index is None: - expected = f"""TensorDict( - fields={{ - b: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - my_nested_td: TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([1, 4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})}}, - batch_size=torch.Size([1, 4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - else: - expected = f"""TensorDict( - fields={{ - b: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - my_nested_td: TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 2, 1]), - device={str(device)}, - is_shared={is_shared})}}, - batch_size=torch.Size([4, 2, 1]), - device={str(device)}, - is_shared={is_shared})""" - assert repr(nested_tensordict) == expected - - @pytest.mark.parametrize("index", [None, (slice(None), 0)]) - def test_repr_indexed_stacked_tensordict(self, device, dtype, index): - stacked_tensordict = self.stacked_td(device, dtype) - if device is not None and device.type == "cuda": - is_shared = True - else: - is_shared = False - tensor_class = "Tensor" - tensor_device = device if device else stacked_tensordict["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - - expected = f"""LazyStackedTensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - exclusive_fields={{ - 0 -> - c: {tensor_class}(shape=torch.Size([4, 3, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), - 1 -> - b: {tensor_class}(shape=torch.Size([4, 3, 1, 10]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device)}, - is_shared={is_shared}, - stack_dim={stacked_tensordict.stack_dim})""" - - assert repr(stacked_tensordict) == expected - - @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") - @pytest.mark.parametrize("device_cast", get_available_devices()) - def test_repr_device_to_device(self, device, dtype, device_cast): - td = self.td(device, dtype) - if (device_cast is None and (torch.cuda.device_count() > 0)) or ( - device_cast is not None and device_cast.type == "cuda" - ): - is_shared = True - else: - is_shared = False - tensor_class = "Tensor" - td2 = td.to(device_cast) - tensor_device = device_cast if device_cast else td2["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2, 1]), - device={str(device_cast)}, - is_shared={is_shared})""" - assert repr(td2) == expected - - @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") - def test_repr_batch_size_update(self, device, dtype): - td = self.td(device, dtype) - td.batch_size = torch.Size([4, 3, 2]) - is_shared = False - tensor_class = "Tensor" - if device is not None and device.type == "cuda": - is_shared = True - tensor_device = device if device else td["a"].device - if tensor_device.type == "cuda": - is_shared_tensor = True - else: - is_shared_tensor = is_shared - expected = f"""TensorDict( - fields={{ - a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, - batch_size=torch.Size([4, 3, 2]), - device={device}, - is_shared={is_shared})""" - assert repr(td) == expected - - -@pytest.mark.parametrize( - "td_name", - [ - "td", - "stacked_td", - "sub_td", - "idx_td", - "unsqueezed_td", - "td_reset_bs", - ], -) -@pytest.mark.parametrize( - "device", - get_available_devices(), -) -class TestTensorDictsRequiresGrad: - def td(self, device): - return TensorDict( - source={ - "a": torch.randn(3, 1, 5, device=device), - "b": torch.randn(3, 1, 10, device=device, requires_grad=True), - "c": torch.randint(10, (3, 1, 3), device=device), - }, - batch_size=[3, 1], - ) + def test_shape(self, td_name, device): + td = getattr(self, td_name)(device) + assert td.shape == td.batch_size - def stacked_td(self, device): - return stack_td([self.td(device) for _ in range(2)], 0) + def test_sorted_keys(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sorted_keys = td.sorted_keys + i = -1 + for i, (key1, key2) in enumerate(zip(sorted_keys, td.keys())): # noqa: B007 + assert key1 == key2 + assert i == len(td.keys()) - 1 + if td.is_locked: + assert td._cache.get("sorted_keys", None) is not None + td.unlock_() + assert td._cache is None + elif td_name not in ("sub_td", "sub_td2"): # we cannot lock sub tensordicts + if isinstance(td, _CustomOpTensorDict): + target = td._source + else: + target = td + assert target._cache is None + td.lock_() + _ = td.sorted_keys + assert target._cache.get("sorted_keys", None) is not None + td.unlock_() + assert target._cache is None - def idx_td(self, device): - return self.td(device)[0] + @pytest.mark.parametrize("performer", ["torch", "tensordict"]) + @pytest.mark.parametrize("dim", range(4)) + def test_split(self, td_name, device, performer, dim): + td = getattr(self, td_name)(device) + t = torch.zeros(()).expand(td.shape) + for dim in range(td.batch_dims): + rep, remainder = divmod(td.shape[dim], 2) + split_sizes = [2] * rep + [1] * remainder + for test_split_size in (2, split_sizes): + tensorsplit = t.split(test_split_size, dim=dim) + length = len(tensorsplit) + if performer == "torch": + tds = torch.split(td, test_split_size, dim) + elif performer == "tensordict": + tds = td.split(test_split_size, dim) + assert len(tds) == length - def sub_td(self, device): - return self.td(device)._get_sub_tensordict(0) + for idx, split_td in enumerate(tds): + expected_split_dim_size = 1 if idx == rep else 2 + expected_batch_size = tensorsplit[idx].shape + # Test each split_td has the expected batch_size + assert split_td.batch_size == torch.Size(expected_batch_size) - def unsqueezed_td(self, device): - return self.td(device).unsqueeze(0) + if td_name == "nested_td": + assert isinstance(split_td["my_nested_td"], TensorDict) + assert isinstance( + split_td["my_nested_td", "inner"], torch.Tensor + ) - def td_reset_bs(self, device): - td = self.td(device) - td = td.unsqueeze(-1).to_tensordict() - td.batch_size = torch.Size([3, 1]) - return td + # Test each tensor (or nested_td) in split_td has the expected shape + for key, item in split_td.items(): + expected_shape = [ + expected_split_dim_size if dim_idx == dim else dim_size + for (dim_idx, dim_size) in enumerate(td[key].shape) + ] + assert item.shape == torch.Size(expected_shape) - def test_view(self, td_name, device): - torch.manual_seed(1) - td = getattr(self, td_name)(device) - td_view = td.view(-1) - assert td_view.get("b").requires_grad + if key == "my_nested_td": + expected_inner_tensor_size = [ + expected_split_dim_size if dim_idx == dim else dim_size + for (dim_idx, dim_size) in enumerate( + td[key]["inner"].shape + ) + ] + assert item["inner"].shape == torch.Size( + expected_inner_tensor_size + ) - def test_expand(self, td_name, device): + def test_squeeze(self, td_name, device, squeeze_dim=-1): torch.manual_seed(1) td = getattr(self, td_name)(device) - batch_size = td.batch_size - new_td = td.expand(3, *batch_size) - assert new_td.get("b").requires_grad - assert new_td.batch_size == torch.Size([3, *batch_size]) - - # Deprecated - # def test_cast(self, td_name, device): - # torch.manual_seed(1) - # td = getattr(self, td_name)(device) - # td_td = td.to(TensorDict) - # assert td_td.get("b").requires_grad + with td.unlock_(): # make sure that the td is not locked + td_squeeze = torch.squeeze(td, dim=-1) + tensor_squeeze_dim = td.batch_dims + squeeze_dim + tensor = torch.ones_like(td.get("a").squeeze(tensor_squeeze_dim)) + if td_name in ("sub_td", "sub_td2"): + td_squeeze.set_("a", tensor) + else: + td_squeeze.set("a", tensor) + assert td.batch_size[squeeze_dim] == 1 + assert (td_squeeze.get("a") == tensor).all() + assert (td.get("a") == tensor.unsqueeze(tensor_squeeze_dim)).all() + if td_name != "unsqueezed_td": + assert _compare_tensors_identity(td_squeeze.unsqueeze(squeeze_dim), td) + else: + assert td_squeeze is td._source + assert (td_squeeze.get("a") == 1).all() + assert (td.get("a") == 1).all() - def test_clone_td(self, td_name, device): + def test_squeeze_with_none(self, td_name, device, squeeze_dim=None): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert torch.clone(td).get("b").requires_grad + td_squeeze = torch.squeeze(td, dim=None) + tensor = torch.ones_like(td.get("a").squeeze()) + td_squeeze.set_("a", tensor) + assert (td_squeeze.get("a") == tensor).all() + if td_name == "unsqueezed_td": + assert td_squeeze._source is td + assert (td_squeeze.get("a") == 1).all() + assert (td.get("a") == 1).all() - def test_squeeze(self, td_name, device, squeeze_dim=-1): + @pytest.mark.filterwarnings("error") + def test_stack_onto(self, td_name, device, tmpdir): torch.manual_seed(1) td = getattr(self, td_name)(device) - assert torch.squeeze(td, dim=-1).get("b").requires_grad - - -def test_batchsize_reset(): - td = TensorDict( - {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] - ) - # smoke-test - td.batch_size = torch.Size([3]) - - # test with list - td.batch_size = [3] + if td_name == "td_h5": + td0 = td.clone(newfile=tmpdir / "file0.h5").apply_(lambda x: x.zero_()) + td1 = td.clone(newfile=tmpdir / "file1.h5").apply_(lambda x: x.zero_() + 1) + else: + td0 = td.clone() + if td_name in ("td_params",): + td0.data.apply_(lambda x: x.zero_()) + else: + td0.apply_(lambda x: x.zero_()) + td1 = td.clone() + if td_name in ("td_params",): + td1.data.apply_(lambda x: x.zero_() + 1) + else: + td1.apply_(lambda x: x.zero_() + 1) - # test with tuple - td.batch_size = (3,) + td_out = td.unsqueeze(1).expand(td.shape[0], 2, *td.shape[1:]).clone() + td_stack = torch.stack([td0, td1], 1) + if td_name == "td_params": + with pytest.raises(RuntimeError, match="out.batch_size and stacked"): + torch.stack([td0, td1], 0, out=td_out) + return + data_ptr_set_before = {val.data_ptr() for val in decompose(td_out)} + torch.stack([td0, td1], 1, out=td_out) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_out)} + assert data_ptr_set_before == data_ptr_set_after + assert (td_stack == td_out).all() - # incompatible size - with pytest.raises( - RuntimeError, - match=re.escape( - "the tensor a has shape torch.Size([3, 4, 5, 6]) which is incompatible with the batch-size torch.Size([3, 5])." - ), - ): - td.batch_size = [3, 5] - - # test set - td.set("c", torch.randn(3)) - - # test index - td[torch.tensor([1, 2])] - td[:] - td[[1, 2]] - with pytest.raises( - IndexError, - match=re.escape("too many indices for tensor of dimension 1"), - ): - td[:, 0] + @pytest.mark.filterwarnings("error") + def test_stack_subclasses_on_td(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td = td.expand(3, *td.batch_size).clone().zero_() + tds_list = [getattr(self, td_name)(device) for _ in range(3)] + if td_name == "td_params": + with pytest.raises(RuntimeError, match="arguments don't support automatic"): + torch.stack(tds_list, 0, out=td) + return + data_ptr_set_before = {val.data_ptr() for val in decompose(td)} + stacked_td = stack_td(tds_list, 0, out=td) + data_ptr_set_after = {val.data_ptr() for val in decompose(td)} + assert data_ptr_set_before == data_ptr_set_after + assert stacked_td.batch_size == td.batch_size + for key in ("a", "b", "c"): + assert (stacked_td[key] == td[key]).all() - # test a greater batch_size - td = TensorDict( - {"a": torch.randn(3, 4, 5, 6), "b": torch.randn(3, 4, 5)}, batch_size=[3, 4] - ) - td.batch_size = torch.Size([3, 4, 5]) - - td.set("c", torch.randn(3, 4, 5, 6)) - with pytest.raises( - RuntimeError, - match=re.escape( - "batch dimension mismatch, got self.batch_size=torch.Size([3, 4, 5]) and value.shape=torch.Size([3, 4, 2])" - ), - ): - td.set("d", torch.randn(3, 4, 2)) - - # test that lazy tds return an exception - td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)]) - with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), - ): - td_stack.batch_size = [2] - td_stack.to_tensordict().batch_size = [2] - - td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) - subtd = td._get_sub_tensordict((slice(None), torch.tensor([1, 2]))) - with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), - ): - subtd.batch_size = [3, 2] - subtd.to_tensordict().batch_size = [3, 2] - - td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) - td_u = td.unsqueeze(0) - with pytest.raises( - RuntimeError, - match=re.escape( - "modifying the batch size of a lazy representation of a tensordict is not permitted. Consider instantiating the tensordict first by calling `td = td.to_tensordict()` before resetting the batch size." - ), - ): - td_u.batch_size = [1] - td_u.to_tensordict().batch_size = [1] - - -@pytest.mark.parametrize("index0", [None, slice(None)]) -def test_set_sub_key(index0): - # tests that parent tensordict is affected when subtensordict is set with a new key - batch_size = [10, 10] - source = {"a": torch.randn(10, 10, 10), "b": torch.ones(10, 10, 2)} - td = TensorDict(source, batch_size=batch_size) - idx0 = (index0, 0) if index0 is not None else 0 - td0 = td._get_sub_tensordict(idx0) - idx = (index0, slice(2, 4)) if index0 is not None else slice(2, 4) - sub_td = td._get_sub_tensordict(idx) - if index0 is None: - c = torch.randn(2, 10, 10) - else: - c = torch.randn(10, 2, 10) - sub_td.set("c", c) - assert (td.get("c")[idx] == sub_td.get("c")).all() - assert (sub_td.get("c") == c).all() - assert (td.get("c")[idx0] == 0).all() - assert (td._get_sub_tensordict(idx0).get("c") == 0).all() - assert (td0.get("c") == 0).all() - - -@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") -def test_create_on_device(): - device = torch.device(0) - - # TensorDict - td = TensorDict({}, [5]) - assert td.device is None - - td.set("a", torch.randn(5, device=device)) - assert td.device is None - - td = TensorDict({}, [5], device="cuda:0") - td.set("a", torch.randn(5, 1)) - assert td.get("a").device == device - - # stacked TensorDict - td1 = TensorDict({}, [5]) - td2 = TensorDict({}, [5]) - stackedtd = stack_td([td1, td2], 0) - assert stackedtd.device is None - - stackedtd.set("a", torch.randn(2, 5, device=device)) - assert stackedtd.device is None - - stackedtd = stackedtd.to(device) - assert stackedtd.device == device - - td1 = TensorDict({}, [5], device="cuda:0") - td2 = TensorDict({}, [5], device="cuda:0") - stackedtd = stack_td([td1, td2], 0) - stackedtd.set("a", torch.randn(2, 5, 1)) - assert stackedtd.get("a").device == device - assert td1.get("a").device == device - assert td2.get("a").device == device - - # TensorDict, indexed - td = TensorDict({}, [5]) - subtd = td[1] - assert subtd.device is None - - subtd.set("a", torch.randn(1, device=device)) - # setting element of subtensordict doesn't set top-level device - assert subtd.device is None - - subtd = subtd.to(device) - assert subtd.device == device - assert subtd["a"].device == device - - td = TensorDict({}, [5], device="cuda:0") - subtd = td[1] - subtd.set("a", torch.randn(1)) - assert subtd.get("a").device == device - - td = TensorDict({}, [5], device="cuda:0") - subtd = td[1:3] - subtd.set("a", torch.randn(2)) - assert subtd.get("a").device == device - - # ViewedTensorDict - td = TensorDict({}, [6]) - viewedtd = td.view(2, 3) - assert viewedtd.device is None - - viewedtd = viewedtd.to(device) - assert viewedtd.device == device - - td = TensorDict({}, [6], device="cuda:0") - viewedtd = td.view(2, 3) - a = torch.randn(2, 3) - viewedtd.set("a", a) - assert viewedtd.get("a").device == device - assert (a.to(device) == viewedtd.get("a")).all() + @pytest.mark.filterwarnings("error") + def test_stack_tds_on_subclass(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + tds_count = td.batch_size[0] + tds_batch_size = td.batch_size[1:] + tds_list = [ + TensorDict( + source={ + "a": torch.ones(*tds_batch_size, 5), + "b": torch.ones(*tds_batch_size, 10), + "c": torch.ones(*tds_batch_size, 3, dtype=torch.long), + }, + batch_size=tds_batch_size, + device=device, + ) + for _ in range(tds_count) + ] + if td_name in ("sub_td", "sub_td2"): + with pytest.raises(IndexError, match="storages of the indexed tensors"): + torch.stack(tds_list, 0, out=td) + return + data_ptr_set_before = {val.data_ptr() for val in decompose(td)} + stacked_td = torch.stack(tds_list, 0, out=td) + data_ptr_set_after = {val.data_ptr() for val in decompose(td)} + assert data_ptr_set_before == data_ptr_set_after + assert stacked_td.batch_size == td.batch_size + assert stacked_td is td + for key in ("a", "b", "c"): + assert (stacked_td[key] == 1).all() -class TestMPInplace: - @classmethod - def _remote_process( - cls, worker_id, command_pipe_child, command_pipe_parent, tensordict - ): - command_pipe_parent.close() - while True: - cmd, val = command_pipe_child.recv() - if cmd == "recv": - b = tensordict.get("b") - assert (b == val).all() - command_pipe_child.send("done") - elif cmd == "send": - a = torch.ones(2) * val - tensordict.set_("a", a) - assert ( - tensordict.get("a") == a - ).all(), f'found {a} and {tensordict.get("a")}' - command_pipe_child.send("done") - elif cmd == "set_done": - tensordict.set_("done", torch.ones(1, dtype=torch.bool)) - command_pipe_child.send("done") - elif cmd == "set_undone_": - tensordict.set_("done", torch.zeros(1, dtype=torch.bool)) - command_pipe_child.send("done") - elif cmd == "update": - tensordict.update_( - TensorDict( - source={"a": tensordict.get("a").clone() + 1}, - batch_size=tensordict.batch_size, - ) - ) - command_pipe_child.send("done") - elif cmd == "update_": - tensordict.update_( - TensorDict( - source={"a": tensordict.get("a").clone() - 1}, - batch_size=tensordict.batch_size, - ) - ) - command_pipe_child.send("done") + def test_state_dict(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + td_zero.load_state_dict(sd) + assert_allclose_td(td, td_zero) - elif cmd == "close": - command_pipe_child.close() - break + def test_state_dict_assign(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + shallow_copy = td_zero.clone(False) + td_zero.load_state_dict(sd, assign=True) + assert (shallow_copy == 0).all() + assert_allclose_td(td, td_zero) - @classmethod - def _driver_func(cls, tensordict, tensordict_unbind): - procs = [] - children = [] - parents = [] + def test_state_dict_strict(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + sd = td.state_dict() + td_zero = td.clone().detach().zero_() + del sd["a"] + td_zero.load_state_dict(sd, strict=False) + with pytest.raises(RuntimeError): + td_zero.load_state_dict(sd, strict=True) - for i in range(2): - command_pipe_parent, command_pipe_child = mp.Pipe() - proc = mp.Process( - target=cls._remote_process, - args=(i, command_pipe_child, command_pipe_parent, tensordict_unbind[i]), - ) - proc.start() - command_pipe_child.close() - parents.append(command_pipe_parent) - children.append(command_pipe_child) - procs.append(proc) + def test_tensordict_set(self, td_name, device): + torch.manual_seed(1) + np.random.seed(1) + td = getattr(self, td_name)(device) + td.unlock_() - b = torch.ones(2, 1) * 10 - tensordict.set_("b", b) - for i in range(2): - parents[i].send(("recv", 10)) - is_done = parents[i].recv() - assert is_done == "done" + # test set + val1 = np.ones(shape=(4, 3, 2, 1, 10)) + td.set("key1", val1) + assert (td.get("key1") == 1).all() + with pytest.raises(RuntimeError): + td.set("key1", np.ones(shape=(5, 10))) - for i in range(2): - parents[i].send(("send", i)) - is_done = parents[i].recv() - assert is_done == "done" - a = tensordict.get("a").clone() - assert (a[0] == 0).all() - assert (a[1] == 1).all() + # test set_ + val2 = np.zeros(shape=(4, 3, 2, 1, 10)) + td.set_("key1", val2) + assert (td.get("key1") == 0).all() + if td_name not in ("stacked_td", "nested_stacked_td"): + err_msg = r"key.*smartypants.*not found in " + elif td_name in ("td_h5",): + err_msg = "Unable to open object" + else: + err_msg = "setting a value in-place on a stack of TensorDict" - assert not tensordict.get("done").any() - for i in range(2): - parents[i].send(("set_done", i)) - is_done = parents[i].recv() - assert is_done == "done" - assert tensordict.get("done").all() + with pytest.raises(KeyError, match=err_msg): + td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) - for i in range(2): - parents[i].send(("set_undone_", i)) - is_done = parents[i].recv() - assert is_done == "done" - assert not tensordict.get("done").any() + # test set_at_ + td.set("key2", np.random.randn(4, 3, 2, 1, 5)) + x = np.ones(shape=(2, 1, 5)) * 42 + td.set_at_("key2", x, (2, 2)) + assert (td.get("key2")[2, 2] == 42).all() - a_prev = tensordict.get("a").clone().contiguous() - for i in range(2): - parents[i].send(("update_", i)) - is_done = parents[i].recv() - assert is_done == "done" - new_a = tensordict.get("a").clone().contiguous() - torch.testing.assert_close(a_prev - 1, new_a) + def test_tensordict_set_dict_value(self, td_name, device): + torch.manual_seed(1) + np.random.seed(1) + td = getattr(self, td_name)(device) + td.unlock_() - a_prev = tensordict.get("a").clone().contiguous() - for i in range(2): - parents[i].send(("update", i)) - is_done = parents[i].recv() - assert is_done == "done" - new_a = tensordict.get("a").clone().contiguous() - torch.testing.assert_close(a_prev + 1, new_a) + # test set + val1 = {"subkey1": torch.ones(4, 3, 2, 1, 10)} + td.set("key1", val1) + assert (td.get("key1").get("subkey1") == 1).all() + with pytest.raises(RuntimeError): + td.set("key1", torch.ones(5, 10)) - for i in range(2): - parents[i].send(("close", None)) - procs[i].join() + # test set_ + val2 = {"subkey1": torch.zeros(4, 3, 2, 1, 10)} + if td_name in ("td_params",): + td.data.set_("key1", val2) + else: + td.set_("key1", val2) + assert (td.get("key1").get("subkey1") == 0).all() - @pytest.mark.parametrize( - "td_type", - [ - "memmap", - "memmap_stack", - "contiguous", - "stack", - ], - ) - def test_mp(self, td_type): - tensordict = TensorDict( - source={ - "a": torch.randn(2, 2), - "b": torch.randn(2, 1), - "done": torch.zeros(2, 1, dtype=torch.bool), - }, - batch_size=[2], - ) - if td_type == "contiguous": - tensordict = tensordict.share_memory_() - elif td_type == "stack": - tensordict = stack_td( - [ - tensordict[0].clone().share_memory_(), - tensordict[1].clone().share_memory_(), - ], - 0, - ) - elif td_type == "memmap": - tensordict = tensordict.memmap_() - elif td_type == "memmap_stack": - tensordict = stack_td( - [ - tensordict[0].clone().memmap_(), - tensordict[1].clone().memmap_(), - ], - 0, - ) + if td_name not in ("stacked_td", "nested_stacked_td"): + err_msg = r"key.*smartypants.*not found in " + elif td_name in ("td_h5",): + err_msg = "Unable to open object" else: - raise NotImplementedError - self._driver_func( - tensordict, - (tensordict._get_sub_tensordict(0), tensordict._get_sub_tensordict(1)) - # tensordict, - # tensordict.unbind(0), - ) + err_msg = "setting a value in-place on a stack of TensorDict" + with pytest.raises(KeyError, match=err_msg): + td.set_("smartypants", np.ones(shape=(4, 3, 2, 1, 5))) -@pytest.mark.parametrize( - "idx", - [ - (slice(None),), - slice(None), - (3, 4), - (3, slice(None), slice(2, 2, 2)), - (torch.tensor([1, 2, 3]),), - ([1, 2, 3]), - ( - torch.tensor([1, 2, 3]), - torch.tensor([2, 3, 4]), - torch.tensor([0, 10, 2]), - torch.tensor([2, 4, 1]), - ), - torch.zeros(10, 7, 11, 5, dtype=torch.bool).bernoulli_(), - torch.zeros(10, 7, 11, dtype=torch.bool).bernoulli_(), - (0, torch.zeros(7, dtype=torch.bool).bernoulli_()), - ], -) -def test_getitem_batch_size(idx): - shape = [10, 7, 11, 5] - shape = torch.Size(shape) - mocking_tensor = torch.zeros(*shape) - expected_shape = mocking_tensor[idx].shape - resulting_shape = _getitem_batch_size(shape, idx) - assert expected_shape == resulting_shape, (idx, expected_shape, resulting_shape) - - -@pytest.mark.parametrize("device", get_available_devices()) -def test_requires_grad(device): - torch.manual_seed(1) - # Just one of the tensors have requires_grad - tensordicts = [ - TensorDict( - batch_size=[11, 12], - source={ - "key1": torch.randn( - 11, 12, 5, device=device, requires_grad=True if i == 5 else False - ), - "key2": torch.zeros( - 11, 12, 50, device=device, dtype=torch.bool - ).bernoulli_(), - }, - ) - for i in range(10) - ] - stacked_td = LazyStackedTensorDict(*tensordicts, stack_dim=0) - # First stacked tensor has requires_grad == True - assert list(stacked_td.values())[0].requires_grad is True + def test_to_dict_nested(self, td_name, device): + def recursive_checker(cur_dict): + for _, value in cur_dict.items(): + if is_tensor_collection(value): + return False + elif isinstance(value, dict) and not recursive_checker(value): + return False + return True + td = getattr(self, td_name)(device) + td.unlock_() -@pytest.mark.parametrize("device", get_available_devices()) -@pytest.mark.parametrize( - "td_type", ["tensordict", "view", "unsqueeze", "squeeze", "stack"] -) -@pytest.mark.parametrize("update", [True, False]) -def test_filling_empty_tensordict(device, td_type, update): - if td_type == "tensordict": - td = TensorDict({}, batch_size=[16], device=device) - elif td_type == "view": - td = TensorDict({}, batch_size=[4, 4], device=device).view(-1) - elif td_type == "unsqueeze": - td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1) - elif td_type == "squeeze": - td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) - elif td_type == "stack": - td = torch.stack([TensorDict({}, [], device=device) for _ in range(16)], 0) - else: - raise NotImplementedError - - for i in range(16): - other_td = TensorDict({"a": torch.randn(10), "b": torch.ones(1)}, []) - if td_type == "unsqueeze": - other_td = other_td.unsqueeze(-1).to_tensordict() - if update: - subtd = td._get_sub_tensordict(i) - subtd.update(other_td, inplace=True) - else: - td[i] = other_td - - assert td.device == device - assert td.get("a").device == device - assert (td.get("b") == 1).all() - if td_type == "view": - assert td._source["a"].shape == torch.Size([4, 4, 10]) - elif td_type == "unsqueeze": - assert td._source["a"].shape == torch.Size([16, 10]) - elif td_type == "squeeze": - assert td._source["a"].shape == torch.Size([16, 1, 10]) - elif td_type == "stack": - assert (td[-1] == other_td.to(device)).all() - - -def test_getitem_nested(): - tensor = torch.randn(4, 5, 6, 7) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) - sub_tensordict = TensorDict({}, [4, 5]) - tensordict = TensorDict({}, [4]) - - sub_tensordict["b"] = sub_sub_tensordict - tensordict["a"] = sub_tensordict - - # check that content match - assert (tensordict["a"] == sub_tensordict).all() - assert (tensordict["a", "b"] == sub_sub_tensordict).all() - assert (tensordict["a", "b", "c"] == tensor).all() - - # check that get method returns same contents - assert (tensordict.get("a") == sub_tensordict).all() - assert (tensordict.get(("a", "b")) == sub_sub_tensordict).all() - assert (tensordict.get(("a", "b", "c")) == tensor).all() - - # check that shapes are kept - assert tensordict.shape == torch.Size([4]) - assert sub_tensordict.shape == torch.Size([4, 5]) - assert sub_sub_tensordict.shape == torch.Size([4, 5, 6]) - - -def test_setitem_nested(): - tensor = torch.randn(4, 5, 6, 7) - tensor2 = torch.ones(4, 5, 6, 7) - tensordict = TensorDict({}, [4]) - sub_tensordict = TensorDict({}, [4, 5]) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) - sub_sub_tensordict2 = TensorDict({"c": tensor2}, [4, 5, 6]) - sub_tensordict["b"] = sub_sub_tensordict - tensordict["a"] = sub_tensordict - assert tensordict["a", "b"] is sub_sub_tensordict - tensordict["a", "b"] = sub_sub_tensordict2 - assert tensordict["a", "b"] is sub_sub_tensordict2 - assert (tensordict["a", "b", "c"] == 1).all() - - # check the same with set method - sub_tensordict.set("b", sub_sub_tensordict) - tensordict.set("a", sub_tensordict) - assert tensordict["a", "b"] is sub_sub_tensordict - - tensordict.set(("a", "b"), sub_sub_tensordict2) - assert tensordict["a", "b"] is sub_sub_tensordict2 - assert (tensordict["a", "b", "c"] == 1).all() - - -def test_setdefault_nested(): - tensor = torch.randn(4, 5, 6, 7) - tensor2 = torch.ones(4, 5, 6, 7) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) - sub_tensordict = TensorDict({"b": sub_sub_tensordict}, [4, 5]) - tensordict = TensorDict({"a": sub_tensordict}, [4]) - - # if key exists we return the existing value - assert tensordict.setdefault(("a", "b", "c"), tensor2) is tensor - - assert tensordict.setdefault(("a", "b", "d"), tensor2) is tensor2 - assert (tensordict["a", "b", "d"] == 1).all() - assert tensordict.get(("a", "b", "d")) is tensor2 - - -@pytest.mark.parametrize("inplace", [True, False]) -def test_select_nested(inplace): - tensor_1 = torch.rand(4, 5, 6, 7) - tensor_2 = torch.rand(4, 5, 6, 7) - sub_sub_tensordict = TensorDict( - {"t1": tensor_1, "t2": tensor_2}, batch_size=[4, 5, 6] - ) - sub_tensordict = TensorDict( - {"double_nested": sub_sub_tensordict}, batch_size=[4, 5] - ) - tensordict = TensorDict( - { - "a": torch.rand(4, 3), - "b": torch.rand(4, 2), - "c": torch.rand(4, 1), - "nested": sub_tensordict, - }, - batch_size=[4], - ) + # Create nested TensorDict + nested_tensordict_value = TensorDict( + {"e": torch.randn(4, 3, 2, 1, 10)}, batch_size=td.batch_size, device=device + ) + td["d"] = nested_tensordict_value - selected = tensordict.select( - "b", ("nested", "double_nested", "t2"), inplace=inplace - ) + # Convert into dictionary and recursively check if the values are TensorDicts + td_dict = td.to_dict() + assert recursive_checker(td_dict) + if td_name == "td_with_non_tensor": + assert td_dict["data"]["non_tensor"] == "some text data" + assert (TensorDict.from_dict(td_dict) == td).all() - assert set(selected.keys(include_nested=True)) == { - "b", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t2"), - } + def test_to_tensordict(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td2 = td.to_tensordict() + assert (td2 == td).all() - if inplace: - assert selected is tensordict - assert set(tensordict.keys(include_nested=True)) == { - "b", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t2"), - } - else: - assert selected is not tensordict - assert set(tensordict.keys(include_nested=True)) == { - "a", - "b", - "c", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t1"), - ("nested", "double_nested", "t2"), - } + def test_transpose(self, td_name, device): + td = getattr(self, td_name)(device) + tdt = td.transpose(0, 1) + assert tdt.shape == torch.Size([td.shape[1], td.shape[0], *td.shape[2:]]) + for key, value in tdt.items(True): + assert value.shape == torch.Size( + [td.get(key).shape[1], td.get(key).shape[0], *td.get(key).shape[2:]] + ) + tdt = td.transpose(-1, -2) + for key, value in tdt.items(True): + assert value.shape == td.get(key).transpose(2, 3).shape + if td_name in ("td_params",): + assert tdt.transpose(-1, -2)._param_td is td._param_td + else: + assert tdt.transpose(-1, -2) is td + with td.unlock_(): + tdt.set(("some", "transposed", "tensor"), torch.zeros(tdt.shape)) + assert td.get(("some", "transposed", "tensor")).shape == td.shape + if td_name in ("td_params",): + assert td.transpose(0, 0)._param_td is td._param_td + else: + assert td.transpose(0, 0) is td + with pytest.raises( + ValueError, match="The provided dimensions are incompatible" + ): + td.transpose(-5, -6) + with pytest.raises( + ValueError, match="The provided dimensions are incompatible" + ): + tdt.transpose(-5, -6) + @pytest.mark.parametrize("dim", range(4)) + def test_unbind(self, td_name, device, dim): + if td_name not in ["sub_td", "idx_td", "td_reset_bs"]: + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td_unbind = torch.unbind(td, dim=dim) + assert (td == stack_td(td_unbind, dim).contiguous()).all() + idx = (slice(None),) * dim + (0,) + assert (td[idx] == td_unbind[0]).all() -def test_select_nested_missing(): - # checks that we keep a nested key even if missing nested keys are present - td = TensorDict({"a": {"b": [1], "c": [2]}}, []) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.parametrize("separator", [",", "-"]) + def test_unflatten_keys(self, td_name, device, inplace, separator): + td = getattr(self, td_name)(device) + locked = td.is_locked + td.unlock_() + nested_nested_tensordict = TensorDict( + { + "a": torch.zeros(*td.shape, 2, 3), + }, + [*td.shape, 2], + ) + nested_tensordict = TensorDict( + { + "a": torch.zeros(*td.shape, 2), + "nested_nested_tensordict": nested_nested_tensordict, + }, + td.shape, + ) + td["nested_tensordict"] = nested_tensordict - td_select = td.select(("a", "b"), "r", ("a", "z"), strict=False) - assert ("a", "b") in list(td_select.keys(True, True)) - assert ("a", "b") in td_select.keys(True, True) + if inplace and locked: + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + td_flatten.lock_() + with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): + td_unflatten = td_flatten.unflatten_keys( + inplace=inplace, separator=separator + ) + return + else: + if locked: + td.lock_() + if td_name in ("td_h5",) and inplace: + with pytest.raises( + ValueError, + match="Cannot call flatten_keys in_place with a PersistentTensorDict", + ): + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + return + td_flatten = td.flatten_keys(inplace=inplace, separator=separator) + td_unflatten = td_flatten.unflatten_keys( + inplace=inplace, separator=separator + ) + assert (td == td_unflatten).all() + if inplace: + assert td is td_unflatten + def test_unlock(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td.unlock_() + assert not td.is_locked + if td.device is not None: + assert td.device.type == "cuda" or not td.is_shared() + else: + assert not td.is_shared() + assert not td.is_memmap() -@pytest.mark.parametrize("inplace", [True, False]) -def test_exclude_nested(inplace): - tensor_1 = torch.rand(4, 5, 6, 7) - tensor_2 = torch.rand(4, 5, 6, 7) - sub_sub_tensordict = TensorDict( - {"t1": tensor_1, "t2": tensor_2}, batch_size=[4, 5, 6] - ) - sub_tensordict = TensorDict( - {"double_nested": sub_sub_tensordict}, batch_size=[4, 5] - ) - tensordict = TensorDict( - { - "a": torch.rand(4, 3), - "b": torch.rand(4, 2), - "c": torch.rand(4, 1), - "nested": sub_tensordict, - }, - batch_size=[4], - ) - # making a copy for inplace tests - tensordict2 = tensordict.clone() + @pytest.mark.parametrize("squeeze_dim", [0, 1]) + def test_unsqueeze(self, td_name, device, squeeze_dim): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + with td.unlock_(): # make sure that the td is not locked + td_unsqueeze = torch.unsqueeze(td, dim=squeeze_dim) + tensor = torch.ones_like(td.get("a").unsqueeze(squeeze_dim)) + if td_name in ("sub_td", "sub_td2"): + td_unsqueeze.set_("a", tensor) + else: + td_unsqueeze.set("a", tensor) + assert (td_unsqueeze.get("a") == tensor).all() + assert (td.get("a") == tensor.squeeze(squeeze_dim)).all() + # the tensors should match + assert _compare_tensors_identity(td_unsqueeze.squeeze(squeeze_dim), td) + assert (td_unsqueeze.get("a") == 1).all() + assert (td.get("a") == 1).all() - excluded = tensordict.exclude( - "b", ("nested", "double_nested", "t2"), inplace=inplace - ) - - assert set(excluded.keys(include_nested=True)) == { - "a", - "c", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t1"), - } - - if inplace: - assert excluded is tensordict - assert set(tensordict.keys(include_nested=True)) == { - "a", - "c", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t1"), - } - else: - assert excluded is not tensordict - assert set(tensordict.keys(include_nested=True)) == { - "a", - "b", - "c", - "nested", - ("nested", "double_nested"), - ("nested", "double_nested", "t1"), - ("nested", "double_nested", "t2"), - } - - # excluding "nested" should exclude all subkeys also - excluded2 = tensordict2.exclude("nested", inplace=inplace) - assert set(excluded2.keys(include_nested=True)) == {"a", "b", "c"} + @pytest.mark.parametrize("clone", [True, False]) + def test_update(self, td_name, device, clone): + td = getattr(self, td_name)(device) + td.unlock_() # make sure that the td is not locked + keys = set(td.keys()) + td.update({"x": torch.zeros(td.shape)}, clone=clone) + assert set(td.keys()) == keys.union({"x"}) + # now with nested: using tuples for keys + td.update({("somenested", "z"): torch.zeros(td.shape)}) + assert td["somenested"].shape == td.shape + assert td["somenested", "z"].shape == td.shape + td.update({("somenested", "zz"): torch.zeros(td.shape)}) + assert td["somenested"].shape == td.shape + assert td["somenested", "zz"].shape == td.shape + # now with nested: using nested dicts + td["newnested"] = {"z": torch.zeros(td.shape)} + keys = set(td.keys(True)) + assert ("newnested", "z") in keys + td.update({"newnested": {"y": torch.zeros(td.shape)}}, clone=clone) + keys = keys.union({("newnested", "y")}) + assert keys == set(td.keys(True)) + td.update( + { + ("newnested", "x"): torch.zeros(td.shape), + ("newnested", "w"): torch.zeros(td.shape), + }, + clone=clone, + ) + keys = keys.union({("newnested", "x"), ("newnested", "w")}) + assert keys == set(td.keys(True)) + td.update({("newnested",): {"v": torch.zeros(td.shape)}}, clone=clone) + keys = keys.union( + { + ("newnested", "v"), + } + ) + assert keys == set(td.keys(True)) + if td_name in ("sub_td", "sub_td2"): + with pytest.raises(ValueError, match="Tried to replace a tensordict with"): + td.update({"newnested": torch.zeros(td.shape)}, clone=clone) + else: + td.update({"newnested": torch.zeros(td.shape)}, clone=clone) + assert isinstance(td["newnested"], torch.Tensor) -def test_set_nested_keys(): - tensor = torch.randn(4, 5, 6, 7) - tensor2 = torch.ones(4, 5, 6, 7) - tensordict = TensorDict({}, [4]) - sub_tensordict = TensorDict({}, [4, 5]) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) - sub_sub_tensordict2 = TensorDict({"c": tensor2}, [4, 5, 6]) - sub_tensordict.set("b", sub_sub_tensordict) - tensordict.set("a", sub_tensordict) - assert tensordict.get(("a", "b")) is sub_sub_tensordict + def test_update_at_(self, td_name, device): + td = getattr(self, td_name)(device) + td0 = td[1].clone().zero_() + td.update_at_(td0, 0) + assert (td[0] == 0).all() - tensordict.set(("a", "b"), sub_sub_tensordict2) - assert tensordict.get(("a", "b")) is sub_sub_tensordict2 - assert (tensordict.get(("a", "b", "c")) == 1).all() + def test_update_select(self, td_name, device): + if td_name in ("memmap_td",): + pytest.skip(reason="update not possible with memory-mapped td") + td = getattr(self, td_name)(device) + t = lambda: torch.zeros(()).expand((4, 3, 2, 1)) + other_td = TensorDict( + { + "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, + "self-improving": t(), + }, + batch_size=(4, 3, 2, 1), + ) + td.update( + other_td, + keys_to_update=(("My", ("father",), "was"), ("My", "relentlessly")), + ) + assert ("My", "father", "was") in td.keys(True) + assert ("My", ("father",), "was") in td.keys(True) + assert ("My", "relentlessly") in td.keys(True) + assert ("My", "father", "a") in td.keys(True) + assert ("self-improving",) not in td.keys(True) + t = lambda: torch.ones(()).expand((4, 3, 2, 1)) + other_td = TensorDict( + { + "My": {"father": {"was": t(), "a": t()}, "relentlessly": t()}, + "self-improving": t(), + }, + batch_size=(4, 3, 2, 1), + ) + td.update(other_td, keys_to_update=(("My", "relentlessly"),)) + assert (td["My", "relentlessly"] == 1).all() + assert (td["My", "father", "was"] == 0).all() + td.update(other_td, keys_to_update=(("My", ("father",), "was"),)) + assert (td["My", "father", "was"] == 1).all() + @pytest.mark.parametrize( + "index", ["tensor1", "mask", "int", "range", "tensor2", "slice_tensor"] + ) + def test_update_subtensordict(self, td_name, device, index): + td = getattr(self, td_name)(device) + if index == "mask": + index = torch.zeros(td.shape[0], dtype=torch.bool, device=device) + index[-1] = 1 + elif index == "int": + index = td.shape[0] - 1 + elif index == "range": + index = range(td.shape[0] - 1, td.shape[0]) + elif index == "tensor1": + index = torch.tensor(td.shape[0] - 1, device=device) + elif index == "tensor2": + index = torch.tensor([td.shape[0] - 2, td.shape[0] - 1], device=device) + elif index == "slice_tensor": + index = ( + slice(None), + torch.tensor([td.shape[1] - 2, td.shape[1] - 1], device=device), + ) -def test_keys_view(): - tensor = torch.randn(4, 5, 6, 7) - sub_sub_tensordict = TensorDict({"c": tensor}, [4, 5, 6]) - sub_tensordict = TensorDict({}, [4, 5]) - tensordict = TensorDict({}, [4]) + sub_td = td._get_sub_tensordict(index) + assert sub_td.shape == td.to_tensordict()[index].shape + assert sub_td.shape == td[index].shape, (td, index) + td0 = td[index] + td0 = td0.to_tensordict() + td0 = td0.apply(lambda x: x * 0 + 2) + assert sub_td.shape == td0.shape + if td_name == "td_params": + with pytest.raises(RuntimeError, match="a leaf Variable"): + sub_td.update(td0) + return + sub_td.update(td0) + assert (sub_td == 2).all() + assert (td[index] == 2).all() - sub_tensordict["b"] = sub_sub_tensordict - tensordict["a"] = sub_tensordict + def test_view(self, td_name, device): + if td_name in ("permute_td", "sub_td2"): + pytest.skip("view incompatible with stride / permutation") + torch.manual_seed(1) + td = getattr(self, td_name)(device) + with td.unlock_(): # make sure that the td is not locked + td_view = td.view(-1) + tensor = td.get("a") + tensor = tensor.view(-1, tensor.numel() // prod(td.batch_size)) + tensor = torch.ones_like(tensor) + if td_name == "sub_td": + td_view.set_("a", tensor) + else: + td_view.set("a", tensor) + assert (td_view.get("a") == tensor).all() + assert (td.get("a") == tensor.view(td.get("a").shape)).all() + if td_name in ("td_params",): + assert td_view.view(td.shape)._param_td is td._param_td + assert td_view.view(*td.shape)._param_td is td._param_td + else: + assert td_view.view(td.shape) is td + assert td_view.view(*td.shape) is td + assert (td_view.get("a") == 1).all() + assert (td.get("a") == 1).all() - assert "a" in tensordict.keys() - assert "random_string" not in tensordict.keys() + def test_where(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() + td_where = torch.where(mask, td, 0) + for k in td.keys(True, True): + assert (td_where.get(k)[~mask] == 0).all() + td_where = torch.where(mask, td, torch.ones_like(td)) + for k in td.keys(True, True): + assert (td_where.get(k)[~mask] == 1).all() + td_where = td.clone() - assert ("a",) in tensordict.keys(include_nested=True) - assert ("a", "b", "c") in tensordict.keys(include_nested=True) - assert ("a", "c", "b") not in tensordict.keys(include_nested=True) + if td_name == "td_h5": + with pytest.raises( + RuntimeError, + match="Cannot use a persistent tensordict as output of torch.where", + ): + torch.where(mask, td, torch.ones_like(td), out=td_where) + return + torch.where(mask, td, torch.ones_like(td), out=td_where) + for k in td.keys(True, True): + assert (td_where.get(k)[~mask] == 1).all() - with pytest.raises( - TypeError, match="checks with tuples of strings is only supported" - ): - ("a", "b", "c") in tensordict.keys() # noqa: B015 + def test_where_pad(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + # test with other empty td + mask = torch.zeros(td.shape, dtype=torch.bool, device=td.device).bernoulli_() + if td_name in ("td_h5",): + td_full = td.to_tensordict() + else: + td_full = td + td_empty = td_full.empty() + result = td.where(mask, td_empty, pad=1) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + td_empty = td_full.empty() + result = td_empty.where(~mask, td, pad=1) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + # with output + td_out = td_full.empty() + result = td.where(mask, td_empty, pad=1, out=td_out) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + if td_name not in ("td_params",): + assert result is td_out + # TODO: decide if we want where to return a TensorDictParams. + # probably not, given + # else: + # assert isinstance(result, TensorDictParams) + td_out = td_full.empty() + td_empty = td_full.empty() + result = td_empty.where(~mask, td, pad=1, out=td_out) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + assert result is td_out - with pytest.raises(TypeError, match="TensorDict keys are always strings."): - 42 in tensordict.keys() # noqa: B015 + with pytest.raises(KeyError, match="not found and no pad value provided"): + td.where(mask, td_full.empty()) + with pytest.raises(KeyError, match="not found and no pad value provided"): + td_full.empty().where(mask, td) - with pytest.raises(TypeError, match="TensorDict keys are always strings."): - ("a", 42) in tensordict.keys() # noqa: B015 + def test_write_on_subtd(self, td_name, device): + td = getattr(self, td_name)(device) + sub_td = td._get_sub_tensordict(0) + # should not work with td_params + if td_name == "td_params": + with pytest.raises(RuntimeError, match="a view of a leaf"): + sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) + return + sub_td["a"] = torch.full((3, 2, 1, 5), 1.0, device=device) + assert (td["a"][0] == 1).all() - keys = set(tensordict.keys()) - keys_nested = set(tensordict.keys(include_nested=True)) + def test_zero_(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + new_td = td.zero_() + assert new_td is td + for k in td.keys(): + assert (td.get(k) == 0).all() - assert keys == {"a"} - assert keys_nested == {"a", ("a", "b"), ("a", "b", "c")} - leaves = set(tensordict.keys(leaves_only=True)) - leaves_nested = set(tensordict.keys(include_nested=True, leaves_only=True)) +@pytest.mark.parametrize("device", [None, *get_available_devices()]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) +class TestTensorDictRepr: + def memmap_td(self, device, dtype): + if device is not None and device.type != "cpu": + pytest.skip("MemoryMappedTensors can only be placed on CPU.") + return self.td(device, dtype).memmap_() - assert leaves == set() - assert leaves_nested == {("a", "b", "c")} + def nested_td(self, device, dtype): + if device is not None: + device_not_none = device + elif torch.has_cuda and torch.cuda.device_count(): + device_not_none = torch.device("cuda:0") + else: + device_not_none = torch.device("cpu") + return TensorDict( + source={ + "my_nested_td": self.td(device, dtype), + "b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none), + }, + batch_size=[4, 3, 2, 1], + device=device, + ) + def nested_tensorclass(self, device, dtype): + from tensordict import tensorclass -def test_error_on_contains(): - td = TensorDict( - {"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)}, [1] - ) - with pytest.raises( - NotImplementedError, - match="TensorDict does not support membership checks with the `in` keyword", - ): - "random_string" in td # noqa: B015 + @tensorclass + class MyClass: + X: torch.Tensor + y: "MyClass" + z: str + if device is not None: + device_not_none = device + elif torch.has_cuda and torch.cuda.device_count(): + device_not_none = torch.device("cuda:0") + else: + device_not_none = torch.device("cpu") + nested_class = MyClass( + X=torch.zeros(4, 3, 2, 1, dtype=dtype, device=device_not_none), + y=MyClass( + X=torch.zeros(4, 3, 2, 1, dtype=dtype, device=device_not_none), + y=None, + z=None, + batch_size=[4, 3, 2, 1], + ), + z="z", + batch_size=[4, 3, 2, 1], + ) + return TensorDict( + source={ + "my_nested_td": nested_class, + "b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none), + }, + batch_size=[4, 3, 2, 1], + device=device, + ) -@pytest.mark.parametrize("method", ["share_memory", "memmap"]) -def test_memory_lock(method): - torch.manual_seed(1) - td = TensorDict({"a": torch.randn(4, 5)}, batch_size=(4, 5)) + def share_memory_td(self, device, dtype): + return self.td(device, dtype).share_memory_() - # lock=True - if method == "share_memory": - td.share_memory_() - elif method == "memmap": - td.memmap_() - else: - raise NotImplementedError - - td.set("a", torch.randn(4, 5), inplace=True) - td.set_("a", torch.randn(4, 5)) # No exception because set_ ignores the lock - - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td.set("a", torch.randn(4, 5)) - - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td.set("b", torch.randn(4, 5)) - - with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): - td.set("b", torch.randn(4, 5), inplace=True) - - -class TestMakeTensorDict: - def test_create_tensordict(self): - tensordict = make_tensordict(a=torch.zeros(3, 4)) - assert (tensordict["a"] == torch.zeros(3, 4)).all() - - def test_tensordict_batch_size(self): - tensordict = make_tensordict() - assert tensordict.batch_size == torch.Size([]) - - tensordict = make_tensordict(a=torch.randn(3, 4)) - assert tensordict.batch_size == torch.Size([3, 4]) - - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5)) - assert tensordict.batch_size == torch.Size([3, 4]) - - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5)) # nested - assert nested_tensordict.batch_size == torch.Size([3]) - - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5)) # nested - assert nested_tensordict.batch_size == torch.Size([]) - - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5)) - assert tensordict.batch_size == torch.Size([3, 4]) - - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1)) - assert tensordict.batch_size == torch.Size([]) - - tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4, 5), batch_size=[3] - ) - assert tensordict.batch_size == torch.Size([3]) - - tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4, 5), batch_size=[] - ) - assert tensordict.batch_size == torch.Size([]) - - @pytest.mark.parametrize("device", get_available_devices()) - def test_tensordict_device(self, device): - tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device - ) - assert tensordict.device == device - assert tensordict["a"].device == device - assert tensordict["b"].device == device - - tensordict = make_tensordict( - a=torch.randn(3, 4, device=device), - b=torch.randn(3, 4), - c=torch.randn(3, 4, device="cpu"), + def stacked_td(self, device, dtype): + if device is not None: + device_not_none = device + elif torch.has_cuda and torch.cuda.device_count(): + device_not_none = torch.device("cuda:0") + else: + device_not_none = torch.device("cpu") + td1 = TensorDict( + source={ + "a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), + "c": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), + }, + batch_size=[4, 3, 1], device=device, ) - assert tensordict.device == device - assert tensordict["a"].device == device - assert tensordict["b"].device == device - assert tensordict["c"].device == device - - def test_nested(self): - input_dict = { - "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, - "d": torch.randn(3), - } - tensordict = make_tensordict(input_dict) - assert tensordict.shape == torch.Size([3]) - assert tensordict["a"].shape == torch.Size([3, 4]) - input_tensordict = TensorDict( - { - "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, - "d": torch.randn(3), + td2 = TensorDict( + source={ + "a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none), + "b": torch.zeros(4, 3, 1, 10, dtype=dtype, device=device_not_none), }, - [], + batch_size=[4, 3, 1], + device=device, ) - tensordict = make_tensordict(input_tensordict) - assert tensordict.shape == torch.Size([3]) - assert tensordict["a"].shape == torch.Size([3, 4]) - input_dict = { - ("a", "b"): torch.randn(3, 4), - ("a", "c"): torch.randn(3, 4, 5), - "d": torch.randn(3), - } - tensordict = make_tensordict(input_dict) - assert tensordict.shape == torch.Size([3]) - assert tensordict["a"].shape == torch.Size([3, 4]) - - -def test_update_nested_dict(): - t = TensorDict({"a": {"d": [[[0]] * 3] * 2}}, [2, 3]) - assert ("a", "d") in t.keys(include_nested=True) - t.update({"a": {"b": [[[1]] * 3] * 2}}) - assert ("a", "d") in t.keys(include_nested=True) - assert ("a", "b") in t.keys(include_nested=True) - assert t["a", "b"].shape == torch.Size([2, 3, 1]) - t.update({"a": {"d": [[[1]] * 3] * 2}}) - - -@pytest.mark.parametrize("inplace", [True, False]) -@pytest.mark.parametrize("separator", [",", "-"]) -def test_flatten_unflatten_key_collision(inplace, separator): - td1 = TensorDict( - { - f"a{separator}b{separator}c": torch.zeros(3), - "a": {"b": {"c": torch.zeros(3)}}, - }, - [], - ) - td2 = TensorDict( - { - f"a{separator}b": torch.zeros(3), - "a": {"b": torch.zeros(3)}, - "g": {"d": torch.zeros(3)}, - }, - [], - ) - td3 = TensorDict( - { - f"a{separator}b{separator}c": torch.zeros(3), - "a": {"b": {"c": torch.zeros(3), "d": torch.zeros(3)}}, - }, - [], - ) - - td4 = TensorDict( - { - f"a{separator}b{separator}c{separator}d": torch.zeros(3), - "a": {"b": {"c": torch.zeros(3)}}, - }, - [], - ) - - td5 = TensorDict( - {f"a{separator}b": torch.zeros(3), "a": {"b": {"c": torch.zeros(3)}}}, [] - ) - - with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): - _ = td1.flatten_keys(separator) - - with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): - _ = td2.flatten_keys(separator) - - with pytest.raises(KeyError, match="Flattening keys in tensordict causes keys"): - _ = td3.flatten_keys(separator) - - with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override an existing for unflattened key" - ), - ): - _ = td1.unflatten_keys(separator) - - with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override an existing for unflattened key" - ), - ): - _ = td2.unflatten_keys(separator) - - with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override an existing for unflattened key" - ), - ): - _ = td3.unflatten_keys(separator) - - with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override an existing for unflattened key" - ), - ): - _ = td4.unflatten_keys(separator) - - with pytest.raises( - KeyError, - match=re.escape( - "Unflattening key(s) in tensordict will override an existing for unflattened key" - ), - ): - _ = td5.unflatten_keys(separator) - - td4_flat = td4.flatten_keys(separator) - assert (f"a{separator}b{separator}c{separator}d") in td4_flat.keys() - assert (f"a{separator}b{separator}c") in td4_flat.keys() - - td5_flat = td5.flatten_keys(separator) - assert (f"a{separator}b") in td5_flat.keys() - assert (f"a{separator}b{separator}c") in td5_flat.keys() - - -def test_split_with_invalid_arguments(): - td = TensorDict({"a": torch.zeros(2, 1)}, []) - # Test empty batch size - with pytest.raises(IndexError, match="Dimension out of range"): - td.split(1, 0) - - td = TensorDict({}, [3, 2]) - - # Test invalid split_size input - with pytest.raises(TypeError, match="must be int or list of ints"): - td.split("1", 0) - with pytest.raises(TypeError, match="must be int or list of ints"): - td.split(["1", 2], 0) - # Test invalid split_size sum - with pytest.raises( - RuntimeError, match="Insufficient number of elements in split_size" - ): - td.split([], 0) - - with pytest.raises(RuntimeError, match="expects split_size to sum exactly"): - td.split([1, 1], 0) - - # Test invalid dimension input - with pytest.raises(IndexError, match="Dimension out of range"): - td.split(1, 2) - with pytest.raises(IndexError, match="Dimension out of range"): - td.split(1, -3) - - -def test_split_with_empty_tensordict(): - td = TensorDict({}, [10]) - - tds = td.split(4, 0) - assert len(tds) == 3 - assert tds[0].shape == torch.Size([4]) - assert tds[1].shape == torch.Size([4]) - assert tds[2].shape == torch.Size([2]) + return stack_td([td1, td2], 2) - tds = td.split([1, 9], 0) + def td(self, device, dtype): + if device is not None: + device_not_none = device + elif torch.has_cuda and torch.cuda.device_count(): + device_not_none = torch.device("cuda:0") + else: + device_not_none = torch.device("cpu") - assert len(tds) == 2 - assert tds[0].shape == torch.Size([1]) - assert tds[1].shape == torch.Size([9]) + return TensorDict( + source={ + "a": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none) + }, + batch_size=[4, 3, 2, 1], + device=device, + ) - td = TensorDict({}, [10, 10, 3]) + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") + def test_repr_batch_size_update(self, device, dtype): + td = self.td(device, dtype) + td.batch_size = torch.Size([4, 3, 2]) + is_shared = False + tensor_class = "Tensor" + if device is not None and device.type == "cuda": + is_shared = True + tensor_device = device if device else td["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2]), + device={device}, + is_shared={is_shared})""" + assert repr(td) == expected - tds = td.split(4, 1) - assert len(tds) == 3 - assert tds[0].shape == torch.Size([10, 4, 3]) - assert tds[1].shape == torch.Size([10, 4, 3]) - assert tds[2].shape == torch.Size([10, 2, 3]) + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") + @pytest.mark.parametrize("device_cast", get_available_devices()) + def test_repr_device_to_device(self, device, dtype, device_cast): + td = self.td(device, dtype) + if (device_cast is None and (torch.cuda.device_count() > 0)) or ( + device_cast is not None and device_cast.type == "cuda" + ): + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + td2 = td.to(device_cast) + tensor_device = device_cast if device_cast else td2["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device_cast)}, + is_shared={is_shared})""" + assert repr(td2) == expected - tds = td.split([1, 9], 1) - assert len(tds) == 2 - assert tds[0].shape == torch.Size([10, 1, 3]) - assert tds[1].shape == torch.Size([10, 9, 3]) + @pytest.mark.parametrize("index", [None, (slice(None), 0)]) + def test_repr_indexed_nested_tensordict(self, device, dtype, index): + nested_tensordict = self.nested_td(device, dtype)[index] + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else nested_tensordict["b"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + if index is None: + expected = f"""TensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([1, 4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})}}, + batch_size=torch.Size([1, 4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + else: + expected = f"""TensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 2, 1]), + device={str(device)}, + is_shared={is_shared})}}, + batch_size=torch.Size([4, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(nested_tensordict) == expected + @pytest.mark.parametrize("index", [None, (slice(None), 0)]) + def test_repr_indexed_stacked_tensordict(self, device, dtype, index): + stacked_tensordict = self.stacked_td(device, dtype) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else stacked_tensordict["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared -def test_split_with_negative_dim(): - td = TensorDict({"a": torch.zeros(5, 4, 2, 1), "b": torch.zeros(5, 4, 1)}, [5, 4]) + expected = f"""LazyStackedTensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + exclusive_fields={{ + 0 -> + c: {tensor_class}(shape=torch.Size([4, 3, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + 1 -> + b: {tensor_class}(shape=torch.Size([4, 3, 1, 10]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim={stacked_tensordict.stack_dim})""" - tds = td.split([1, 3], -1) - assert len(tds) == 2 - assert tds[0].shape == torch.Size([5, 1]) - assert tds[0]["a"].shape == torch.Size([5, 1, 2, 1]) - assert tds[0]["b"].shape == torch.Size([5, 1, 1]) - assert tds[1].shape == torch.Size([5, 3]) - assert tds[1]["a"].shape == torch.Size([5, 3, 2, 1]) - assert tds[1]["b"].shape == torch.Size([5, 3, 1]) + assert repr(stacked_tensordict) == expected + @pytest.mark.parametrize("index", [None, (slice(None), 0)]) + def test_repr_indexed_tensordict(self, device, dtype, index): + tensordict = self.td(device, dtype)[index] + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else tensordict["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + if index is None: + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([1, 4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([1, 4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + else: + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" -def test_shared_inheritance(): - td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) - td.share_memory_() + assert repr(tensordict) == expected - td0, *_ = td.unbind(1) - assert td0.is_shared() + def test_repr_memmap(self, device, dtype): + tensordict = self.memmap_td(device, dtype) + # tensor_device = device if device else tensordict["a"].device # noqa: F841 + expected = f"""TensorDict( + fields={{ + a: MemoryMappedTensor(shape=torch.Size([4, 3, 2, 1, 5]), device=cpu, dtype={dtype}, is_shared=False)}}, + batch_size=torch.Size([4, 3, 2, 1]), + device=cpu, + is_shared=False)""" + assert repr(tensordict) == expected - td0, *_ = td.split(1, 0) - assert td0.is_shared() + def test_repr_nested(self, device, dtype): + nested_td = self.nested_td(device, dtype) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else nested_td["b"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(nested_td) == expected - td0 = td.exclude("a") - assert td0.is_shared() + def test_repr_nested_update(self, device, dtype): + nested_td = self.nested_td(device, dtype) + nested_td["my_nested_td"].rename_key_("a", "z") + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else nested_td["b"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: TensorDict( + fields={{ + z: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(nested_td) == expected - td0 = td.select("a") - assert td0.is_shared() + def test_repr_plain(self, device, dtype): + tensordict = self.td(device, dtype) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_device = device if device else tensordict["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + a: Tensor(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(tensordict) == expected - td.unlock_() - td0 = td.rename_key_("a", "a.a") - assert not td0.is_shared() - td.share_memory_() + def test_repr_share_memory(self, device, dtype): + tensordict = self.share_memory_td(device, dtype) + is_shared = True + tensor_class = "Tensor" + tensor_device = device if device else tensordict["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""TensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared})""" + assert repr(tensordict) == expected - td0 = td.unflatten_keys(".") - assert td0.is_shared() + def test_repr_stacked(self, device, dtype): + stacked_td = self.stacked_td(device, dtype) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else stacked_td["a"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""LazyStackedTensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 3, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + exclusive_fields={{ + 0 -> + c: {tensor_class}(shape=torch.Size([4, 3, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + 1 -> + b: {tensor_class}(shape=torch.Size([4, 3, 1, 10]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 3, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim={stacked_td.stack_dim})""" + assert repr(stacked_td) == expected - td0 = td.flatten_keys(".") - assert td0.is_shared() + def test_repr_stacked_het(self, device, dtype): + stacked_td = torch.stack( + [ + TensorDict( + { + "a": torch.zeros(3, dtype=dtype), + "b": torch.zeros(2, 3, dtype=dtype), + }, + [], + device=device, + ), + TensorDict( + { + "a": torch.zeros(2, dtype=dtype), + "b": torch.zeros((), dtype=dtype), + }, + [], + device=device, + ), + ] + ) + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_device = device if device else torch.device("cpu") + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""LazyStackedTensorDict( + fields={{ + a: Tensor(shape=torch.Size([2, -1]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + b: Tensor(shape=torch.Size([-1]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + exclusive_fields={{ + }}, + batch_size=torch.Size([2]), + device={str(device)}, + is_shared={is_shared}, + stack_dim={stacked_td.stack_dim})""" + assert repr(stacked_td) == expected - td0 = td.view(-1) - assert td0.is_shared() - td0 = td.permute(1, 0) - assert td0.is_shared() +@pytest.mark.parametrize( + "td_name", + [ + "td", + "stacked_td", + "sub_td", + "idx_td", + "unsqueezed_td", + "td_reset_bs", + ], +) +@pytest.mark.parametrize( + "device", + get_available_devices(), +) +class TestTensorDictsRequiresGrad: + def idx_td(self, device): + return self.td(device)[0] - td0 = td.unsqueeze(0) - assert td0.is_shared() + def stacked_td(self, device): + return stack_td([self.td(device) for _ in range(2)], 0) - td0 = td0.squeeze(0) - assert td0.is_shared() + def sub_td(self, device): + return self.td(device)._get_sub_tensordict(0) + def test_clone_td(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert torch.clone(td).get("b").requires_grad -class TestLazyStackedTensorDict: - @staticmethod - def nested_lazy_het_td(batch_size): - shared = torch.zeros(4, 4, 2) - hetero_3d = torch.zeros(3) - hetero_2d = torch.zeros(2) + def test_expand(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + batch_size = td.batch_size + new_td = td.expand(3, *batch_size) + assert new_td.get("b").requires_grad + assert new_td.batch_size == torch.Size([3, *batch_size]) - individual_0_tensor = torch.zeros(1) - individual_1_tensor = torch.zeros(1, 2) - individual_2_tensor = torch.zeros(1, 2, 3) + def test_squeeze(self, td_name, device, squeeze_dim=-1): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + assert torch.squeeze(td, dim=-1).get("b").requires_grad - td_list = [ - TensorDict( - { - "shared": shared, - "hetero": hetero_3d, - "individual_0_tensor": individual_0_tensor, - }, - [], - ), - TensorDict( - { - "shared": shared, - "hetero": hetero_3d, - "individual_1_tensor": individual_1_tensor, - }, - [], - ), - TensorDict( - { - "shared": shared, - "hetero": hetero_2d, - "individual_2_tensor": individual_2_tensor, - }, - [], - ), - ] - for i, td in enumerate(td_list): - td[f"individual_{i}_td"] = td.clone() - td["shared_td"] = td.clone() + def test_view(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + td_view = td.view(-1) + assert td_view.get("b").requires_grad - td_stack = torch.stack(td_list, dim=0) - obs = TensorDict( - {"lazy": td_stack, "dense": torch.zeros(3, 3, 2)}, - [], + def td(self, device): + return TensorDict( + source={ + "a": torch.randn(3, 1, 5, device=device), + "b": torch.randn(3, 1, 10, device=device, requires_grad=True), + "c": torch.randint(10, (3, 1, 3), device=device), + }, + batch_size=[3, 1], ) - obs = obs.expand(batch_size).clone() - return obs - - @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) - @pytest.mark.parametrize("cat_dim", [0, 1, 2]) - def test_cat_lazy_stack(self, batch_size, cat_dim): - if cat_dim > len(batch_size): - return - td_lazy = self.nested_lazy_het_td(batch_size)["lazy"] - assert isinstance(td_lazy, LazyStackedTensorDict) - res = torch.cat([td_lazy], dim=cat_dim) - assert assert_allclose_td(res, td_lazy) - assert res is not td_lazy - td_lazy_clone = td_lazy.clone() - data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy)} - res = torch.cat([td_lazy_clone], dim=cat_dim, out=td_lazy) - data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy)} - assert data_ptr_set_after == data_ptr_set_before - assert res is td_lazy - assert assert_allclose_td(res, td_lazy_clone) - - td_lazy_2 = td_lazy.clone() - td_lazy_2.apply_(lambda x: x + 1) - - res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim) - assert res.stack_dim == len(batch_size) - assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim] - index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) - assert assert_allclose_td(res[index], td_lazy) - index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) - assert assert_allclose_td(res[index], td_lazy_2) - - res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim) - assert res.stack_dim == len(batch_size) - assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim] - index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) - assert assert_allclose_td(res[index], td_lazy) - index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) - assert assert_allclose_td(res[index], td_lazy_2) - - if cat_dim != len(batch_size): # cat dim is not stack dim - batch_size = list(batch_size) - batch_size[cat_dim] *= 2 - td_lazy_dest = self.nested_lazy_het_td(batch_size)["lazy"] - data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy_dest)} - res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim, out=td_lazy_dest) - data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy_dest)} - assert data_ptr_set_after == data_ptr_set_before - assert res is td_lazy_dest - index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) - assert assert_allclose_td(res[index], td_lazy) - index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) - assert assert_allclose_td(res[index], td_lazy_2) - def recursively_check_key(self, td, value: int): - if isinstance(td, LazyStackedTensorDict): - for t in td.tensordicts: - if not self.recursively_check_key(t, value): - return False - elif isinstance(td, TensorDict): - for i in td.values(): - if not self.recursively_check_key(i, value): - return False - elif isinstance(td, torch.Tensor): - return (td == value).all() - else: - return False + def td_reset_bs(self, device): + td = self.td(device) + td = td.unsqueeze(-1).to_tensordict() + td.batch_size = torch.Size([3, 1]) + return td - return True + def unsqueezed_td(self, device): + return self.td(device).unsqueeze(0) - def dense_stack_tds_v1(self, td_list, stack_dim: int) -> TensorDictBase: - shape = list(td_list[0].shape) - shape.insert(stack_dim, len(td_list)) - out = td_list[0].unsqueeze(stack_dim).expand(shape).clone() - for i in range(1, len(td_list)): - index = (slice(None),) * stack_dim + (i,) # this is index_select - out[index] = td_list[i] +class TestMPInplace: + @classmethod + def _remote_process( + cls, worker_id, command_pipe_child, command_pipe_parent, tensordict + ): + command_pipe_parent.close() + while True: + cmd, val = command_pipe_child.recv() + if cmd == "recv": + b = tensordict.get("b") + assert (b == val).all() + command_pipe_child.send("done") + elif cmd == "send": + a = torch.ones(2) * val + tensordict.set_("a", a) + assert ( + tensordict.get("a") == a + ).all(), f'found {a} and {tensordict.get("a")}' + command_pipe_child.send("done") + elif cmd == "set_done": + tensordict.set_("done", torch.ones(1, dtype=torch.bool)) + command_pipe_child.send("done") + elif cmd == "set_undone_": + tensordict.set_("done", torch.zeros(1, dtype=torch.bool)) + command_pipe_child.send("done") + elif cmd == "update": + tensordict.update_( + TensorDict( + source={"a": tensordict.get("a").clone() + 1}, + batch_size=tensordict.batch_size, + ) + ) + command_pipe_child.send("done") + elif cmd == "update_": + tensordict.update_( + TensorDict( + source={"a": tensordict.get("a").clone() - 1}, + batch_size=tensordict.batch_size, + ) + ) + command_pipe_child.send("done") - return out + elif cmd == "close": + command_pipe_child.close() + break - def dense_stack_tds_v2(self, td_list, stack_dim: int) -> TensorDictBase: - shape = list(td_list[0].shape) - shape.insert(stack_dim, len(td_list)) - out = td_list[0].unsqueeze(stack_dim).expand(shape).clone() + @classmethod + def _driver_func(cls, tensordict, tensordict_unbind): + procs = [] + children = [] + parents = [] - data_ptr_set_before = {val.data_ptr() for val in decompose(out)} - res = torch.stack(td_list, dim=stack_dim, out=out) - data_ptr_set_after = {val.data_ptr() for val in decompose(out)} - assert data_ptr_set_before == data_ptr_set_after + for i in range(2): + command_pipe_parent, command_pipe_child = mp.Pipe() + proc = mp.Process( + target=cls._remote_process, + args=(i, command_pipe_child, command_pipe_parent, tensordict_unbind[i]), + ) + proc.start() + command_pipe_child.close() + parents.append(command_pipe_parent) + children.append(command_pipe_child) + procs.append(proc) - return res + b = torch.ones(2, 1) * 10 + tensordict.set_("b", b) + for i in range(2): + parents[i].send(("recv", 10)) + is_done = parents[i].recv() + assert is_done == "done" - @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) - @pytest.mark.parametrize("stack_dim", [0, 1, 2]) - def test_setitem_hetero(self, batch_size, stack_dim): - obs = self.nested_lazy_het_td(batch_size) - obs1 = obs.clone() - obs1.apply_(lambda x: x + 1) + for i in range(2): + parents[i].send(("send", i)) + is_done = parents[i].recv() + assert is_done == "done" + a = tensordict.get("a").clone() + assert (a[0] == 0).all() + assert (a[1] == 1).all() - if stack_dim > len(batch_size): - return + assert not tensordict.get("done").any() + for i in range(2): + parents[i].send(("set_done", i)) + is_done = parents[i].recv() + assert is_done == "done" + assert tensordict.get("done").all() - res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim) - res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim) + for i in range(2): + parents[i].send(("set_undone_", i)) + is_done = parents[i].recv() + assert is_done == "done" + assert not tensordict.get("done").any() - index = (slice(None),) * stack_dim + (0,) # get the first in the stack - assert self.recursively_check_key(res1[index], 0) # check all 0 - assert self.recursively_check_key(res2[index], 0) # check all 0 - index = (slice(None),) * stack_dim + (1,) # get the second in the stack - assert self.recursively_check_key(res1[index], 1) # check all 1 - assert self.recursively_check_key(res2[index], 1) # check all 1 + a_prev = tensordict.get("a").clone().contiguous() + for i in range(2): + parents[i].send(("update_", i)) + is_done = parents[i].recv() + assert is_done == "done" + new_a = tensordict.get("a").clone().contiguous() + torch.testing.assert_close(a_prev - 1, new_a) - @pytest.mark.parametrize("batch_size", [(), (32,), (32, 4)]) - def test_lazy_stack_stack(self, batch_size): - obs = self.nested_lazy_het_td(batch_size) + a_prev = tensordict.get("a").clone().contiguous() + for i in range(2): + parents[i].send(("update", i)) + is_done = parents[i].recv() + assert is_done == "done" + new_a = tensordict.get("a").clone().contiguous() + torch.testing.assert_close(a_prev + 1, new_a) - assert isinstance(obs, TensorDict) - assert isinstance(obs["lazy"], LazyStackedTensorDict) - assert obs["lazy"].stack_dim == len(obs["lazy"].shape) - 1 # succeeds - assert obs["lazy"].shape == (*batch_size, 3) - assert isinstance(obs["lazy"][..., 0], TensorDict) # succeeds + for i in range(2): + parents[i].send(("close", None)) + procs[i].join() - obs_stack = torch.stack([obs]) + @pytest.mark.parametrize( + "td_type", + [ + "memmap", + "memmap_stack", + "contiguous", + "stack", + ], + ) + def test_mp(self, td_type): + tensordict = TensorDict( + source={ + "a": torch.randn(2, 2), + "b": torch.randn(2, 1), + "done": torch.zeros(2, 1, dtype=torch.bool), + }, + batch_size=[2], + ) + if td_type == "contiguous": + tensordict = tensordict.share_memory_() + elif td_type == "stack": + tensordict = stack_td( + [ + tensordict[0].clone().share_memory_(), + tensordict[1].clone().share_memory_(), + ], + 0, + ) + elif td_type == "memmap": + tensordict = tensordict.memmap_() + elif td_type == "memmap_stack": + tensordict = stack_td( + [ + tensordict[0].clone().memmap_(), + tensordict[1].clone().memmap_(), + ], + 0, + ) + else: + raise NotImplementedError + self._driver_func( + tensordict, + (tensordict._get_sub_tensordict(0), tensordict._get_sub_tensordict(1)) + # tensordict, + # tensordict.unbind(0), + ) - assert ( - isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 - ) # succeeds - assert obs_stack.batch_size == (1, *batch_size) # succeeds - assert obs_stack[0] is obs # succeeds - assert isinstance(obs_stack["lazy"], LazyStackedTensorDict) - assert obs_stack["lazy"].shape == (1, *batch_size, 3) - assert obs_stack["lazy"].stack_dim == 0 # succeeds - assert obs_stack["lazy"][0] is obs["lazy"] - obs2 = obs.clone() - obs_stack = torch.stack([obs, obs2]) +class TestMakeTensorDict: + def test_create_tensordict(self): + tensordict = make_tensordict(a=torch.zeros(3, 4)) + assert (tensordict["a"] == torch.zeros(3, 4)).all() - assert ( - isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 - ) # succeeds - assert obs_stack.batch_size == (2, *batch_size) # succeeds - assert obs_stack[0] is obs # succeeds - assert isinstance(obs_stack["lazy"], LazyStackedTensorDict) - assert obs_stack["lazy"].shape == (2, *batch_size, 3) - assert obs_stack["lazy"].stack_dim == 0 # succeeds - assert obs_stack["lazy"][0] is obs["lazy"] + def test_nested(self): + input_dict = { + "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, + "d": torch.randn(3), + } + tensordict = make_tensordict(input_dict) + assert tensordict.shape == torch.Size([3]) + assert tensordict["a"].shape == torch.Size([3, 4]) + input_tensordict = TensorDict( + { + "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, + "d": torch.randn(3), + }, + [], + ) + tensordict = make_tensordict(input_tensordict) + assert tensordict.shape == torch.Size([3]) + assert tensordict["a"].shape == torch.Size([3, 4]) + input_dict = { + ("a", "b"): torch.randn(3, 4), + ("a", "c"): torch.randn(3, 4, 5), + "d": torch.randn(3), + } + tensordict = make_tensordict(input_dict) + assert tensordict.shape == torch.Size([3]) + assert tensordict["a"].shape == torch.Size([3, 4]) - @pytest.mark.parametrize("batch_size", [(), (32,), (32, 4)]) - def test_stack_hetero(self, batch_size): - obs = self.nested_lazy_het_td(batch_size) + def test_tensordict_batch_size(self): + tensordict = make_tensordict() + assert tensordict.batch_size == torch.Size([]) - obs2 = obs.clone() - obs2.apply_(lambda x: x + 1) + tensordict = make_tensordict(a=torch.randn(3, 4)) + assert tensordict.batch_size == torch.Size([3, 4]) + + tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5)) + assert tensordict.batch_size == torch.Size([3, 4]) - obs_stack = torch.stack([obs, obs2]) - obs_stack_resolved = self.dense_stack_tds_v2([obs, obs2], stack_dim=0) + nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5)) # nested + assert nested_tensordict.batch_size == torch.Size([3]) - assert isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 - assert isinstance(obs_stack_resolved, TensorDict) + nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5)) # nested + assert nested_tensordict.batch_size == torch.Size([]) - assert obs_stack.batch_size == (2, *batch_size) - assert obs_stack_resolved.batch_size == obs_stack.batch_size + tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5)) + assert tensordict.batch_size == torch.Size([3, 4]) - assert obs_stack["lazy"].shape == (2, *batch_size, 3) - assert obs_stack_resolved["lazy"].batch_size == obs_stack["lazy"].batch_size + tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1)) + assert tensordict.batch_size == torch.Size([]) - assert obs_stack["lazy"].stack_dim == 0 - assert ( - obs_stack_resolved["lazy"].stack_dim - == len(obs_stack_resolved["lazy"].batch_size) - 1 + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), batch_size=[3] ) - for stack in [obs_stack_resolved, obs_stack]: - for index in range(2): - assert (stack[index]["dense"] == index).all() - assert (stack["dense"][index] == index).all() - assert (stack["lazy"][index]["shared"] == index).all() - assert (stack[index]["lazy"]["shared"] == index).all() - assert (stack["lazy"]["shared"][index] == index).all() - assert ( - stack["lazy"][index][..., 0]["individual_0_tensor"] == index - ).all() - assert ( - stack[index]["lazy"][..., 0]["individual_0_tensor"] == index - ).all() - assert ( - stack["lazy"][..., 0]["individual_0_tensor"][index] == index - ).all() - assert ( - stack["lazy"][..., 0][index]["individual_0_tensor"] == index - ).all() + assert tensordict.batch_size == torch.Size([3]) - def test_add_batch_dim_cache(self): - td = TensorDict( - {"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5] + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), batch_size=[] ) - td = torch.stack([td, td.clone()], 0) - from tensordict.nn import TensorDictModule # noqa - from torch import vmap + assert tensordict.batch_size == torch.Size([]) - logging.info("first call to vmap") - fun = vmap(lambda x: x) - fun(td) - td.zero_() - # this value should be cached - logging.info("second call to vmap") - std = fun(td) - for value in std.values(True, True): - assert (value == 0).all() + @pytest.mark.parametrize("device", get_available_devices()) + def test_tensordict_device(self, device): + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4), device=device + ) + assert tensordict.device == device + assert tensordict["a"].device == device + assert tensordict["b"].device == device - def test_add_batch_dim_cache_nested(self): - td = TensorDict( - {"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5] + tensordict = make_tensordict( + a=torch.randn(3, 4, device=device), + b=torch.randn(3, 4), + c=torch.randn(3, 4, device="cpu"), + device=device, ) - td = TensorDict({"parent": torch.stack([td, td.clone()], 0)}, [2, 3, 4, 5]) - from tensordict.nn import TensorDictModule # noqa - from torch import vmap + assert tensordict.device == device + assert tensordict["a"].device == device + assert tensordict["b"].device == device + assert tensordict["c"].device == device - fun = vmap(lambda x: x) - logging.info("first call to vmap") - fun(td) - td.zero_() - # this value should be cached - logging.info("second call to vmap") - std = fun(td) - for value in std.values(True, True): - assert (value == 0).all() - def test_update_with_lazy(self): - td0 = TensorDict( - { - ("a", "b", "c"): torch.ones(3, 4), - ("a", "b", "d"): torch.ones(3, 4), - "common": torch.ones(3), - }, - [3], - ) - td1 = TensorDict( - { - ("a", "b", "c"): torch.ones(3, 5) * 2, - "common": torch.ones(3) * 2, - }, - [3], - ) - td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2]) +class TestLazyStackedTensorDict: + @property + def _idx_list(self): + return { + 0: 1, + 1: slice(None), + 2: slice(1, 2), + 3: self._tensor_index, + 4: range(1, 2), + 5: None, + 6: [0, 1], + 7: self._tensor_index.numpy(), + } - td_void = TensorDict( - { - ("parent", "a", "b", "c"): torch.zeros(2, 3, 4), - ("parent", "a", "b", "e"): torch.zeros(2, 3, 4), - ("parent", "a", "b", "d"): torch.zeros(2, 3, 5), - }, - [2], - ) - td_void.update(td) - assert type(td_void.get("parent")) is LazyStackedTensorDict - assert type(td_void.get(("parent", "a"))) is LazyStackedTensorDict - assert type(td_void.get(("parent", "a", "b"))) is LazyStackedTensorDict - assert (td_void.get(("parent", "a", "b"))[0].get("c") == 1).all() - assert (td_void.get(("parent", "a", "b"))[1].get("c") == 2).all() - assert (td_void.get(("parent", "a", "b"))[0].get("d") == 1).all() - assert (td_void.get(("parent", "a", "b"))[1].get("d") == 0).all() # unaffected - assert (td_void.get(("parent", "a", "b")).get("e") == 0).all() # unaffected + @property + def _tensor_index(self): + torch.manual_seed(0) + return torch.randint(2, (5, 2)) - @pytest.mark.parametrize("unsqueeze_dim", [0, 1, -1, -2]) - def test_stack_unsqueeze(self, unsqueeze_dim): - td = TensorDict({("a", "b"): torch.ones(3, 4, 5)}, [3, 4]) - td_stack = torch.stack(td.unbind(1), 1) - td_unsqueeze = td.unsqueeze(unsqueeze_dim) - td_stack_unsqueeze = td_stack.unsqueeze(unsqueeze_dim) - assert isinstance(td_stack_unsqueeze, LazyStackedTensorDict) - for key in td_unsqueeze.keys(True, True): - assert td_unsqueeze.get(key).shape == td_stack_unsqueeze.get(key).shape + def dense_stack_tds_v1(self, td_list, stack_dim: int) -> TensorDictBase: + shape = list(td_list[0].shape) + shape.insert(stack_dim, len(td_list)) - def test_stack_apply(self): - td0 = TensorDict( - { - ("a", "b", "c"): torch.ones(3, 4), - ("a", "b", "d"): torch.ones(3, 4), - "common": torch.ones(3), - }, - [3], - ) - td1 = TensorDict( - { - ("a", "b", "c"): torch.ones(3, 5) * 2, - "common": torch.ones(3) * 2, - }, - [3], - ) - td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2]) - td2 = td.clone() - tdapply = td.apply(lambda x, y: x + y, td2) - assert isinstance(tdapply["parent", "a", "b"], LazyStackedTensorDict) - assert (tdapply["parent", "a", "b"][0]["c"] == 2).all() - assert (tdapply["parent", "a", "b"][1]["c"] == 4).all() - assert (tdapply["parent", "a", "b"][0]["d"] == 2).all() + out = td_list[0].unsqueeze(stack_dim).expand(shape).clone() + for i in range(1, len(td_list)): + index = (slice(None),) * stack_dim + (i,) # this is index_select + out[index] = td_list[i] - def test_stack_keys(self): - td1 = TensorDict(source={"a": torch.randn(3)}, batch_size=[]) - td2 = TensorDict( - source={ - "a": torch.randn(3), - "b": torch.randn(3), - "c": torch.randn(4), - "d": torch.randn(5), - }, - batch_size=[], - ) - td = stack_td([td1, td2], 0) - assert "a" in td.keys() - assert "b" not in td.keys() - assert "b" in td[1].keys() - td.set("b", torch.randn(2, 10), inplace=False) # overwrites - with pytest.raises(KeyError): - td.set_("c", torch.randn(2, 10)) # overwrites - td.set_("b", torch.randn(2, 10)) # b has been set before + return out + + def dense_stack_tds_v2(self, td_list, stack_dim: int) -> TensorDictBase: + shape = list(td_list[0].shape) + shape.insert(stack_dim, len(td_list)) + out = td_list[0].unsqueeze(stack_dim).expand(shape).clone() + + data_ptr_set_before = {val.data_ptr() for val in decompose(out)} + res = torch.stack(td_list, dim=stack_dim, out=out) + data_ptr_set_after = {val.data_ptr() for val in decompose(out)} + assert data_ptr_set_before == data_ptr_set_after + + return res + + @staticmethod + def nested_lazy_het_td(batch_size): + shared = torch.zeros(4, 4, 2) + hetero_3d = torch.zeros(3) + hetero_2d = torch.zeros(2) + + individual_0_tensor = torch.zeros(1) + individual_1_tensor = torch.zeros(1, 2) + individual_2_tensor = torch.zeros(1, 2, 3) + + td_list = [ + TensorDict( + { + "shared": shared, + "hetero": hetero_3d, + "individual_0_tensor": individual_0_tensor, + }, + [], + ), + TensorDict( + { + "shared": shared, + "hetero": hetero_3d, + "individual_1_tensor": individual_1_tensor, + }, + [], + ), + TensorDict( + { + "shared": shared, + "hetero": hetero_2d, + "individual_2_tensor": individual_2_tensor, + }, + [], + ), + ] + for i, td in enumerate(td_list): + td[f"individual_{i}_td"] = td.clone() + td["shared_td"] = td.clone() - td1.set("c", torch.randn(4)) - td[ - "c" - ] # we must first query that key for the stacked tensordict to update the list - assert "c" in td.keys(), list(td.keys()) # now all tds have the key c - td.get("c") + td_stack = torch.stack(td_list, dim=0) + obs = TensorDict( + {"lazy": td_stack, "dense": torch.zeros(3, 3, 2)}, + [], + ) + obs = obs.expand(batch_size).clone() + return obs - td1.set("d", torch.randn(6)) - with pytest.raises(RuntimeError): - td.get("d") + def recursively_check_key(self, td, value: int): + if isinstance(td, LazyStackedTensorDict): + for t in td.tensordicts: + if not self.recursively_check_key(t, value): + return False + elif isinstance(td, TensorDict): + for i in td.values(): + if not self.recursively_check_key(i, value): + return False + elif isinstance(td, torch.Tensor): + return (td == value).all() + else: + return False - td["e"] = torch.randn(2, 4) - assert "e" in td.keys() # now all tds have the key c - td.get("e") + return True - def test_stacked_td_nested_keys(self): - td = torch.stack( - [ - TensorDict({"a": {"b": {"d": [1]}, "c": [2]}}, []), - TensorDict({"a": {"b": {"d": [1]}, "d": [2]}}, []), - ], - 0, + def test_add_batch_dim_cache(self): + td = TensorDict( + {"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5] ) - assert ("a", "b") in td.keys(True) - assert ("a", "c") not in td.keys(True) - assert ("a", "b", "d") in td.keys(True) - td["a", "c"] = [[2], [3]] - assert ("a", "c") in td.keys(True) + td = torch.stack([td, td.clone()], 0) + from tensordict.nn import TensorDictModule # noqa + from torch import vmap - keys, items = zip(*td.items(True)) - assert ("a", "b") in keys - assert ("a", "c") in keys - assert ("a", "d") not in keys + logging.info("first call to vmap") + fun = vmap(lambda x: x) + fun(td) + td.zero_() + # this value should be cached + logging.info("second call to vmap") + std = fun(td) + for value in std.values(True, True): + assert (value == 0).all() - td["a", "c"] = td["a", "c"] + 1 - assert (td["a", "c"] == torch.tensor([[3], [4]], device=td.device)).all() + def test_add_batch_dim_cache_nested(self): + td = TensorDict( + {"a": torch.rand(3, 4, 5), ("b", "c"): torch.rand(3, 4, 5)}, [3, 4, 5] + ) + td = TensorDict({"parent": torch.stack([td, td.clone()], 0)}, [2, 3, 4, 5]) + from tensordict.nn import TensorDictModule # noqa + from torch import vmap - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("stack_dim", [0, 1]) - def test_stacked_td(self, stack_dim, device): - tensordicts = [ - TensorDict( - batch_size=[11, 12], - source={ - "key1": torch.randn(11, 12, 5, device=device), - "key2": torch.zeros( - 11, 12, 50, device=device, dtype=torch.bool - ).bernoulli_(), - }, - ) - for _ in range(10) - ] + fun = vmap(lambda x: x) + logging.info("first call to vmap") + fun(td) + td.zero_() + # this value should be cached + logging.info("second call to vmap") + std = fun(td) + for value in std.values(True, True): + assert (value == 0).all() - tensordicts0 = tensordicts[0] - tensordicts1 = tensordicts[1] - tensordicts2 = tensordicts[2] - tensordicts3 = tensordicts[3] - sub_td = LazyStackedTensorDict(*tensordicts, stack_dim=stack_dim) + def test_all_keys(self): + td = TensorDict({"a": torch.zeros(1)}, []) + td2 = TensorDict({"a": torch.zeros(2)}, []) + stack = torch.stack([td, td2]) + assert set(stack.keys(True, True)) == {"a"} - std_bis = stack_td(tensordicts, dim=stack_dim, contiguous=False) - assert (sub_td == std_bis).all() + def test_best_intention_stack(self): + td0 = TensorDict({"a": 1, "b": TensorDict({"c": 2}, [])}, []) + td1 = TensorDict({"a": 1, "b": TensorDict({"d": 2}, [])}, []) + with set_lazy_legacy(False): + td = torch.stack([td0, td1]) + assert isinstance(td, TensorDict) + assert isinstance(td.get("b"), LazyStackedTensorDict) + td1 = TensorDict({"a": 1, "b": TensorDict({"c": [2]}, [])}, []) + with set_lazy_legacy(False): + td = torch.stack([td0, td1]) + assert isinstance(td, TensorDict) + assert isinstance(td.get("b"), LazyStackedTensorDict) - item = (*[slice(None) for _ in range(stack_dim)], 0) - tensordicts0.zero_() - assert (sub_td[item].get("key1") == sub_td.get("key1")[item]).all() - assert ( - sub_td.contiguous()[item].get("key1") - == sub_td.contiguous().get("key1")[item] - ).all() - assert (sub_td.contiguous().get("key1")[item] == 0).all() + @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) + @pytest.mark.parametrize("cat_dim", [0, 1, 2]) + def test_cat_lazy_stack(self, batch_size, cat_dim): + if cat_dim > len(batch_size): + return + td_lazy = self.nested_lazy_het_td(batch_size)["lazy"] + assert isinstance(td_lazy, LazyStackedTensorDict) + res = torch.cat([td_lazy], dim=cat_dim) + assert assert_allclose_td(res, td_lazy) + assert res is not td_lazy + td_lazy_clone = td_lazy.clone() + data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy)} + res = torch.cat([td_lazy_clone], dim=cat_dim, out=td_lazy) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy)} + assert data_ptr_set_after == data_ptr_set_before + assert res is td_lazy + assert assert_allclose_td(res, td_lazy_clone) - item = (*[slice(None) for _ in range(stack_dim)], 1) - std2 = sub_td[:5] - tensordicts1.zero_() - assert (std2[item].get("key1") == std2.get("key1")[item]).all() - assert ( - std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] - ).all() - assert (std2.contiguous().get("key1")[item] == 0).all() + td_lazy_2 = td_lazy.clone() + td_lazy_2.apply_(lambda x: x + 1) - std3 = sub_td[:5, :, :5] - tensordicts2.zero_() - item = (*[slice(None) for _ in range(stack_dim)], 2) - assert (std3[item].get("key1") == std3.get("key1")[item]).all() - assert ( - std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] - ).all() - assert (std3.contiguous().get("key1")[item] == 0).all() + res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim) + assert res.stack_dim == len(batch_size) + assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim] + index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) + assert assert_allclose_td(res[index], td_lazy) + index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) + assert assert_allclose_td(res[index], td_lazy_2) - std4 = sub_td.select("key1") - tensordicts3.zero_() - item = (*[slice(None) for _ in range(stack_dim)], 3) - assert (std4[item].get("key1") == std4.get("key1")[item]).all() - assert ( - std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] - ).all() - assert (std4.contiguous().get("key1")[item] == 0).all() + res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim) + assert res.stack_dim == len(batch_size) + assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim] + index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) + assert assert_allclose_td(res[index], td_lazy) + index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) + assert assert_allclose_td(res[index], td_lazy_2) - std5 = sub_td.unbind(1)[0] - assert (std5.contiguous() == sub_td.contiguous().unbind(1)[0]).all() + if cat_dim != len(batch_size): # cat dim is not stack dim + batch_size = list(batch_size) + batch_size[cat_dim] *= 2 + td_lazy_dest = self.nested_lazy_het_td(batch_size)["lazy"] + data_ptr_set_before = {val.data_ptr() for val in decompose(td_lazy_dest)} + res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim, out=td_lazy_dest) + data_ptr_set_after = {val.data_ptr() for val in decompose(td_lazy_dest)} + assert data_ptr_set_after == data_ptr_set_before + assert res is td_lazy_dest + index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),) + assert assert_allclose_td(res[index], td_lazy) + index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),) + assert assert_allclose_td(res[index], td_lazy_2) + + @pytest.mark.parametrize("pos1", range(8)) + @pytest.mark.parametrize("pos2", range(8)) + @pytest.mark.parametrize("pos3", range(8)) + def test_lazy_indexing(self, pos1, pos2, pos3): + torch.manual_seed(0) + td_leaf_1 = TensorDict({"a": torch.ones(2, 3)}, []) + inner = torch.stack([td_leaf_1] * 4, 0) + middle = torch.stack([inner] * 3, 0) + outer = torch.stack([middle] * 2, 0) + outer_dense = outer.to_tensordict() + ref_tensor = torch.zeros(2, 3, 4) + pos1 = self._idx_list[pos1] + pos2 = self._idx_list[pos2] + pos3 = self._idx_list[pos3] + index = (pos1, pos2, pos3) + result = outer[index] + assert result.batch_size == ref_tensor[index].shape, index + assert result.batch_size == outer_dense[index].shape, index + + @pytest.mark.parametrize("stack_dim", [0, 1, 2]) + @pytest.mark.parametrize("mask_dim", [0, 1, 2]) + @pytest.mark.parametrize("single_mask_dim", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_lazy_mask_indexing(self, stack_dim, mask_dim, single_mask_dim, device): + torch.manual_seed(0) + td = TensorDict({"a": torch.zeros(9, 10, 11)}, [9, 10, 11], device=device) + td = torch.stack( + [ + td, + td.apply(lambda x: x + 1), + td.apply(lambda x: x + 2), + td.apply(lambda x: x + 3), + ], + stack_dim, + ) + mask = torch.zeros(()) + while not mask.any(): + if single_mask_dim: + mask = torch.zeros(td.shape[mask_dim], dtype=torch.bool).bernoulli_() + else: + mask = torch.zeros( + td.shape[mask_dim : mask_dim + 2], dtype=torch.bool + ).bernoulli_() + index = (slice(None),) * mask_dim + (mask,) + tdmask = td[index] + assert tdmask["a"].shape == td["a"][index].shape + assert (tdmask["a"] == td["a"][index]).all() + index = (0,) * mask_dim + (mask,) + tdmask = td[index] + assert tdmask["a"].shape == td["a"][index].shape + assert (tdmask["a"] == td["a"][index]).all() + index = (slice(1),) * mask_dim + (mask,) + tdmask = td[index] + assert tdmask["a"].shape == td["a"][index].shape + assert (tdmask["a"] == td["a"][index]).all() - @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("stack_dim", [0, 1, 2]) - def test_stacked_indexing(self, device, stack_dim): - tensordict = TensorDict( - {"a": torch.randn(3, 4, 5), "b": torch.randn(3, 4, 5)}, - batch_size=[3, 4, 5], - device=device, + @pytest.mark.parametrize("mask_dim", [0, 1, 2]) + @pytest.mark.parametrize("single_mask_dim", [True, False]) + @pytest.mark.parametrize("device", get_available_devices()) + def test_lazy_mask_setitem(self, stack_dim, mask_dim, single_mask_dim, device): + torch.manual_seed(0) + td = TensorDict({"a": torch.zeros(9, 10, 11)}, [9, 10, 11], device=device) + td = torch.stack( + [ + td, + td.apply(lambda x: x + 1), + td.apply(lambda x: x + 2), + td.apply(lambda x: x + 3), + ], + stack_dim, ) + mask = torch.zeros(()) + while not mask.any(): + if single_mask_dim: + mask = torch.zeros(td.shape[mask_dim], dtype=torch.bool).bernoulli_() + else: + mask = torch.zeros( + td.shape[mask_dim : mask_dim + 2], dtype=torch.bool + ).bernoulli_() + index = (slice(None),) * mask_dim + (mask,) + tdset = TensorDict({"a": td["a"][index] * 0 - 1}, []) + # we know that the batch size is accurate from test_lazy_mask_indexing + tdset.batch_size = td[index].batch_size + td[index] = tdset + assert (td["a"][index] == tdset["a"]).all() + assert (td["a"][index] == tdset["a"]).all() + index = (slice(1),) * mask_dim + (mask,) + tdset = TensorDict({"a": td["a"][index] * 0 - 1}, []) + tdset.batch_size = td[index].batch_size + td[index] = tdset + assert (td["a"][index] == tdset["a"]).all() + assert (td["a"][index] == tdset["a"]).all() - tds = torch.stack(list(tensordict.unbind(stack_dim)), stack_dim) + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 4)]) + def test_lazy_stack_stack(self, batch_size): + obs = self.nested_lazy_het_td(batch_size) - for item, expected_shape in ( - ((2, 2), torch.Size([5])), - ((slice(1, 2), 2), torch.Size([1, 5])), - ((..., 2), torch.Size([3, 4])), - ): - assert tds[item].batch_size == expected_shape - assert (tds[item].get("a") == tds.get("a")[item]).all() - assert (tds[item].get("a") == tensordict[item].get("a")).all() + assert isinstance(obs, TensorDict) + assert isinstance(obs["lazy"], LazyStackedTensorDict) + assert obs["lazy"].stack_dim == len(obs["lazy"].shape) - 1 # succeeds + assert obs["lazy"].shape == (*batch_size, 3) + assert isinstance(obs["lazy"][..., 0], TensorDict) # succeeds - @pytest.mark.parametrize("device", get_available_devices()) - def test_stack(self, device): - torch.manual_seed(1) - tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)] - tds = stack_td(tds_list, 0, contiguous=False) - assert tds[0] is tds_list[0] + obs_stack = torch.stack([obs]) - td = TensorDict( - source={"a": torch.randn(4, 5, 3, device=device)}, batch_size=(4, 5) - ) - td_list = list(td) - td_reconstruct = stack_td(td_list, 0) - assert td_reconstruct.batch_size == td.batch_size - assert (td_reconstruct == td).all() + assert ( + isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 + ) # succeeds + assert obs_stack.batch_size == (1, *batch_size) # succeeds + assert obs_stack[0] is obs # succeeds + assert isinstance(obs_stack["lazy"], LazyStackedTensorDict) + assert obs_stack["lazy"].shape == (1, *batch_size, 3) + assert obs_stack["lazy"].stack_dim == 0 # succeeds + assert obs_stack["lazy"][0] is obs["lazy"] + + obs2 = obs.clone() + obs_stack = torch.stack([obs, obs2]) + + assert ( + isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 + ) # succeeds + assert obs_stack.batch_size == (2, *batch_size) # succeeds + assert obs_stack[0] is obs # succeeds + assert isinstance(obs_stack["lazy"], LazyStackedTensorDict) + assert obs_stack["lazy"].shape == (2, *batch_size, 3) + assert obs_stack["lazy"].stack_dim == 0 # succeeds + assert obs_stack["lazy"][0] is obs["lazy"] @pytest.mark.parametrize("dim", range(2)) - @pytest.mark.parametrize("index", range(2)) @pytest.mark.parametrize("device", get_available_devices()) - def test_lazy_stacked_insert(self, dim, index, device): + def test_lazy_stacked_append(self, dim, device): td = TensorDict({"a": torch.zeros(4)}, [4], device=device) lstd = torch.stack([td] * 2, dim=dim) - lstd.insert( - index, + lstd.append( TensorDict( {"a": torch.ones(4), "invalid": torch.rand(4)}, [4], device=device - ), + ) ) bs = [4] @@ -5296,23 +5389,23 @@ def test_lazy_stacked_insert(self, dim, index, device): t = torch.zeros(*bs, device=device) if dim == 0: - t[index] = 1 + t[-1] = 1 else: - t[:, index] = 1 + t[:, -1] = 1 torch.testing.assert_close(lstd["a"], t) with pytest.raises( TypeError, match="Expected new value to be TensorDictBase instance" ): - lstd.insert(index, torch.rand(10)) + lstd.append(torch.rand(10)) if device != torch.device("cpu"): with pytest.raises(ValueError, match="Devices differ"): - lstd.insert(index, TensorDict({"a": torch.ones(4)}, [4], device="cpu")) + lstd.append(TensorDict({"a": torch.ones(4)}, [4], device="cpu")) with pytest.raises(ValueError, match="Batch sizes in tensordicts differs"): - lstd.insert(index, TensorDict({"a": torch.ones(17)}, [17], device=device)) + lstd.append(TensorDict({"a": torch.ones(17)}, [17], device=device)) def test_lazy_stacked_contains(self): td = TensorDict( @@ -5330,15 +5423,17 @@ def test_lazy_stacked_contains(self): "random_string" in lstd # noqa: B015 @pytest.mark.parametrize("dim", range(2)) + @pytest.mark.parametrize("index", range(2)) @pytest.mark.parametrize("device", get_available_devices()) - def test_lazy_stacked_append(self, dim, device): + def test_lazy_stacked_insert(self, dim, index, device): td = TensorDict({"a": torch.zeros(4)}, [4], device=device) lstd = torch.stack([td] * 2, dim=dim) - lstd.append( + lstd.insert( + index, TensorDict( {"a": torch.ones(4), "invalid": torch.rand(4)}, [4], device=device - ) + ), ) bs = [4] @@ -5350,35 +5445,171 @@ def test_lazy_stacked_append(self, dim, device): t = torch.zeros(*bs, device=device) if dim == 0: - t[-1] = 1 + t[index] = 1 else: - t[:, -1] = 1 + t[:, index] = 1 torch.testing.assert_close(lstd["a"], t) with pytest.raises( TypeError, match="Expected new value to be TensorDictBase instance" ): - lstd.append(torch.rand(10)) + lstd.insert(index, torch.rand(10)) if device != torch.device("cpu"): with pytest.raises(ValueError, match="Devices differ"): - lstd.append(TensorDict({"a": torch.ones(4)}, [4], device="cpu")) + lstd.insert(index, TensorDict({"a": torch.ones(4)}, [4], device="cpu")) + + with pytest.raises(ValueError, match="Batch sizes in tensordicts differs"): + lstd.insert(index, TensorDict({"a": torch.ones(17)}, [17], device=device)) + + @pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)]) + @pytest.mark.parametrize("stack_dim", [0, 1, 2]) + def test_setitem_hetero(self, batch_size, stack_dim): + obs = self.nested_lazy_het_td(batch_size) + obs1 = obs.clone() + obs1.apply_(lambda x: x + 1) + + if stack_dim > len(batch_size): + return + + res1 = self.dense_stack_tds_v1([obs, obs1], stack_dim=stack_dim) + res2 = self.dense_stack_tds_v2([obs, obs1], stack_dim=stack_dim) + + index = (slice(None),) * stack_dim + (0,) # get the first in the stack + assert self.recursively_check_key(res1[index], 0) # check all 0 + assert self.recursively_check_key(res2[index], 0) # check all 0 + index = (slice(None),) * stack_dim + (1,) # get the second in the stack + assert self.recursively_check_key(res1[index], 1) # check all 1 + assert self.recursively_check_key(res2[index], 1) # check all 1 + + @pytest.mark.parametrize("device", get_available_devices()) + def test_stack(self, device): + torch.manual_seed(1) + tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)] + tds = stack_td(tds_list, 0, contiguous=False) + assert tds[0] is tds_list[0] + + td = TensorDict( + source={"a": torch.randn(4, 5, 3, device=device)}, batch_size=(4, 5) + ) + td_list = list(td) + td_reconstruct = stack_td(td_list, 0) + assert td_reconstruct.batch_size == td.batch_size + assert (td_reconstruct == td).all() + + def test_stack_apply(self): + td0 = TensorDict( + { + ("a", "b", "c"): torch.ones(3, 4), + ("a", "b", "d"): torch.ones(3, 4), + "common": torch.ones(3), + }, + [3], + ) + td1 = TensorDict( + { + ("a", "b", "c"): torch.ones(3, 5) * 2, + "common": torch.ones(3) * 2, + }, + [3], + ) + td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2]) + td2 = td.clone() + tdapply = td.apply(lambda x, y: x + y, td2) + assert isinstance(tdapply["parent", "a", "b"], LazyStackedTensorDict) + assert (tdapply["parent", "a", "b"][0]["c"] == 2).all() + assert (tdapply["parent", "a", "b"][1]["c"] == 4).all() + assert (tdapply["parent", "a", "b"][0]["d"] == 2).all() + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 4)]) + def test_stack_hetero(self, batch_size): + obs = self.nested_lazy_het_td(batch_size) + + obs2 = obs.clone() + obs2.apply_(lambda x: x + 1) + + obs_stack = torch.stack([obs, obs2]) + obs_stack_resolved = self.dense_stack_tds_v2([obs, obs2], stack_dim=0) + + assert isinstance(obs_stack, LazyStackedTensorDict) and obs_stack.stack_dim == 0 + assert isinstance(obs_stack_resolved, TensorDict) + + assert obs_stack.batch_size == (2, *batch_size) + assert obs_stack_resolved.batch_size == obs_stack.batch_size + + assert obs_stack["lazy"].shape == (2, *batch_size, 3) + assert obs_stack_resolved["lazy"].batch_size == obs_stack["lazy"].batch_size + + assert obs_stack["lazy"].stack_dim == 0 + assert ( + obs_stack_resolved["lazy"].stack_dim + == len(obs_stack_resolved["lazy"].batch_size) - 1 + ) + for stack in [obs_stack_resolved, obs_stack]: + for index in range(2): + assert (stack[index]["dense"] == index).all() + assert (stack["dense"][index] == index).all() + assert (stack["lazy"][index]["shared"] == index).all() + assert (stack[index]["lazy"]["shared"] == index).all() + assert (stack["lazy"]["shared"][index] == index).all() + assert ( + stack["lazy"][index][..., 0]["individual_0_tensor"] == index + ).all() + assert ( + stack[index]["lazy"][..., 0]["individual_0_tensor"] == index + ).all() + assert ( + stack["lazy"][..., 0]["individual_0_tensor"][index] == index + ).all() + assert ( + stack["lazy"][..., 0][index]["individual_0_tensor"] == index + ).all() + + def test_stack_keys(self): + td1 = TensorDict(source={"a": torch.randn(3)}, batch_size=[]) + td2 = TensorDict( + source={ + "a": torch.randn(3), + "b": torch.randn(3), + "c": torch.randn(4), + "d": torch.randn(5), + }, + batch_size=[], + ) + td = stack_td([td1, td2], 0) + assert "a" in td.keys() + assert "b" not in td.keys() + assert "b" in td[1].keys() + td.set("b", torch.randn(2, 10), inplace=False) # overwrites + with pytest.raises(KeyError): + td.set_("c", torch.randn(2, 10)) # overwrites + td.set_("b", torch.randn(2, 10)) # b has been set before + + td1.set("c", torch.randn(4)) + td[ + "c" + ] # we must first query that key for the stacked tensordict to update the list + assert "c" in td.keys(), list(td.keys()) # now all tds have the key c + td.get("c") - with pytest.raises(ValueError, match="Batch sizes in tensordicts differs"): - lstd.append(TensorDict({"a": torch.ones(17)}, [17], device=device)) + td1.set("d", torch.randn(6)) + with pytest.raises(RuntimeError): + td.get("d") - def test_unbind_lazystack(self): - td0 = TensorDict( - { - "a": {"b": torch.randn(3, 4), "d": torch.randn(3, 4)}, - "c": torch.randn(3, 4), - }, - [3, 4], - ) - td = torch.stack([td0, td0, td0], 1) + td["e"] = torch.randn(2, 4) + assert "e" in td.keys() # now all tds have the key c + td.get("e") - assert all(_td is td0 for _td in td.unbind(1)) + @pytest.mark.parametrize("unsqueeze_dim", [0, 1, -1, -2]) + def test_stack_unsqueeze(self, unsqueeze_dim): + td = TensorDict({("a", "b"): torch.ones(3, 4, 5)}, [3, 4]) + td_stack = torch.stack(td.unbind(1), 1) + td_unsqueeze = td.unsqueeze(unsqueeze_dim) + td_stack_unsqueeze = td_stack.unsqueeze(unsqueeze_dim) + assert isinstance(td_stack_unsqueeze, LazyStackedTensorDict) + for key in td_unsqueeze.keys(True, True): + assert td_unsqueeze.get(key).shape == td_stack_unsqueeze.get(key).shape @pytest.mark.parametrize("stack_dim", [0, 1, -1]) def test_stack_update_heter_stacked_td(self, stack_dim): @@ -5399,136 +5630,159 @@ def test_stack_update_heter_stacked_td(self, stack_dim): ): td_a.update_(td_b.to_tensordict()) - @property - def _tensor_index(self): - torch.manual_seed(0) - return torch.randint(2, (5, 2)) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("stack_dim", [0, 1, 2]) + def test_stacked_indexing(self, device, stack_dim): + tensordict = TensorDict( + {"a": torch.randn(3, 4, 5), "b": torch.randn(3, 4, 5)}, + batch_size=[3, 4, 5], + device=device, + ) - @property - def _idx_list(self): - return { - 0: 1, - 1: slice(None), - 2: slice(1, 2), - 3: self._tensor_index, - 4: range(1, 2), - 5: None, - 6: [0, 1], - 7: self._tensor_index.numpy(), - } + tds = torch.stack(list(tensordict.unbind(stack_dim)), stack_dim) - @pytest.mark.parametrize("pos1", range(8)) - @pytest.mark.parametrize("pos2", range(8)) - @pytest.mark.parametrize("pos3", range(8)) - def test_lazy_indexing(self, pos1, pos2, pos3): - torch.manual_seed(0) - td_leaf_1 = TensorDict({"a": torch.ones(2, 3)}, []) - inner = torch.stack([td_leaf_1] * 4, 0) - middle = torch.stack([inner] * 3, 0) - outer = torch.stack([middle] * 2, 0) - outer_dense = outer.to_tensordict() - ref_tensor = torch.zeros(2, 3, 4) - pos1 = self._idx_list[pos1] - pos2 = self._idx_list[pos2] - pos3 = self._idx_list[pos3] - index = (pos1, pos2, pos3) - result = outer[index] - assert result.batch_size == ref_tensor[index].shape, index - assert result.batch_size == outer_dense[index].shape, index + for item, expected_shape in ( + ((2, 2), torch.Size([5])), + ((slice(1, 2), 2), torch.Size([1, 5])), + ((..., 2), torch.Size([3, 4])), + ): + assert tds[item].batch_size == expected_shape + assert (tds[item].get("a") == tds.get("a")[item]).all() + assert (tds[item].get("a") == tensordict[item].get("a")).all() - @pytest.mark.parametrize("stack_dim", [0, 1, 2]) - @pytest.mark.parametrize("mask_dim", [0, 1, 2]) - @pytest.mark.parametrize("single_mask_dim", [True, False]) @pytest.mark.parametrize("device", get_available_devices()) - def test_lazy_mask_indexing(self, stack_dim, mask_dim, single_mask_dim, device): - torch.manual_seed(0) - td = TensorDict({"a": torch.zeros(9, 10, 11)}, [9, 10, 11], device=device) + @pytest.mark.parametrize("stack_dim", [0, 1]) + def test_stacked_td(self, stack_dim, device): + tensordicts = [ + TensorDict( + batch_size=[11, 12], + source={ + "key1": torch.randn(11, 12, 5, device=device), + "key2": torch.zeros( + 11, 12, 50, device=device, dtype=torch.bool + ).bernoulli_(), + }, + ) + for _ in range(10) + ] + + tensordicts0 = tensordicts[0] + tensordicts1 = tensordicts[1] + tensordicts2 = tensordicts[2] + tensordicts3 = tensordicts[3] + sub_td = LazyStackedTensorDict(*tensordicts, stack_dim=stack_dim) + + std_bis = stack_td(tensordicts, dim=stack_dim, contiguous=False) + assert (sub_td == std_bis).all() + + item = (*[slice(None) for _ in range(stack_dim)], 0) + tensordicts0.zero_() + assert (sub_td[item].get("key1") == sub_td.get("key1")[item]).all() + assert ( + sub_td.contiguous()[item].get("key1") + == sub_td.contiguous().get("key1")[item] + ).all() + assert (sub_td.contiguous().get("key1")[item] == 0).all() + + item = (*[slice(None) for _ in range(stack_dim)], 1) + std2 = sub_td[:5] + tensordicts1.zero_() + assert (std2[item].get("key1") == std2.get("key1")[item]).all() + assert ( + std2.contiguous()[item].get("key1") == std2.contiguous().get("key1")[item] + ).all() + assert (std2.contiguous().get("key1")[item] == 0).all() + + std3 = sub_td[:5, :, :5] + tensordicts2.zero_() + item = (*[slice(None) for _ in range(stack_dim)], 2) + assert (std3[item].get("key1") == std3.get("key1")[item]).all() + assert ( + std3.contiguous()[item].get("key1") == std3.contiguous().get("key1")[item] + ).all() + assert (std3.contiguous().get("key1")[item] == 0).all() + + std4 = sub_td.select("key1") + tensordicts3.zero_() + item = (*[slice(None) for _ in range(stack_dim)], 3) + assert (std4[item].get("key1") == std4.get("key1")[item]).all() + assert ( + std4.contiguous()[item].get("key1") == std4.contiguous().get("key1")[item] + ).all() + assert (std4.contiguous().get("key1")[item] == 0).all() + + std5 = sub_td.unbind(1)[0] + assert (std5.contiguous() == sub_td.contiguous().unbind(1)[0]).all() + + def test_stacked_td_nested_keys(self): td = torch.stack( [ - td, - td.apply(lambda x: x + 1), - td.apply(lambda x: x + 2), - td.apply(lambda x: x + 3), + TensorDict({"a": {"b": {"d": [1]}, "c": [2]}}, []), + TensorDict({"a": {"b": {"d": [1]}, "d": [2]}}, []), ], - stack_dim, + 0, ) - mask = torch.zeros(()) - while not mask.any(): - if single_mask_dim: - mask = torch.zeros(td.shape[mask_dim], dtype=torch.bool).bernoulli_() - else: - mask = torch.zeros( - td.shape[mask_dim : mask_dim + 2], dtype=torch.bool - ).bernoulli_() - index = (slice(None),) * mask_dim + (mask,) - tdmask = td[index] - assert tdmask["a"].shape == td["a"][index].shape - assert (tdmask["a"] == td["a"][index]).all() - index = (0,) * mask_dim + (mask,) - tdmask = td[index] - assert tdmask["a"].shape == td["a"][index].shape - assert (tdmask["a"] == td["a"][index]).all() - index = (slice(1),) * mask_dim + (mask,) - tdmask = td[index] - assert tdmask["a"].shape == td["a"][index].shape - assert (tdmask["a"] == td["a"][index]).all() + assert ("a", "b") in td.keys(True) + assert ("a", "c") not in td.keys(True) + assert ("a", "b", "d") in td.keys(True) + td["a", "c"] = [[2], [3]] + assert ("a", "c") in td.keys(True) - @pytest.mark.parametrize("stack_dim", [0, 1, 2]) - @pytest.mark.parametrize("mask_dim", [0, 1, 2]) - @pytest.mark.parametrize("single_mask_dim", [True, False]) - @pytest.mark.parametrize("device", get_available_devices()) - def test_lazy_mask_setitem(self, stack_dim, mask_dim, single_mask_dim, device): - torch.manual_seed(0) - td = TensorDict({"a": torch.zeros(9, 10, 11)}, [9, 10, 11], device=device) - td = torch.stack( - [ - td, - td.apply(lambda x: x + 1), - td.apply(lambda x: x + 2), - td.apply(lambda x: x + 3), - ], - stack_dim, + keys, items = zip(*td.items(True)) + assert ("a", "b") in keys + assert ("a", "c") in keys + assert ("a", "d") not in keys + + td["a", "c"] = td["a", "c"] + 1 + assert (td["a", "c"] == torch.tensor([[3], [4]], device=td.device)).all() + + def test_unbind_lazystack(self): + td0 = TensorDict( + { + "a": {"b": torch.randn(3, 4), "d": torch.randn(3, 4)}, + "c": torch.randn(3, 4), + }, + [3, 4], ) - mask = torch.zeros(()) - while not mask.any(): - if single_mask_dim: - mask = torch.zeros(td.shape[mask_dim], dtype=torch.bool).bernoulli_() - else: - mask = torch.zeros( - td.shape[mask_dim : mask_dim + 2], dtype=torch.bool - ).bernoulli_() - index = (slice(None),) * mask_dim + (mask,) - tdset = TensorDict({"a": td["a"][index] * 0 - 1}, []) - # we know that the batch size is accurate from test_lazy_mask_indexing - tdset.batch_size = td[index].batch_size - td[index] = tdset - assert (td["a"][index] == tdset["a"]).all() - assert (td["a"][index] == tdset["a"]).all() - index = (slice(1),) * mask_dim + (mask,) - tdset = TensorDict({"a": td["a"][index] * 0 - 1}, []) - tdset.batch_size = td[index].batch_size - td[index] = tdset - assert (td["a"][index] == tdset["a"]).all() - assert (td["a"][index] == tdset["a"]).all() + td = torch.stack([td0, td0, td0], 1) - def test_all_keys(self): - td = TensorDict({"a": torch.zeros(1)}, []) - td2 = TensorDict({"a": torch.zeros(2)}, []) - stack = torch.stack([td, td2]) - assert set(stack.keys(True, True)) == {"a"} + assert all(_td is td0 for _td in td.unbind(1)) + + def test_update_with_lazy(self): + td0 = TensorDict( + { + ("a", "b", "c"): torch.ones(3, 4), + ("a", "b", "d"): torch.ones(3, 4), + "common": torch.ones(3), + }, + [3], + ) + td1 = TensorDict( + { + ("a", "b", "c"): torch.ones(3, 5) * 2, + "common": torch.ones(3) * 2, + }, + [3], + ) + td = TensorDict({"parent": torch.stack([td0, td1], 0)}, [2]) - def test_best_intention_stack(self): - td0 = TensorDict({"a": 1, "b": TensorDict({"c": 2}, [])}, []) - td1 = TensorDict({"a": 1, "b": TensorDict({"d": 2}, [])}, []) - with set_lazy_legacy(False): - td = torch.stack([td0, td1]) - assert isinstance(td, TensorDict) - assert isinstance(td.get("b"), LazyStackedTensorDict) - td1 = TensorDict({"a": 1, "b": TensorDict({"c": [2]}, [])}, []) - with set_lazy_legacy(False): - td = torch.stack([td0, td1]) - assert isinstance(td, TensorDict) - assert isinstance(td.get("b"), LazyStackedTensorDict) + td_void = TensorDict( + { + ("parent", "a", "b", "c"): torch.zeros(2, 3, 4), + ("parent", "a", "b", "e"): torch.zeros(2, 3, 4), + ("parent", "a", "b", "d"): torch.zeros(2, 3, 5), + }, + [2], + ) + td_void.update(td) + assert type(td_void.get("parent")) is LazyStackedTensorDict + assert type(td_void.get(("parent", "a"))) is LazyStackedTensorDict + assert type(td_void.get(("parent", "a", "b"))) is LazyStackedTensorDict + assert (td_void.get(("parent", "a", "b"))[0].get("c") == 1).all() + assert (td_void.get(("parent", "a", "b"))[1].get("c") == 2).all() + assert (td_void.get(("parent", "a", "b"))[0].get("d") == 1).all() + assert (td_void.get(("parent", "a", "b"))[1].get("d") == 0).all() # unaffected + assert (td_void.get(("parent", "a", "b")).get("e") == 0).all() # unaffected @pytest.mark.skipif( @@ -5594,88 +5848,6 @@ def test_update( assert (td_plain == tensordict2).all() -@pytest.mark.parametrize("device", get_available_devices()) -def test_memmap_as_tensor(device): - td = TensorDict( - {"a": torch.randn(3, 4), "b": {"c": torch.randn(3, 4)}}, [3, 4], device="cpu" - ) - td_memmap = td.clone().memmap_() - assert (td == td_memmap).all() - - assert (td == td_memmap.apply(lambda x: x.clone())).all() - if device.type == "cuda": - td = td.pin_memory() - td_memmap = td.clone().memmap_() - td_memmap_pm = td_memmap.apply(lambda x: x.clone()).pin_memory() - assert (td.pin_memory().to(device) == td_memmap_pm.to(device)).all() - - -def test_tensordict_prealloc_nested(): - N = 3 - B = 5 - T = 4 - buffer = TensorDict({}, batch_size=[B, N]) - - td_0 = TensorDict( - { - "env.time": torch.rand(N, 1), - "agent.obs": TensorDict( - { # assuming 3 agents in a multi-agent setting - "image": torch.rand(N, T, 64), - "state": torch.rand(N, T, 3, 32, 32), - }, - batch_size=[N, T], - ), - }, - batch_size=[N], - ) - - td_1 = td_0.clone() - buffer[0] = td_0 - buffer[1] = td_1 - assert ( - repr(buffer) - == """TensorDict( - fields={ - agent.obs: TensorDict( - fields={ - image: Tensor(shape=torch.Size([5, 3, 4, 64]), device=cpu, dtype=torch.float32, is_shared=False), - state: Tensor(shape=torch.Size([5, 3, 4, 3, 32, 32]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5, 3, 4]), - device=None, - is_shared=False), - env.time: Tensor(shape=torch.Size([5, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5, 3]), - device=None, - is_shared=False)""" - ) - assert buffer.batch_size == torch.Size([B, N]) - assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) - - -@pytest.mark.parametrize("like", [True, False]) -def test_save_load_memmap_stacked_td( - like, - tmpdir, -): - a = TensorDict({"a": [1]}, []) - b = TensorDict({"b": [1]}, []) - c = torch.stack([a, b]) - c = c.expand(10, 2) - if like: - d = c.memmap_like(prefix=tmpdir) - else: - d = c.memmap_(prefix=tmpdir) - - d2 = LazyStackedTensorDict.load_memmap(tmpdir) - assert (d2 == d).all() - assert (d2[:, 0] == d[:, 0]).all() - if like: - assert (d2[:, 0] == a.zero_()).all() - else: - assert (d2[:, 0] == a).all() - - class TestErrorMessage: @staticmethod def test_err_msg_missing_nested(): @@ -5690,84 +5862,21 @@ def test_inplace_error(): td.set_("a", torch.randn(2)) -@pytest.mark.parametrize("batch_first", [True, False]) -@pytest.mark.parametrize("make_mask", [True, False]) -def test_pad_sequence(batch_first, make_mask): - list_td = [ - TensorDict({"a": torch.ones((2,)), ("b", "c"): torch.ones((2, 3))}, [2]), - TensorDict({"a": torch.ones((4,)), ("b", "c"): torch.ones((4, 3))}, [4]), - ] - padded_td = pad_sequence(list_td, batch_first=batch_first, return_mask=make_mask) - if batch_first: - assert padded_td.shape == torch.Size([2, 4]) - assert padded_td["a"].shape == torch.Size([2, 4]) - assert padded_td["a"][0, -1] == 0 - assert padded_td["b", "c"].shape == torch.Size([2, 4, 3]) - assert padded_td["b", "c"][0, -1, 0] == 0 - else: - assert padded_td.shape == torch.Size([4, 2]) - assert padded_td["a"].shape == torch.Size([4, 2]) - assert padded_td["a"][-1, 0] == 0 - assert padded_td["b", "c"].shape == torch.Size([4, 2, 3]) - assert padded_td["b", "c"][-1, 0, 0] == 0 - if make_mask: - assert "mask" in padded_td.keys() - assert not padded_td["mask"].all() - else: - assert "mask" not in padded_td.keys() - - class TestNamedDims(TestTensorDictsBase): - def test_noname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - assert td.names == [None] * 4 - - def test_fullname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) - assert td.names == ["a", "b", "c", "d"] - - def test_partial_name(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) - assert td.names == ["a", None, None, "d"] - - def test_partial_set(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - td.names = ["a", None, None, "d"] - assert td.names == ["a", None, None, "d"] - td.names = ["a", "b", "c", "d"] - assert td.names == ["a", "b", "c", "d"] - with pytest.raises( - ValueError, - match="the length of the dimension names must equate the tensordict batch_dims", - ): - td.names = ["a", "b", "c"] - - def test_rename(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - td.names = ["a", None, None, "d"] - td.rename_(a="c") - assert td.names == ["c", None, None, "d"] - td.rename_(d="z") - assert td.names == ["c", None, None, "z"] - td.rename_(*list("mnop")) - assert td.names == ["m", "n", "o", "p"] - td2 = td.rename(p="q") - assert td.names == ["m", "n", "o", "p"] - assert td2.names == ["m", "n", "o", "q"] - td2 = td.rename(*list("wxyz")) - assert td.names == ["m", "n", "o", "p"] - assert td2.names == ["w", "x", "y", "z"] + def test_all(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + tda = td.all(2) + assert tda.names == ["a", "b", "d"] + tda = td.any(2) + assert tda.names == ["a", "b", "d"] - def test_stack(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) - tds = torch.stack([td, td], 0) - assert tds.names == [None, "a", "b", "c", "d"] - tds = torch.stack([td, td], -1) - assert tds.names == ["a", "b", "c", "d", None] - tds = torch.stack([td, td], 2) - tds.names = list("mnopq") - assert tds.names == list("mnopq") - assert td.names == ["m", "n", "p", "q"] + def test_apply(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + tda = td.apply(lambda x: x + 1) + assert tda.names == ["a", "b", "c", "d"] + tda = td.apply(lambda x: x.squeeze(2), batch_size=[3, 4, 6]) + # no way to tell what the names have become, in general + assert tda.names == [None] * 3 def test_cat(self): td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) @@ -5777,25 +5886,17 @@ def test_cat(self): tdc = torch.cat([td, td], -1) assert tdc.names == ["a", "b", "c", "d"] - def test_unsqueeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - td.names = ["a", "b", "c", "d"] - tdu = td.unsqueeze(0) - assert tdu.names == [None, "a", "b", "c", "d"] - tdu = td.unsqueeze(-1) - assert tdu.names == ["a", "b", "c", "d", None] - tdu = td.unsqueeze(2) - assert tdu.names == ["a", "b", None, "c", "d"] - - def test_squeeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - td.names = ["a", "b", "c", "d"] - tds = td.squeeze(0) - assert tds.names == ["a", "b", "c", "d"] - td = TensorDict({}, batch_size=[3, 1, 5, 6], names=None) - td.names = ["a", "b", "c", "d"] - tds = td.squeeze(1) - assert tds.names == ["a", "c", "d"] + def test_change_batch_size(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td.batch_size = [3, 4, 1, 6, 1] + assert td.names == ["a", "b", "c", "z", None] + td.batch_size = [] + assert td.names == [] + td.batch_size = [3, 4] + assert td.names == [None, None] + td.names = ["a", None] + td.batch_size = [3] + assert td.names == ["a"] def test_clone(self): td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) @@ -5805,99 +5906,36 @@ def test_clone(self): tdc = td.clone(False) assert tdc.names == ["a", "b", "c", "d"] - def test_permute(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) - td.names = ["a", "b", "c", "d"] - tdp = td.permute(-1, -2, -3, -4) - assert tdp.names == list("dcba") - tdp = td.permute(-1, 1, 2, -4) - assert tdp.names == list("dbca") - - def test_refine_names(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6]) - tdr = td.refine_names(None, None, None, "d") - assert tdr.names == [None, None, None, "d"] - tdr = tdr.refine_names(None, None, "c", "d") - assert tdr.names == [None, None, "c", "d"] - with pytest.raises( - RuntimeError, match="refine_names: cannot coerce TensorDict" - ): - tdr.refine_names(None, None, "d", "d") - tdr = td.refine_names(..., "d") - assert tdr.names == [None, None, "c", "d"] - tdr = td.refine_names("a", ..., "d") - assert tdr.names == ["a", None, "c", "d"] - - def test_index(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) - assert td[0].names == ["b", "c", "d"] - assert td[:, 0].names == ["a", "c", "d"] - assert td[0, :].names == ["b", "c", "d"] - assert td[0, :1].names == ["b", "c", "d"] - assert td[..., -1].names == ["a", "b", "c"] - assert td[0, ..., -1].names == ["b", "c"] - assert td[0, ..., [-1]].names == ["b", "c", "d"] - assert td[0, ..., torch.tensor([-1])].names == ["b", "c", "d"] - assert td[0, ..., torch.tensor(-1)].names == ["b", "c"] - assert td[0, ..., :-1].names == ["b", "c", "d"] - assert td[:1, ..., :-1].names == ["a", "b", "c", "d"] - tdbool = td[torch.ones(3, dtype=torch.bool)] - assert tdbool.names == [None, "b", "c", "d"] - assert tdbool.ndim == 4 - tdbool = td[torch.ones(3, 4, dtype=torch.bool)] - assert tdbool.names == [None, "c", "d"] - assert tdbool.ndim == 3 + def test_detach(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td[""] = torch.zeros(td.shape, requires_grad=True) + tdd = td.detach() + assert tdd.names == ["a", "b", "c", "d"] - def test_subtd(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) - assert td._get_sub_tensordict(0).names == ["b", "c", "d"] - assert td._get_sub_tensordict((slice(None), 0)).names == ["a", "c", "d"] - assert td._get_sub_tensordict((0, slice(None))).names == ["b", "c", "d"] - assert td._get_sub_tensordict((0, slice(None, 1))).names == ["b", "c", "d"] - assert td._get_sub_tensordict((..., -1)).names == ["a", "b", "c"] - assert td._get_sub_tensordict((0, ..., -1)).names == ["b", "c"] - assert td._get_sub_tensordict((0, ..., [-1])).names == ["b", "c", "d"] - assert td._get_sub_tensordict((0, ..., torch.tensor([-1]))).names == [ - "b", - "c", - "d", - ] - assert td._get_sub_tensordict((0, ..., torch.tensor(-1))).names == ["b", "c"] - assert td._get_sub_tensordict((0, ..., slice(None, -1))).names == [ - "b", - "c", - "d", - ] - assert td._get_sub_tensordict((slice(None, 1), ..., slice(None, -1))).names == [ - "a", - "b", - "c", - "d", - ] - tdbool = td._get_sub_tensordict(torch.ones(3, dtype=torch.bool)) - assert tdbool.names == [None, "b", "c", "d"] - assert tdbool.ndim == 4 - tdbool = td._get_sub_tensordict(torch.ones(3, 4, dtype=torch.bool)) - assert tdbool.names == [None, "c", "d"] - assert tdbool.ndim == 3 - with pytest.raises( - RuntimeError, match="Names of a subtensordict cannot be modified" - ): - tdbool.names = "All work and no play makes Jack a dull boy" + def test_error_similar(self): + with pytest.raises(ValueError): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) + with pytest.raises(ValueError): + td = TensorDict( + {}, + batch_size=[3, 4, 1, 6], + ) + td.names = ["a", "b", "c", "a"] + with pytest.raises(ValueError): + td = TensorDict( + {}, + batch_size=[3, 4, 1, 6], + ) + td.refine_names("a", "a", ...) + with pytest.raises(ValueError): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td.rename_(a="z") def test_expand(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tde = td.expand(2, 3, 4, 5, 6) assert tde.names == [None, "a", "b", "c", "d"] - def test_apply(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - tda = td.apply(lambda x: x + 1) - assert tda.names == ["a", "b", "c", "d"] - tda = td.apply(lambda x: x.squeeze(2), batch_size=[3, 4, 6]) - # no way to tell what the names have become, in general - assert tda.names == [None] * 3 - def test_flatten(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tdf = td.flatten(1, 3) @@ -5913,49 +5951,52 @@ def test_flatten(self): tdu = tdf.unflatten(0, (3, 4, 1)) assert tdu.names == [None, None, None, "d"] + def test_fullname(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + assert td.names == ["a", "b", "c", "d"] + def test_gather(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) idx = torch.randint(6, (3, 4, 1, 18)) tdg = td.gather(dim=-1, index=idx) assert tdg.names == ["a", "b", "c", "d"] - def test_select(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - tds = td.select() - assert tds.names == ["a", "b", "c", "d"] - tde = td.exclude() - assert tde.names == ["a", "b", "c", "d"] - td[""] = torch.zeros(td.shape) - td["*"] = torch.zeros(td.shape) - tds = td.select("") - assert tds.names == ["a", "b", "c", "d"] - - def test_detach(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - td[""] = torch.zeros(td.shape, requires_grad=True) - tdd = td.detach() - assert tdd.names == ["a", "b", "c", "d"] - - def test_unbind(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - *_, tdu = td.unbind(-1) - assert tdu.names == ["a", "b", "c"] - *_, tdu = td.unbind(-2) - assert tdu.names == ["a", "b", "d"] + def test_h5(self, tmpdir): + td = TensorDict( + {"a": torch.zeros(3, 4, 1, 6)}, + batch_size=[3, 4, 1, 6], + names=["a", "b", "c", "d"], + ) + tdm = td.to_h5(filename=tmpdir / "file.h5") + assert tdm.names == ["a", "b", "c", "d"] - def test_split(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - _, tdu = td.split(dim=-1, split_size=[3, 3]) - assert tdu.names == ["a", "b", "c", "d"] - _, tdu = td.split(dim=1, split_size=[1, 3]) - assert tdu.names == ["a", "b", "c", "d"] + def test_h5_td(self): + td = self.td_h5("cpu") + td.names = list("abcd") + assert td.rename(c="g").names == list("abgd") + assert td.names == list("abcd") + td.rename_(c="g") + assert td.names == list("abgd") - def test_all(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - tda = td.all(2) - assert tda.names == ["a", "b", "d"] - tda = td.any(2) - assert tda.names == ["a", "b", "d"] + def test_index(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + assert td[0].names == ["b", "c", "d"] + assert td[:, 0].names == ["a", "c", "d"] + assert td[0, :].names == ["b", "c", "d"] + assert td[0, :1].names == ["b", "c", "d"] + assert td[..., -1].names == ["a", "b", "c"] + assert td[0, ..., -1].names == ["b", "c"] + assert td[0, ..., [-1]].names == ["b", "c", "d"] + assert td[0, ..., torch.tensor([-1])].names == ["b", "c", "d"] + assert td[0, ..., torch.tensor(-1)].names == ["b", "c"] + assert td[0, ..., :-1].names == ["b", "c", "d"] + assert td[:1, ..., :-1].names == ["a", "b", "c", "d"] + tdbool = td[torch.ones(3, dtype=torch.bool)] + assert tdbool.names == [None, "b", "c", "d"] + assert tdbool.ndim == 4 + tdbool = td[torch.ones(3, 4, dtype=torch.bool)] + assert tdbool.names == [None, "c", "d"] + assert tdbool.ndim == 3 def test_masked_fill(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) @@ -5971,14 +6012,14 @@ def test_memmap_like(self, tmpdir): tdm = td.memmap_like(prefix=tmpdir) assert tdm.names == ["a", "b", "c", "d"] - def test_h5(self, tmpdir): - td = TensorDict( - {"a": torch.zeros(3, 4, 1, 6)}, - batch_size=[3, 4, 1, 6], - names=["a", "b", "c", "d"], - ) - tdm = td.to_h5(filename=tmpdir / "file.h5") - assert tdm.names == ["a", "b", "c", "d"] + def test_memmap_td(self): + td = self.memmap_td("cpu") + td.names = list("abcd") + assert td.rename(c="g").names == list("abgd") + assert td.names == list("abcd") + td.rename_(c="g") + assert td.names == list("abgd") + assert td.clone().names == list("abgd") def test_nested(self): td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) @@ -5993,24 +6034,128 @@ def test_nested(self): td.set_("a", TensorDict({}, batch_size=[3, 4, 1, 6])) assert td["a"].names == td.names - def test_error_similar(self): - with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) - with pytest.raises(ValueError): - td = TensorDict( - {}, - batch_size=[3, 4, 1, 6], - ) - td.names = ["a", "b", "c", "a"] - with pytest.raises(ValueError): - td = TensorDict( - {}, - batch_size=[3, 4, 1, 6], - ) - td.refine_names("a", "a", ...) - with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) - td.rename_(a="z") + def test_nested_indexing(self): + td = TensorDict( + {"": TensorDict({}, [3, 4], names=["c", "d"])}, [3], names=["c"] + ) + assert td[0][""].names == td[""][0].names == ["d"] + + def test_nested_stacked_td(self): + td = self.nested_stacked_td("cpu") + td.names = list("abcd") + assert td.names == list("abcd") + assert td[:, 1].names == list("acd") + assert td["my_nested_td"][:, 1].names == list("acd") + assert td[:, 1]["my_nested_td"].names == list("acd") + tdr = td.rename(c="z") + assert td.names == list("abcd") + assert tdr.names == list("abzd") + td.rename_(c="z") + assert td.names == list("abzd") + assert td[:, 1].names == list("azd") + assert td["my_nested_td"][:, 1].names == list("azd") + assert td[:, 1]["my_nested_td"].names == list("azd") + assert td.contiguous().names == list("abzd") + assert td[:, 1].contiguous()["my_nested_td"].names == list("azd") + + def test_nested_tc(self): + nested_td = self.nested_tensorclass("cpu") + nested_td.names = list("abcd") + assert nested_td.rename(c="g").names == list("abgd") + assert nested_td.names == list("abcd") + nested_td.rename_(c="g") + assert nested_td.names == list("abgd") + assert nested_td.get("my_nested_tc").names == list("abgd") + assert nested_td.contiguous().names == list("abgd") + assert nested_td.contiguous().get("my_nested_tc").names == list("abgd") + + def test_nested_td(self): + nested_td = self.nested_td("cpu") + nested_td.names = list("abcd") + assert nested_td.rename(c="g").names == list("abgd") + assert nested_td.names == list("abcd") + nested_td.rename_(c="g") + assert nested_td.names == list("abgd") + assert nested_td["my_nested_td"].names == list("abgd") + assert nested_td.contiguous().names == list("abgd") + assert nested_td.contiguous()["my_nested_td"].names == list("abgd") + + def test_noname(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + assert td.names == [None] * 4 + + def test_partial_name(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) + assert td.names == ["a", None, None, "d"] + + def test_partial_set(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td.names = ["a", None, None, "d"] + assert td.names == ["a", None, None, "d"] + td.names = ["a", "b", "c", "d"] + assert td.names == ["a", "b", "c", "d"] + with pytest.raises( + ValueError, + match="the length of the dimension names must equate the tensordict batch_dims", + ): + td.names = ["a", "b", "c"] + + def test_permute(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td.names = ["a", "b", "c", "d"] + tdp = td.permute(-1, -2, -3, -4) + assert tdp.names == list("dcba") + tdp = td.permute(-1, 1, 2, -4) + assert tdp.names == list("dbca") + + def test_permute_td(self): + td = self.unsqueezed_td("cpu") + with pytest.raises( + RuntimeError, match="Names of a lazy tensordict cannot be modified" + ): + td.names = list("abcd") + + def test_refine_names(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6]) + tdr = td.refine_names(None, None, None, "d") + assert tdr.names == [None, None, None, "d"] + tdr = tdr.refine_names(None, None, "c", "d") + assert tdr.names == [None, None, "c", "d"] + with pytest.raises( + RuntimeError, match="refine_names: cannot coerce TensorDict" + ): + tdr.refine_names(None, None, "d", "d") + tdr = td.refine_names(..., "d") + assert tdr.names == [None, None, "c", "d"] + tdr = td.refine_names("a", ..., "d") + assert tdr.names == ["a", None, "c", "d"] + + def test_rename(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td.names = ["a", None, None, "d"] + td.rename_(a="c") + assert td.names == ["c", None, None, "d"] + td.rename_(d="z") + assert td.names == ["c", None, None, "z"] + td.rename_(*list("mnop")) + assert td.names == ["m", "n", "o", "p"] + td2 = td.rename(p="q") + assert td.names == ["m", "n", "o", "p"] + assert td2.names == ["m", "n", "o", "q"] + td2 = td.rename(*list("wxyz")) + assert td.names == ["m", "n", "o", "p"] + assert td2.names == ["w", "x", "y", "z"] + + def test_select(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + tds = td.select() + assert tds.names == ["a", "b", "c", "d"] + tde = td.exclude() + assert tde.names == ["a", "b", "c", "d"] + td[""] = torch.zeros(td.shape) + td["*"] = torch.zeros(td.shape) + tds = td.select("") + assert tds.names == ["a", "b", "c", "d"] def test_set_at(self): td = TensorDict( @@ -6022,39 +6167,46 @@ def test_set_at(self): assert td.names == ["a", "b", "c", "d"] assert td[""].names == ["a", "b", "c", "d"] - def test_change_batch_size(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) - td.batch_size = [3, 4, 1, 6, 1] - assert td.names == ["a", "b", "c", "z", None] - td.batch_size = [] - assert td.names == [] - td.batch_size = [3, 4] - assert td.names == [None, None] - td.names = ["a", None] - td.batch_size = [3] - assert td.names == ["a"] - def test_set_item_populate_names(self): td = TensorDict({}, [3]) td["a"] = TensorDict({}, [3, 4], names=["a", "b"]) assert td.names == ["a"] assert td["a"].names == ["a", "b"] - def test_nested_indexing(self): - td = TensorDict( - {"": TensorDict({}, [3, 4], names=["c", "d"])}, [3], names=["c"] - ) - assert td[0][""].names == td[""][0].names == ["d"] + def test_split(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + _, tdu = td.split(dim=-1, split_size=[3, 3]) + assert tdu.names == ["a", "b", "c", "d"] + _, tdu = td.split(dim=1, split_size=[1, 3]) + assert tdu.names == ["a", "b", "c", "d"] + + def test_squeeze(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td.names = ["a", "b", "c", "d"] + tds = td.squeeze(0) + assert tds.names == ["a", "b", "c", "d"] + td = TensorDict({}, batch_size=[3, 1, 5, 6], names=None) + td.names = ["a", "b", "c", "d"] + tds = td.squeeze(1) + assert tds.names == ["a", "c", "d"] - @pytest.mark.parametrize("device", get_available_devices()) - def test_to(self, device): - td = TensorDict( - {"": TensorDict({}, [3, 4, 1, 6])}, - batch_size=[3, 4, 1, 6], - names=["a", "b", "c", "d"], - ) - tdt = td.to(device) - assert tdt.names == ["a", "b", "c", "d"] + def test_squeeze_td(self): + td = self.squeezed_td("cpu") + with pytest.raises( + RuntimeError, match="Names of a lazy tensordict cannot be modified" + ): + td.names = list("abcd") + + def test_stack(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + tds = torch.stack([td, td], 0) + assert tds.names == [None, "a", "b", "c", "d"] + tds = torch.stack([td, td], -1) + assert tds.names == ["a", "b", "c", "d", None] + tds = torch.stack([td, td], 2) + tds.names = list("mnopq") + assert tds.names == list("mnopq") + assert td.names == ["m", "n", "p", "q"] def test_stack_assign(self): td = TensorDict( @@ -6071,46 +6223,6 @@ def test_stack_assign(self): assert tds[0].names == ["e"] assert tds[0][""].names == tds[""][0].names == ["e", "d"] - def test_nested_td(self): - nested_td = self.nested_td("cpu") - nested_td.names = list("abcd") - assert nested_td.rename(c="g").names == list("abgd") - assert nested_td.names == list("abcd") - nested_td.rename_(c="g") - assert nested_td.names == list("abgd") - assert nested_td["my_nested_td"].names == list("abgd") - assert nested_td.contiguous().names == list("abgd") - assert nested_td.contiguous()["my_nested_td"].names == list("abgd") - - def test_nested_tc(self): - nested_td = self.nested_tensorclass("cpu") - nested_td.names = list("abcd") - assert nested_td.rename(c="g").names == list("abgd") - assert nested_td.names == list("abcd") - nested_td.rename_(c="g") - assert nested_td.names == list("abgd") - assert nested_td.get("my_nested_tc").names == list("abgd") - assert nested_td.contiguous().names == list("abgd") - assert nested_td.contiguous().get("my_nested_tc").names == list("abgd") - - def test_nested_stacked_td(self): - td = self.nested_stacked_td("cpu") - td.names = list("abcd") - assert td.names == list("abcd") - assert td[:, 1].names == list("acd") - assert td["my_nested_td"][:, 1].names == list("acd") - assert td[:, 1]["my_nested_td"].names == list("acd") - tdr = td.rename(c="z") - assert td.names == list("abcd") - assert tdr.names == list("abzd") - td.rename_(c="z") - assert td.names == list("abzd") - assert td[:, 1].names == list("azd") - assert td["my_nested_td"][:, 1].names == list("azd") - assert td[:, 1]["my_nested_td"].names == list("azd") - assert td.contiguous().names == list("abzd") - assert td[:, 1].contiguous()["my_nested_td"].names == list("azd") - def test_sub_td(self): td = self.sub_td("cpu") with pytest.raises( @@ -6123,36 +6235,69 @@ def test_sub_td(self): ): td.names = list("abcd") - def test_memmap_td(self): - td = self.memmap_td("cpu") - td.names = list("abcd") - assert td.rename(c="g").names == list("abgd") - assert td.names == list("abcd") - td.rename_(c="g") - assert td.names == list("abgd") - assert td.clone().names == list("abgd") - - def test_h5_td(self): - td = self.td_h5("cpu") - td.names = list("abcd") - assert td.rename(c="g").names == list("abgd") - assert td.names == list("abcd") - td.rename_(c="g") - assert td.names == list("abgd") - - def test_permute_td(self): - td = self.unsqueezed_td("cpu") + def test_subtd(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + assert td._get_sub_tensordict(0).names == ["b", "c", "d"] + assert td._get_sub_tensordict((slice(None), 0)).names == ["a", "c", "d"] + assert td._get_sub_tensordict((0, slice(None))).names == ["b", "c", "d"] + assert td._get_sub_tensordict((0, slice(None, 1))).names == ["b", "c", "d"] + assert td._get_sub_tensordict((..., -1)).names == ["a", "b", "c"] + assert td._get_sub_tensordict((0, ..., -1)).names == ["b", "c"] + assert td._get_sub_tensordict((0, ..., [-1])).names == ["b", "c", "d"] + assert td._get_sub_tensordict((0, ..., torch.tensor([-1]))).names == [ + "b", + "c", + "d", + ] + assert td._get_sub_tensordict((0, ..., torch.tensor(-1))).names == ["b", "c"] + assert td._get_sub_tensordict((0, ..., slice(None, -1))).names == [ + "b", + "c", + "d", + ] + assert td._get_sub_tensordict((slice(None, 1), ..., slice(None, -1))).names == [ + "a", + "b", + "c", + "d", + ] + tdbool = td._get_sub_tensordict(torch.ones(3, dtype=torch.bool)) + assert tdbool.names == [None, "b", "c", "d"] + assert tdbool.ndim == 4 + tdbool = td._get_sub_tensordict(torch.ones(3, 4, dtype=torch.bool)) + assert tdbool.names == [None, "c", "d"] + assert tdbool.ndim == 3 with pytest.raises( - RuntimeError, match="Names of a lazy tensordict cannot be modified" + RuntimeError, match="Names of a subtensordict cannot be modified" ): - td.names = list("abcd") + tdbool.names = "All work and no play makes Jack a dull boy" - def test_squeeze_td(self): - td = self.squeezed_td("cpu") - with pytest.raises( - RuntimeError, match="Names of a lazy tensordict cannot be modified" - ): - td.names = list("abcd") + @pytest.mark.parametrize("device", get_available_devices()) + def test_to(self, device): + td = TensorDict( + {"": TensorDict({}, [3, 4, 1, 6])}, + batch_size=[3, 4, 1, 6], + names=["a", "b", "c", "d"], + ) + tdt = td.to(device) + assert tdt.names == ["a", "b", "c", "d"] + + def test_unbind(self): + td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + *_, tdu = td.unbind(-1) + assert tdu.names == ["a", "b", "c"] + *_, tdu = td.unbind(-2) + assert tdu.names == ["a", "b", "d"] + + def test_unsqueeze(self): + td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td.names = ["a", "b", "c", "d"] + tdu = td.unsqueeze(0) + assert tdu.names == [None, "a", "b", "c", "d"] + tdu = td.unsqueeze(-1) + assert tdu.names == ["a", "b", "c", "d", None] + tdu = td.unsqueeze(2) + assert tdu.names == ["a", "b", None, "c", "d"] def test_unsqueeze_td(self): td = self.unsqueezed_td("cpu") @@ -6162,27 +6307,6 @@ def test_unsqueeze_td(self): td.names = list("abcd") -def _compare_tensors_identity(td0, td1): - if isinstance(td0, LazyStackedTensorDict): - if not isinstance(td1, LazyStackedTensorDict): - return False - for _td0, _td1 in zip(td0.tensordicts, td1.tensordicts): - if not _compare_tensors_identity(_td0, _td1): - return False - return True - if td0 is td1: - return True - for key, val in td0.items(): - if is_tensor_collection(val): - if not _compare_tensors_identity(val, td1.get(key)): - return False - else: - if val is not td1.get(key): - return False - else: - return True - - class TestLock: @staticmethod def check_weakref_count(weakref_list, expected): @@ -6195,6 +6319,81 @@ def check_weakref_count(weakref_list, expected): ids.add(id(td)) assert count == expected, {id(ref()) for ref in weakref_list} + def test_lock_stack(self): + td0 = TensorDict({("a", "b", "c", "d"): 1.0}, []) + td1 = td0.clone() + td = torch.stack([td0, td1]) + td = td.lock_() + a = td["a"] + b = td["a", "b"] + c = td["a", "b", "c"] + a0 = td0["a"] + b0 = td0["a", "b"] + c0 = td0["a", "b", "c"] + self.check_weakref_count(a._lock_parents_weakrefs, 3) # td, td0, td1 + self.check_weakref_count(b._lock_parents_weakrefs, 5) # td, td0, td1, a0, a1 + self.check_weakref_count( + c._lock_parents_weakrefs, 7 + ) # td, td0, td1, a0, a1, b0, b1 + self.check_weakref_count(a0._lock_parents_weakrefs, 2) # td, td0 + self.check_weakref_count(b0._lock_parents_weakrefs, 3) # td, td0, a0 + self.check_weakref_count(c0._lock_parents_weakrefs, 4) # td, td0, a0, b0 + td.unlock_() + td.lock_() + del td, td0, td1 + gc.collect() + a.unlock_() + a.lock_() + self.check_weakref_count(a._lock_parents_weakrefs, 0) + self.check_weakref_count(b._lock_parents_weakrefs, 3) # a, a0, a1 + self.check_weakref_count(c._lock_parents_weakrefs, 5) # a, a0, a1, b0, b1 + self.check_weakref_count(a0._lock_parents_weakrefs, 1) # a + self.check_weakref_count(b0._lock_parents_weakrefs, 2) # a, a0 + self.check_weakref_count(c0._lock_parents_weakrefs, 3) # a, a0, b0 + del a, a0 + gc.collect() + b.unlock_() + b.lock_() + del b + + def test_lock_two_roots(self): + td = TensorDict({("a", "b", "c", "d"): 1.0}, []) + td = td.lock_() + a = td["a"] + b = td["a", "b"] + c = td["a", "b", "c"] + other_td = TensorDict({"a": a}, []) + other_td.lock_() + # we cannot unlock anything anymore + with pytest.raises( + RuntimeError, + match="Cannot unlock a tensordict that is part of a locked graph.", + ): + other_td.unlock_() + assert td._is_locked + assert td.is_locked + with pytest.raises( + RuntimeError, + match="Cannot unlock a tensordict that is part of a locked graph.", + ): + td.unlock_() + # if we group them we can't unlock + supertd = TensorDict({"td": td, "other": other_td}, []) + supertd = supertd.lock_() + supertd = supertd.unlock_() + supertd = supertd.lock_() + del supertd, other_td + gc.collect() + self.check_weakref_count(td._lock_parents_weakrefs, 0) + # self.check_td_not_in_weakref_list(supertd, a._lock_parents_weakrefs) + # self.check_td_not_in_weakref_list(other_td, a._lock_parents_weakrefs) + self.check_weakref_count(a._lock_parents_weakrefs, 1) + # self.check_td_not_in_weakref_list(supertd, b._lock_parents_weakrefs) + # self.check_td_not_in_weakref_list(other_td, b._lock_parents_weakrefs) + self.check_weakref_count(b._lock_parents_weakrefs, 2) + self.check_weakref_count(c._lock_parents_weakrefs, 3) + td.unlock_() + def test_nested_lock(self): td = TensorDict({("a", "b", "c", "d"): 1.0}, []) td = td.lock_() @@ -6255,81 +6454,6 @@ def test_nested_lock_erros(self): ): b.unlock_() - def test_lock_two_roots(self): - td = TensorDict({("a", "b", "c", "d"): 1.0}, []) - td = td.lock_() - a = td["a"] - b = td["a", "b"] - c = td["a", "b", "c"] - other_td = TensorDict({"a": a}, []) - other_td.lock_() - # we cannot unlock anything anymore - with pytest.raises( - RuntimeError, - match="Cannot unlock a tensordict that is part of a locked graph.", - ): - other_td.unlock_() - assert td._is_locked - assert td.is_locked - with pytest.raises( - RuntimeError, - match="Cannot unlock a tensordict that is part of a locked graph.", - ): - td.unlock_() - # if we group them we can't unlock - supertd = TensorDict({"td": td, "other": other_td}, []) - supertd = supertd.lock_() - supertd = supertd.unlock_() - supertd = supertd.lock_() - del supertd, other_td - gc.collect() - self.check_weakref_count(td._lock_parents_weakrefs, 0) - # self.check_td_not_in_weakref_list(supertd, a._lock_parents_weakrefs) - # self.check_td_not_in_weakref_list(other_td, a._lock_parents_weakrefs) - self.check_weakref_count(a._lock_parents_weakrefs, 1) - # self.check_td_not_in_weakref_list(supertd, b._lock_parents_weakrefs) - # self.check_td_not_in_weakref_list(other_td, b._lock_parents_weakrefs) - self.check_weakref_count(b._lock_parents_weakrefs, 2) - self.check_weakref_count(c._lock_parents_weakrefs, 3) - td.unlock_() - - def test_lock_stack(self): - td0 = TensorDict({("a", "b", "c", "d"): 1.0}, []) - td1 = td0.clone() - td = torch.stack([td0, td1]) - td = td.lock_() - a = td["a"] - b = td["a", "b"] - c = td["a", "b", "c"] - a0 = td0["a"] - b0 = td0["a", "b"] - c0 = td0["a", "b", "c"] - self.check_weakref_count(a._lock_parents_weakrefs, 3) # td, td0, td1 - self.check_weakref_count(b._lock_parents_weakrefs, 5) # td, td0, td1, a0, a1 - self.check_weakref_count( - c._lock_parents_weakrefs, 7 - ) # td, td0, td1, a0, a1, b0, b1 - self.check_weakref_count(a0._lock_parents_weakrefs, 2) # td, td0 - self.check_weakref_count(b0._lock_parents_weakrefs, 3) # td, td0, a0 - self.check_weakref_count(c0._lock_parents_weakrefs, 4) # td, td0, a0, b0 - td.unlock_() - td.lock_() - del td, td0, td1 - gc.collect() - a.unlock_() - a.lock_() - self.check_weakref_count(a._lock_parents_weakrefs, 0) - self.check_weakref_count(b._lock_parents_weakrefs, 3) # a, a0, a1 - self.check_weakref_count(c._lock_parents_weakrefs, 5) # a, a0, a1, b0, b1 - self.check_weakref_count(a0._lock_parents_weakrefs, 1) # a - self.check_weakref_count(b0._lock_parents_weakrefs, 2) # a, a0 - self.check_weakref_count(c0._lock_parents_weakrefs, 3) # a, a0, b0 - del a, a0 - gc.collect() - b.unlock_() - b.lock_() - del b - def test_stack_cache_lock(self): td0 = TensorDict({("a", "b", "c", "d"): 1.0}, []) td1 = td0.clone() @@ -6380,250 +6504,12 @@ def test_stacked_append_and_insert(self): td.append(td0) -@pytest.mark.parametrize("memmap", [True, False]) -@pytest.mark.parametrize("params", [False, True]) -def test_from_module(memmap, params): - net = nn.Transformer( - d_model=16, - nhead=2, - num_encoder_layers=3, - dim_feedforward=12, - ) - td = TensorDict.from_module(net, as_module=params) - # check that we have empty tensordicts, reflecting modules wihout params - for subtd in td.values(True): - if isinstance(subtd, TensorDictBase) and subtd.is_empty(): - break - else: - raise RuntimeError - if memmap: - td = td.detach().memmap_() - net.load_state_dict(td.flatten_keys(".")) - - if not memmap and params: - assert set(td.parameters()) == set(net.parameters()) - - -def test_from_module_state_dict(): - net = nn.Transformer( - d_model=16, - nhead=2, - num_encoder_layers=3, - dim_feedforward=12, - ) - - def adder(module, *args, **kwargs): - for p in module.parameters(recurse=False): - p.data += 1 - - def remover(module, *args, **kwargs): - for p in module.parameters(recurse=False): - p.data = p.data - 1 - - for module in net.modules(): - module.register_state_dict_pre_hook(adder) - module._register_state_dict_hook(remover) - params_reg = TensorDict.from_module(net) - params_reg = params_reg.select(*params_reg.keys(True, True)) - - params_sd = TensorDict.from_module(net, use_state_dict=True) - params_sd = params_sd.select(*params_sd.keys(True, True)) - assert_allclose_td(params_sd, params_reg.apply(lambda x: x + 1)) - - sd = net.state_dict() - assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, [])) - - -def test_to_module_state_dict(): - net0 = nn.Transformer( - d_model=16, - nhead=2, - num_encoder_layers=3, - dim_feedforward=12, - ) - net1 = nn.Transformer( - d_model=16, - nhead=2, - num_encoder_layers=3, - dim_feedforward=12, - ) - - def hook( - module, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - for key, val in list(state_dict.items()): - state_dict[key] = val * 0 - - for module in net0.modules(): - module._register_load_state_dict_pre_hook(hook, with_module=True) - for module in net1.modules(): - module._register_load_state_dict_pre_hook(hook, with_module=True) - - params_reg = TensorDict.from_module(net0) - params_reg.to_module(net0, use_state_dict=True) - params_reg = TensorDict.from_module(net0) - - sd = net1.state_dict() - net1.load_state_dict(sd) - sd = net1.state_dict() - - assert (params_reg == 0).all() - assert set(params_reg.flatten_keys(".").keys()) == set(sd.keys()) - assert_allclose_td(params_reg.flatten_keys("."), TensorDict(sd, [])) - - -@pytest.mark.parametrize("batch_size", [None, [3, 4]]) -@pytest.mark.parametrize("batch_dims", [None, 1, 2]) -@pytest.mark.parametrize("device", get_available_devices()) -def test_from_dict(batch_size, batch_dims, device): - data = { - "a": torch.zeros(3, 4, 5), - "b": {"c": torch.zeros(3, 4, 5, 6)}, - ("d", "e"): torch.ones(3, 4, 5), - ("b", "f"): torch.zeros(3, 4, 5, 5), - ("d", "g", "h"): torch.ones(3, 4, 5), - } - if batch_dims and batch_size: - with pytest.raises(ValueError, match="both"): - TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device - ) - return - data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device - ) - assert data.device == device - assert "a" in data.keys() - assert ("b", "c") in data.keys(True) - assert ("b", "f") in data.keys(True) - assert ("d", "e") in data.keys(True) - assert data.device == device - if batch_dims: - assert data.ndim == batch_dims - assert data["b"].ndim == batch_dims - assert data["d"].ndim == batch_dims - assert data["d", "g"].ndim == batch_dims - elif batch_size: - assert data.batch_size == torch.Size(batch_size) - assert data["b"].batch_size == torch.Size(batch_size) - assert data["d"].batch_size == torch.Size(batch_size) - assert data["d", "g"].batch_size == torch.Size(batch_size) - - -def test_unbind_batchsize(): - td = TensorDict({"a": TensorDict({"b": torch.zeros(2, 3)}, [2, 3])}, [2]) - td["a"].batch_size - tds = td.unbind(0) - assert tds[0].batch_size == torch.Size([]) - assert tds[0]["a"].batch_size == torch.Size([3]) - - -def test_empty(): - td = TensorDict( - { - "a": torch.zeros(()), - ("b", "c"): torch.zeros(()), - ("b", "d", "e"): torch.zeros(()), - }, - [], - ) - td_empty = td.empty(recurse=False) - assert len(list(td_empty.keys())) == 0 - td_empty = td.empty(recurse=True) - assert len(list(td_empty.keys())) == 1 - assert len(list(td_empty.get("b").keys())) == 1 - - -@pytest.mark.parametrize( - "stack_dim", - [0, 1, 2, 3], -) -@pytest.mark.parametrize( - "nested_stack_dim", - [0, 1, 2], -) -def test_dense_stack_tds(stack_dim, nested_stack_dim): - batch_size = (5, 6) - td0 = TensorDict( - {"a": torch.zeros(*batch_size, 3)}, - batch_size, - ) - td1 = TensorDict( - {"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)}, - batch_size, - ) - td_lazy = torch.stack([td0, td1], dim=nested_stack_dim) - td_container = TensorDict({"lazy": td_lazy}, td_lazy.batch_size) - td_container_clone = td_container.clone() - td_container_clone.apply_(lambda x: x + 1) - - assert td_lazy.stack_dim == nested_stack_dim - td_stack = torch.stack([td_container, td_container_clone], dim=stack_dim) - assert td_stack.stack_dim == stack_dim - - assert isinstance(td_stack, LazyStackedTensorDict) - dense_td_stack = dense_stack_tds(td_stack) - assert isinstance(dense_td_stack, TensorDict) # check outer layer is non-lazy - assert isinstance( - dense_td_stack["lazy"], LazyStackedTensorDict - ) # while inner layer is still lazy - assert "b" not in dense_td_stack["lazy"].tensordicts[0].keys() - assert "b" in dense_td_stack["lazy"].tensordicts[1].keys() - - assert assert_allclose_td( - dense_td_stack, - dense_stack_tds([td_container, td_container_clone], dim=stack_dim), - ) # This shows it is the same to pass a list or a LazyStackedTensorDict - - for i in range(2): - index = (slice(None),) * stack_dim + (i,) - assert (dense_td_stack[index] == i).all() - - if stack_dim > nested_stack_dim: - assert dense_td_stack["lazy"].stack_dim == nested_stack_dim - else: - assert dense_td_stack["lazy"].stack_dim == nested_stack_dim + 1 - - @pytest.mark.parametrize( "td_name,device", TestTensorDictsBase.TYPES_DEVICES, ) class TestTensorDictMP(TestTensorDictsBase): # Tests sharing a locked tensordict - @staticmethod - def worker_lock(td, q): - assert td.is_locked - for val in td.values(True): - if is_tensor_collection(val): - assert val.is_locked - assert val._lock_parents_weakrefs - assert not td._lock_parents_weakrefs - q.put("succeeded") - - def test_sharing_locked_td(self, td_name, device): - td = getattr(self, td_name)(device) - if td_name in ("sub_td", "sub_td2"): - pytest.skip("cannot lock sub-tds") - if td_name in ("td_h5",): - pytest.skip("h5 files should not be opened across different processes.") - q = mp.Queue(1) - try: - p = mp.Process(target=self.worker_lock, args=(td.lock_(), q)) - p.start() - assert q.get(timeout=30) == "succeeded" - finally: - try: - p.join() - except AssertionError: - pass @staticmethod def add1(x): @@ -6638,35 +6524,6 @@ def add1_app_error(x): # algerbraic ops are not supported return x + 1 - @staticmethod - def write_pid(x): - return TensorDict({"pid": os.getpid()}, []).expand(x.shape) - - @pytest.mark.parametrize("dim", [-2, -1, 0, 1, 2, 3]) - def test_map(self, td_name, device, dim, _pool_fixt): - td = getattr(self, td_name)(device) - if td_name == "td_params": - with pytest.raises( - RuntimeError, match="Cannot call map on a TensorDictParams object" - ): - td.map(self.add1_app, dim=dim, pool=_pool_fixt) - return - assert ( - td.map(self.add1_app, dim=dim, pool=_pool_fixt) == td.apply(self.add1) - ).all() - - @pytest.mark.parametrize("dim", [-2, -1, 0, 1, 2, 3]) - def test_map_exception(self, td_name, device, dim, _pool_fixt): - td = getattr(self, td_name)(device) - if td_name == "td_params": - with pytest.raises( - RuntimeError, match="Cannot call map on a TensorDictParams object" - ): - td.map(self.add1_app_error, dim=dim, pool=_pool_fixt) - return - with pytest.raises(TypeError, match="unsupported operand"): - td.map(self.add1_app_error, dim=dim, pool=_pool_fixt) - @pytest.mark.parametrize( "chunksize,num_chunks", [[None, 2], [4, None], [None, None], [2, 2]] ) @@ -6703,6 +6560,62 @@ def test_chunksize_num_chunks( elif num_chunks is not None: assert pids.numel() == num_chunks + @pytest.mark.parametrize("dim", [-2, -1, 0, 1, 2, 3]) + def test_map(self, td_name, device, dim, _pool_fixt): + td = getattr(self, td_name)(device) + if td_name == "td_params": + with pytest.raises( + RuntimeError, match="Cannot call map on a TensorDictParams object" + ): + td.map(self.add1_app, dim=dim, pool=_pool_fixt) + return + assert ( + td.map(self.add1_app, dim=dim, pool=_pool_fixt) == td.apply(self.add1) + ).all() + + @pytest.mark.parametrize("dim", [-2, -1, 0, 1, 2, 3]) + def test_map_exception(self, td_name, device, dim, _pool_fixt): + td = getattr(self, td_name)(device) + if td_name == "td_params": + with pytest.raises( + RuntimeError, match="Cannot call map on a TensorDictParams object" + ): + td.map(self.add1_app_error, dim=dim, pool=_pool_fixt) + return + with pytest.raises(TypeError, match="unsupported operand"): + td.map(self.add1_app_error, dim=dim, pool=_pool_fixt) + + def test_sharing_locked_td(self, td_name, device): + td = getattr(self, td_name)(device) + if td_name in ("sub_td", "sub_td2"): + pytest.skip("cannot lock sub-tds") + if td_name in ("td_h5",): + pytest.skip("h5 files should not be opened across different processes.") + q = mp.Queue(1) + try: + p = mp.Process(target=self.worker_lock, args=(td.lock_(), q)) + p.start() + assert q.get(timeout=30) == "succeeded" + finally: + try: + p.join() + except AssertionError: + pass + + @staticmethod + def worker_lock(td, q): + assert td.is_locked + for val in td.values(True): + if is_tensor_collection(val): + assert val.is_locked + assert val._lock_parents_weakrefs + assert not td._lock_parents_weakrefs + q.put("succeeded") + + @staticmethod + def write_pid(x): + return TensorDict({"pid": os.getpid()}, []).expand(x.shape) + @pytest.fixture(scope="class") def _pool_fixt(): @@ -6849,6 +6762,10 @@ def test_modules(self, as_module): class TestMap: """Tests for TensorDict.map that are independent from tensordict's type.""" + @staticmethod + def _set_2(td): + return td.set("2", 2) + @classmethod def get_rand_incr(cls, td): # torch @@ -6944,10 +6861,6 @@ def test_map_seed_single(self): td_out_1["s"].sort().values, ) - @staticmethod - def _set_2(td): - return td.set("2", 2) - def test_map_unbind(self): if mp.get_start_method(allow_none=True) is None: mp.set_start_method("spawn") @@ -6976,6 +6889,17 @@ def non_tensor_data(self): batch_size=[], ) + def test_comparison(self, non_tensor_data): + non_tensor_data = non_tensor_data.exclude(("nested", "str")) + assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool")) + assert not (non_tensor_data ^ non_tensor_data).get_non_tensor( + ("nested", "bool") + ) + assert (non_tensor_data == non_tensor_data).get_non_tensor(("nested", "bool")) + assert not (non_tensor_data != non_tensor_data).get_non_tensor( + ("nested", "bool") + ) + def test_nontensor_dict(self, non_tensor_data): assert ( TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data @@ -7013,17 +6937,6 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) - def test_comparison(self, non_tensor_data): - non_tensor_data = non_tensor_data.exclude(("nested", "str")) - assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool")) - assert not (non_tensor_data ^ non_tensor_data).get_non_tensor( - ("nested", "bool") - ) - assert (non_tensor_data == non_tensor_data).get_non_tensor(("nested", "bool")) - assert not (non_tensor_data != non_tensor_data).get_non_tensor( - ("nested", "bool") - ) - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()