From fc48d84af3ab806c5578c317b0ad880057d362da Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Feb 2024 11:03:17 +0000 Subject: [PATCH 01/20] init --- tensordict/base.py | 2 +- tensordict/tensorclass.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 51c734dd1..daed2d7ec 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1975,7 +1975,7 @@ def memmap_like( else: return TensorDictFuture(futures, result) input = self.apply( - lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape) + lambda x: torch.empty_like(x) ) return input._memmap_( prefix=prefix, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d2b249cb5..a3b8bdf4f 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -56,6 +56,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.empty_like, + torch.randn_like, + torch.rand_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -1246,6 +1249,13 @@ class NonTensorData: # and all the overhead falls back on this class. data: Any + @classmethod + def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None): + """A util to create a NonTensorData containing a tensor.""" + out = cls(data=None, batch_size=batch_size, device=device, names=names) + out._non_tensordict["data"] = value + return out + def __post_init__(self): if isinstance(self.data, NonTensorData): self.data = self.data.data @@ -1304,7 +1314,7 @@ def __or__(self, other): self.__class__.__or__ = __or__ def empty(self, recurse=False): - return NonTensorData( + return type(self)( data=self.data, batch_size=self.batch_size, names=self.names if self._has_names() else None, @@ -1332,7 +1342,7 @@ def _check_equal(a, b): if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) - return NonTensorData( + return type(self)( data=first.data, batch_size=batch_size, names=first.names if first._has_names() else None, @@ -1358,7 +1368,7 @@ def __torch_function__( ): return NotImplemented - escape_conversion = func in (torch.stack,) + escape_conversion = func in (torch.stack, torch.ones_like, torch.zeros_like, torch.empty_like, torch.randn_like, torch.rand_like) if kwargs is None: kwargs = {} From 8f7722d0729067eb982a1be2daffc7ed2bb80d0a Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Feb 2024 16:09:48 +0000 Subject: [PATCH 02/20] amend --- tensordict/base.py | 4 +--- tensordict/tensorclass.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index daed2d7ec..d26c2c900 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1974,9 +1974,7 @@ def memmap_like( return result else: return TensorDictFuture(futures, result) - input = self.apply( - lambda x: torch.empty_like(x) - ) + input = self.apply(lambda x: torch.empty_like(x)) return input._memmap_( prefix=prefix, copy_existing=copy_existing, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a3b8bdf4f..509ccfb2a 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1342,7 +1342,7 @@ def _check_equal(a, b): if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) - return type(self)( + return type(cls)( data=first.data, batch_size=batch_size, names=first.names if first._has_names() else None, @@ -1368,7 +1368,14 @@ def __torch_function__( ): return NotImplemented - escape_conversion = func in (torch.stack, torch.ones_like, torch.zeros_like, torch.empty_like, torch.randn_like, torch.rand_like) + escape_conversion = func in ( + torch.stack, + torch.ones_like, + torch.zeros_like, + torch.empty_like, + torch.randn_like, + torch.rand_like, + ) if kwargs is None: kwargs = {} From ff2c5b0d7baa4a77b83962928eafae3d9357b4af Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Feb 2024 19:48:51 -0800 Subject: [PATCH 03/20] init --- tensordict/_lazy.py | 4 ++++ tensordict/base.py | 3 ++- tensordict/tensorclass.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index b0c6ee6cc..659b04cfe 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2441,6 +2441,10 @@ def _unsqueeze(self, dim): _to_module = TensorDict._to_module +class StackNonTensor(LazyStackedTensorDict): + """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + pass + class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/base.py b/tensordict/base.py index 9203b6ba3..52b83e089 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5377,9 +5377,10 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: from tensordict.tensorclass import NonTensorData + from tensordict._lazy import StackNonTensor if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, NonTensorData) + return issubclass(cls, (NonTensorData, StackNonTensor)) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index d2b249cb5..81f16c598 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1339,9 +1339,9 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import LazyStackedTensorDict + from tensordict._lazy import StackNonTensor - return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + return StackNonTensor(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( From 40dd77349bd1e81010d41cbd1c3cce619805111e Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 08:21:14 -0800 Subject: [PATCH 04/20] amend --- tensordict/_lazy.py | 16 +++++++++++++--- tensordict/_td.py | 9 ++++++--- tensordict/_torch_func.py | 31 ++++++++++++++++++++++++++++++- tensordict/tensorclass.py | 12 +++++++++++- tensordict/utils.py | 25 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 8 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 659b04cfe..1ff73948e 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1473,7 +1473,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: elif isinstance(index, (list, range)): index = torch.tensor(index, device=self.device) - if isinstance(value, (TensorDictBase, dict)): + if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): value = TensorDict( @@ -1500,7 +1500,11 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - self.tensordicts[i][_idx] = value + if _idx != (): + self.tensordicts[i][_idx] = value + else: + self.tensordicts[i] = value + return self if is_nd_tensor: raise RuntimeError( @@ -2443,7 +2447,13 @@ def _unsqueeze(self, dim): class StackNonTensor(LazyStackedTensorDict): """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" - pass + + def tolist(self): + if self.stack_dim == 0: + return [td.tolist() for td in self.tensordicts] + else: + return [td.tolist() for td in self.unbind(0)] + class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/_td.py b/tensordict/_td.py index 24e909dab..c52a3d594 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1538,7 +1538,9 @@ def _set_at_str(self, key, value, idx, *, validated): tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) + if tensor_in is not tensor_out: + self._set_str(key, tensor_out, validated=True, inplace=False) return self @@ -2366,10 +2368,11 @@ def _set_at_str(self, key, value, idx, *, validated): ) tensor_in = _sub_index(tensor_in, idx) tensor_in.copy_(value) + tensor_out = tensor_in else: - _set_item(tensor_in, idx, value, validated=validated) + tensor_out = _set_item(tensor_in, idx, value, validated=validated) # make sure that the value is updated - self._source._set_at_str(key, tensor_in, self.idx, validated=validated) + self._source._set_at_str(key, tensor_out, self.idx, validated=validated) return self def _set_at_tuple(self, key, value, idx, *, validated): diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 3584b1d33..47ad18077 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -161,6 +161,32 @@ def _ones_like(td: T, **kwargs: Any) -> T: return td_clone +@implements_for_td(torch.rand_like) +def _rand_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.rand_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + +@implements_for_td(torch.randn_like) +def _randn_like(td: T, **kwargs: Any) -> T: + td_clone = td._fast_apply(lambda x: torch.randn_like(x)) + if "device" in kwargs: + td_clone = td_clone.to(kwargs.pop("device")) + if len(kwargs): + raise RuntimeError( + f"keyword arguments {list(kwargs.keys())} are not " + f"supported with full_like with TensorDict" + ) + return td_clone + + @implements_for_td(torch.empty_like) def _empty_like(td: T, *args, **kwargs) -> T: try: @@ -353,9 +379,12 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if all(isinstance(td, NonTensorData) for td in list_of_tensordicts): + if all( + isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts + ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 81f16c598..38003e1bc 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -56,6 +56,9 @@ torch.full_like, torch.zeros_like, torch.ones_like, + torch.rand_like, + torch.empty_like, + torch.randn_like, torch.clone, torch.squeeze, torch.unsqueeze, @@ -1329,7 +1332,9 @@ def _check_equal(a, b): iseq = False return iseq - if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]): + if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all( + _check_equal(data.data, first.data) for data in list_of_non_tensor[1:] + ): batch_size = list(first.batch_size) batch_size.insert(dim, len(list_of_non_tensor)) return NonTensorData( @@ -1395,3 +1400,8 @@ def _fast_apply(self, *args, **kwargs): return _wrap_method(self, "_fast_apply", self._tensordict._fast_apply)( *args, **kwargs ) + + def tolist(self): + if not self.batch_size: + return self.data + return [ntd.tolist() for ntd in self.unbind(0)] diff --git a/tensordict/utils.py b/tensordict/utils.py index c744a7377..f572c3ff4 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -656,6 +656,31 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor + from tensordict._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorData + + if isinstance(tensor, (NonTensorData, StackNonTensor)): + if ( + isinstance(value, NonTensorData) + and isinstance(tensor, NonTensorData) + and tensor.data == value.data + ): + return tensor + if isinstance(index, tuple): + if len(index) == 1: + index = index[0] + else: + idx = index[0] + tensor_idx = tensor[idx] + tensor_idx = _set_item(tensor_idx, index[1:], value, validated=True) + tensor = _set_item(tensor, idx, tensor_idx, validated=True) + return tensor + if isinstance(tensor, NonTensorData): + tensor = StackNonTensor(*[tensor[0]] * tensor.shape[0], stack_dim=0) + elif tensor.stack_dim != 0: + tensor = StackNonTensor(*tensor.unbind(0), stack_dim=0) + tensor[index] = value + return tensor else: tensor[index] = value return tensor From 20dc288d6dda3822383faf6c6fe0aed431a05ffa Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 09:31:54 -0800 Subject: [PATCH 05/20] amend --- tensordict/_lazy.py | 5 ++++- tensordict/_td.py | 10 +++++++--- tensordict/base.py | 13 +++++++++++-- tensordict/tensorclass.py | 5 +++++ test/test_tensordict.py | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 6 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 1ff73948e..6ff283228 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1518,7 +1518,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - self.tensordicts[i][_idx] = _value + if _idx != (): + self.tensordicts[i][_idx] = _value + else: + self.tensordicts[i] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] diff --git a/tensordict/_td.py b/tensordict/_td.py index c52a3d594..74f099faa 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -308,9 +308,10 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if isinstance(item, NonTensorData): + if isinstance(item, (NonTensorData, StackNonTensor)): return False else: return False @@ -2433,10 +2434,13 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) + from tensordict._lazy import StackNonTensor from tensordict.tensorclass import NonTensorData - if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData): - return out._source.data + if isinstance(out, _SubTensorDict) and isinstance( + out._source, (NonTensorData, StackNonTensor) + ): + return out._source return out def _get_str(self, key, default): diff --git a/tensordict/base.py b/tensordict/base.py index 52b83e089..253836712 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,6 +246,10 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(result, NonTensorData): return result.data + from ._lazy import StackNonTensor + + if isinstance(result, StackNonTensor): + return result.tolist() return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -2308,10 +2312,15 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from tensordict.tensorclass import NonTensorData + + from .tensorclass import NonTensorData if isinstance(value, NonTensorData): return value.data + from ._lazy import StackNonTensor + + if isinstance(value, StackNonTensor): + return value.tolist() return value def filter_non_tensor_data(self) -> T: @@ -5376,8 +5385,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData from tensordict._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorData if issubclass(cls, KeyedJaggedTensor): return False diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 38003e1bc..b7016bcff 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1402,6 +1402,11 @@ def _fast_apply(self, *args, **kwargs): ) def tolist(self): + """Converts the data in a list if the batch-size is non-empty. + + If the batch-size is empty, returns the data. + + """ if not self.batch_size: return self.data return [ntd.tolist() for ntd in self.unbind(0)] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 7c40eb5ca..adb8fae65 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7785,6 +7785,43 @@ def test_stack(self, non_tensor_data): LazyStackedTensorDict, ) + def test_assign_non_tensor(self): + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + assert data["b"] == "a string!" + assert data.get("b").tolist() == [["a string!"] * 10] + data[0, 1] = TensorDict({"a": 0, "b": "another string!"}, []) + assert data.get("b").tolist() == [ + ["a string!"] + ["another string!"] + ["a string!"] * 8 + ] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 5:] = TensorDict({"a": torch.zeros(5), "b": "another string!"}, [5]) + assert data.get("b").tolist() == [["a string!"] * 5 + ["another string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0, 0::2] = TensorDict( + {"a": torch.zeros(5, dtype=torch.long), "b": "another string!"}, [5] + ) + assert data.get("b").tolist() == [["another string!", "a string!"] * 5] + + data = TensorDict({}, [1, 10]) + + data[0, 0] = TensorDict({"a": 0, "b": "a string!"}, []) + + data[0] = TensorDict( + {"a": torch.zeros(10, dtype=torch.long), "b": "another string!"}, [10] + ) + assert data.get("b").tolist() == [["another string!"] * 10] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() From 0f44acc2d3a37135ccd73ef93c38107264fdb96d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 14:08:14 -0800 Subject: [PATCH 06/20] amend --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 3 +- tensordict/_lazy.py | 38 +++++++----------- tensordict/_td.py | 17 ++++---- tensordict/_torch_func.py | 5 +-- tensordict/base.py | 28 ++++++++++---- tensordict/tensorclass.py | 56 ++++++++++++++++++++++----- tensordict/utils.py | 9 ++--- 8 files changed, 100 insertions(+), 57 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17518df1d..8e6e4e907 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -273,3 +273,4 @@ Here is an example: tensorclass NonTensorData + NonTensorStack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c71661ceb..5e6dc8761 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import NonTensorData, tensorclass +from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, is_batchedtensor, @@ -43,6 +43,7 @@ "TensorDict", "TensorDictBase", "merge_tensordicts", + "NonTensorStack", "set_transfer_ownership", "pad_sequence", "is_memmap", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 6ff283228..8059b6819 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -718,7 +718,7 @@ def _legacy_unsqueeze(self, dim: int) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.unsqueeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -756,7 +756,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.squeeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -1236,7 +1236,7 @@ def contiguous(self) -> T: return out def empty(self, recurse=False) -> T: - return LazyStackedTensorDict( + return type(self)( *[td.empty(recurse=recurse) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1245,12 +1245,12 @@ def _clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with # metadata (_is_shared etc) - result = LazyStackedTensorDict( + result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, ) else: - result = LazyStackedTensorDict( + result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1274,7 +1274,7 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result - return LazyStackedTensorDict( + return type(self)( *[td.to(*args, **kwargs) for td in self.tensordicts], stack_dim=self.stack_dim, hook_out=self.hook_out, @@ -1403,7 +1403,7 @@ def _apply_nest( if filter_empty and all(r is None for r in results): return if not inplace: - out = LazyStackedTensorDict( + out = type(self)( *results, stack_dim=self.stack_dim, ) @@ -1429,7 +1429,7 @@ def _select( ] if inplace: return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def _exclude( @@ -1442,7 +1442,7 @@ def _exclude( if inplace: self.tensordicts = tensordicts return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def __setitem__(self, index: IndexType, value: T) -> T: @@ -2336,9 +2336,9 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4 # resulting shape: [5, 1, 3, 2, 4] if dim1 == dim0 + 1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1) + result = type(self)(*self.tensordicts, stack_dim=dim1) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1 - 1) for td in self.tensordicts), stack_dim=dim1, ) @@ -2346,16 +2346,16 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3 # resulting shape: [5, 2, 3, 4, 1] if dim0 + 1 == dim1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0) + result = type(self)(*self.tensordicts, stack_dim=dim0) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0 + 1, dim1) for td in self.tensordicts), stack_dim=dim0, ) else: dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1 dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1 - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, ) @@ -2448,16 +2448,6 @@ def _unsqueeze(self, dim): _to_module = TensorDict._to_module -class StackNonTensor(LazyStackedTensorDict): - """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" - - def tolist(self): - if self.stack_dim == 0: - return [td.tolist() for td in self.tensordicts] - else: - return [td.tolist() for td in self.unbind(0)] - - class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/_td.py b/tensordict/_td.py index 74f099faa..43562aa89 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -29,6 +29,7 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, + is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -308,10 +309,9 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(item, (NonTensorData, StackNonTensor)): + if isinstance(item, (NonTensorData, NonTensorStack)): return False else: return False @@ -680,7 +680,11 @@ def make_result(): any_set = False for key, item in self.items(): - if not call_on_nested and _is_tensor_collection(item.__class__): + if ( + not call_on_nested + and _is_tensor_collection(item.__class__) + and not is_non_tensor(item) + ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] _others = [ @@ -2434,11 +2438,10 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if isinstance(out, _SubTensorDict) and isinstance( - out._source, (NonTensorData, StackNonTensor) + out._source, (NonTensorData, NonTensorStack) ): return out._source return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 47ad18077..d27cbb1d7 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -379,11 +379,10 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if all( - isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts + isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) diff --git a/tensordict/base.py b/tensordict/base.py index 253836712..3fe1ac63b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,9 +246,9 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(result, NonTensorData): return result.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(result, StackNonTensor): + if isinstance(result, NonTensorStack): return result.tolist() return result @@ -1659,6 +1659,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2317,9 +2325,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): if isinstance(value, NonTensorData): return value.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(value, StackNonTensor): + if isinstance(value, NonTensorStack): return value.tolist() return value @@ -5380,16 +5388,22 @@ def is_tensor_collection(datatype: type | Any) -> bool: return _is_tensor_collection(datatype) +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + from tensordict.tensorclass import NonTensorData, NonTensorStack + + return isinstance(data, (NonTensorData, NonTensorStack)) + + def _default_is_leaf(cls: Type) -> bool: return not _is_tensor_collection(cls) def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, (NonTensorData, StackNonTensor)) + return issubclass(cls, (NonTensorData, NonTensorStack)) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index b7016bcff..90b00ecb8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -15,7 +15,7 @@ import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,6 +24,7 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS @@ -475,7 +476,7 @@ def wrapper(self, item: str) -> Any: return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -489,12 +490,10 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 """ __dict__ = self.__dict__ - if ( - "_tensordict" not in __dict__ - or "_non_tensordict" not in __dict__ - or key in SET_ATTRIBUTES - ): + if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__: return setattr_(self, key, value) + if key in SET_ATTRIBUTES: + return setattr(self._tensordict, key, value) out = self.set(key, value) if out is not self: @@ -714,6 +713,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -1344,9 +1346,7 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import StackNonTensor - - return StackNonTensor(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1410,3 +1410,39 @@ def tolist(self): if not self.batch_size: return self.data return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + if isinstance(src, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(src, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ requires the source to be a NonTensorData object." + ) + self._non_tensordict["data"] = src.data + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + + def tolist(self): + if self.stack_dim == 0: + return [td.tolist() for td in self.tensordicts] + else: + return [td.tolist() for td in self.unbind(0)] diff --git a/tensordict/utils.py b/tensordict/utils.py index f572c3ff4..b02e367c8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -656,10 +656,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(tensor, (NonTensorData, StackNonTensor)): + if isinstance(tensor, (NonTensorData, NonTensorStack)): if ( isinstance(value, NonTensorData) and isinstance(tensor, NonTensorData) @@ -676,9 +675,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> tensor = _set_item(tensor, idx, tensor_idx, validated=True) return tensor if isinstance(tensor, NonTensorData): - tensor = StackNonTensor(*[tensor[0]] * tensor.shape[0], stack_dim=0) + tensor = NonTensorStack(*[tensor[0]] * tensor.shape[0], stack_dim=0) elif tensor.stack_dim != 0: - tensor = StackNonTensor(*tensor.unbind(0), stack_dim=0) + tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) tensor[index] = value return tensor else: From 7104a0b466d1b23be4b16025090d129df9575a61 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 16:32:57 -0800 Subject: [PATCH 07/20] amend --- tensordict/nn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 3e1d698bd..105be6538 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -14,7 +14,7 @@ import torch from torch import nn -AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True")) +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) From d1866394de41bbeb8b29c2dd450d3e4a91487933 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 24 Feb 2024 17:44:07 -0800 Subject: [PATCH 08/20] amend --- tensordict/_td.py | 10 ++++++---- tensordict/base.py | 5 +---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c01256e24..aeca5001a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -309,9 +309,8 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(item, (NonTensorData, NonTensorStack)): + if is_non_tensor(item): return False else: return False @@ -693,7 +692,7 @@ def make_result(): if ( not call_on_nested and _is_tensor_collection(item.__class__) - and not is_non_tensor(item) + # and not is_non_tensor(item) ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] @@ -2467,7 +2466,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): - return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx) + data = self._source._get_str(key, NO_DEFAULT) + if is_non_tensor(data): + return data[self.idx] + return _SubTensorDict(data, self.idx) return self._source._get_at_str(key, self.idx, default=default) def _get_tuple(self, key, default): diff --git a/tensordict/base.py b/tensordict/base.py index 3fe1ac63b..33fa7b239 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2321,12 +2321,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from .tensorclass import NonTensorData - + from .tensorclass import NonTensorData, NonTensorStack if isinstance(value, NonTensorData): return value.data - from tensordict.tensorclass import NonTensorStack - if isinstance(value, NonTensorStack): return value.tolist() return value From 4f665201d32ec5d4d4dd70d3ebd072475fdbaa1a Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 25 Feb 2024 15:26:58 -0800 Subject: [PATCH 09/20] amend --- tensordict/_lazy.py | 155 +++++++++++++++++++++++++++----------- tensordict/base.py | 1 + tensordict/tensorclass.py | 78 +++++++++++++++++-- tensordict/utils.py | 18 ++--- test/test_tensordict.py | 19 ++++- 5 files changed, 206 insertions(+), 65 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 8059b6819..df256f265 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -200,6 +200,10 @@ def __init__( ) _batch_size = tensordicts[0].batch_size device = tensordicts[0].device + if stack_dim > len(_batch_size): + raise RuntimeError( + f"Stack dim {stack_dim} is too big for batch size {_batch_size}." + ) for td in tensordicts[1:]: if not is_tensor_collection(td): @@ -487,9 +491,10 @@ def _split_index(self, index): isinteger = False is_nd_tensor = False cursor = 0 # the dimension cursor - selected_td_idx = range(len(self.tensordicts)) + selected_td_idx = torch.arange(len(self.tensordicts)) has_bool = False num_squash = 0 + encountered_tensor = False for i, idx in enumerate(index): # noqa: B007 cursor_incr = 1 if idx is None: @@ -509,10 +514,8 @@ def _split_index(self, index): if not isinstance(selected_td_idx, range): isinteger = True selected_td_idx = [selected_td_idx] - elif isinstance(idx, (list, range)): - selected_td_idx = idx - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: # we mark that we need to dispatch the indices across stack idx has_bool = True # split mask along dim @@ -522,11 +525,14 @@ def _split_index(self, index): split_dim = self.stack_dim - num_single mask_loc = i else: - if isinstance(idx, np.ndarray): - idx = torch.tensor(idx) is_nd_tensor = True - selected_td_idx = range(len(idx)) - out.append(idx.unbind(0)) + if not encountered_tensor: + # num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 + selected_td_idx = idx + # out.append(idx.unbind(0)) else: raise TypeError(f"Invalid index type: {type(idx)}.") else: @@ -537,13 +543,11 @@ def _split_index(self, index): ( ftdim.Dim, slice, - list, - range, ), ): out.append(idx) - elif isinstance(idx, (np.ndarray, torch.Tensor)): - if idx.dtype in (np.dtype("bool"), torch.bool): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: cursor_incr = idx.ndim if cursor < self.stack_dim: num_squash += cursor_incr - 1 @@ -568,7 +572,11 @@ def _split_index(self, index): # smth[torch.tensor(1)].ndim = smth.ndim-1 # smth[torch.tensor([1])].ndim = smth.ndim # smth[torch.tensor([[1]])].ndim = smth.ndim+1 - num_single -= idx.ndim - 1 + if not encountered_tensor: + num_single -= idx.ndim - 1 + encountered_tensor = True + else: + num_single += 1 out.append(idx) else: raise TypeError(f"Invalid index type: {type(idx)}.") @@ -593,20 +601,45 @@ def _split_index(self, index): elif is_nd_tensor: def isindexable(idx): - if isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype in (torch.bool, np.dtype("bool")): + if isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: return False return True if isinstance(idx, (tuple, list, range)): return True return False - out = tuple( - tuple(idx if not isindexable(idx) else idx[i] for idx in out) - for i in selected_td_idx - ) + def outer_list(tensor_index, tuple_index): + """Converts a tensor and a tuple to a nested list where each leaf is a (int, index) tuple where the index only points to one element.""" + if isinstance(tensor_index, torch.Tensor): + list_index = tensor_index.tolist() + else: + list_index = tensor_index + list_result = [] + + def index_tuple_index(i, convert=False): + for idx in tuple_index: + if isindexable(idx): + if convert: + yield int(idx[i]) + else: + yield idx[i] + else: + yield idx + + for i, idx in enumerate(list_index): + if isinstance(idx, int): + list_result.append( + (idx, tuple(index_tuple_index(i, convert=True))) + ) + elif isinstance(idx, list): + list_result.append(outer_list(idx, tuple(index_tuple_index(i)))) + else: + raise NotImplementedError + return list_result + return { - "index_dict": dict(enumerate(out)), + "index_dict": outer_list(selected_td_idx, out), "num_single": num_single, "isinteger": isinteger, "has_bool": has_bool, @@ -646,8 +679,19 @@ def _set_at_str(self, key, value, index, *, validated): if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) - for idx, _value in zip(converted_idx.values(), value_unbind): - self._set_at_str(key, _value, idx, validated=validated) + + def set_at_str(converted_idx): + for i, item in enumerate(converted_idx): + if isinstance(item, list): + set_at_str(item) + else: + _value = value_unbind[i] + stack_idx, idx = item + self.tensordicts[stack_idx]._set_at_str( + key, _value, idx, validated=validated + ) + + set_at_str(converted_idx) return self elif not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash @@ -1460,10 +1504,12 @@ def __setitem__(self, index: IndexType, value: T) -> T: ) return - if any(isinstance(sub_index, (list, range)) for sub_index in index): + if any( + isinstance(sub_index, (list, range, np.ndarray)) for sub_index in index + ): index = tuple( - torch.tensor(sub_index, device=self.device) - if isinstance(sub_index, (list, range)) + torch.as_tensor(sub_index, device=self.device) + if isinstance(sub_index, (list, range, np.ndarray)) else sub_index for sub_index in index ) @@ -1471,7 +1517,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: if index is Ellipsis or (isinstance(index, tuple) and Ellipsis in index): index = convert_ellipsis_to_idx(index, self.batch_size) elif isinstance(index, (list, range)): - index = torch.tensor(index, device=self.device) + index = torch.as_tensor(index, device=self.device) if is_tensor_collection(value) or isinstance(value, dict): indexed_bs = _getitem_batch_size(self.batch_size, index) @@ -1500,17 +1546,25 @@ def __setitem__(self, index: IndexType, value: T) -> T: if isinteger: # this will break if the index along the stack dim is [0] or :1 or smth for i, _idx in converted_idx.items(): - if _idx != (): - self.tensordicts[i][_idx] = value + if _idx == (): + self.tensordicts[i].update(value, inplace=True) else: - self.tensordicts[i] = value - + self.tensordicts[i][_idx] = value return self if is_nd_tensor: - raise RuntimeError( - "Indexing along stack dim with a non-boolean tensor is not supported yet. " - "Use SubTensorDict instead." - ) + unbind_dim = self.stack_dim - num_single + num_none - num_squash + # converted_idx is a nested list with (int, index) items + def assign(converted_idx, value=value): + value = value.unbind(unbind_dim) + for i, item in enumerate(converted_idx): + if isinstance(item, list): + assign(item) + else: + stack_item, idx = item + self.tensordicts[stack_item][idx] = value[i] + + assign(converted_idx) + return self if not has_bool: unbind_dim = self.stack_dim - num_single + num_none - num_squash value_unbind = value.unbind(unbind_dim) @@ -1518,10 +1572,10 @@ def __setitem__(self, index: IndexType, value: T) -> T: converted_idx.items(), value_unbind, ): - if _idx != (): - self.tensordicts[i][_idx] = _value + if _idx == (): + self.tensordicts[i].update(_value, inplace=True) else: - self.tensordicts[i] = _value + self.tensordicts[i][_idx] = _value else: # we must split, not unbind mask_unbind = split_index["individual_masks"] @@ -1589,11 +1643,21 @@ def __getitem__(self, index: IndexType) -> T: return torch.cat(result, cat_dim) elif is_nd_tensor: new_stack_dim = self.stack_dim - num_single + num_none - out = LazyStackedTensorDict.lazy_stack( - [self[idx] for idx in converted_idx.values()], new_stack_dim - ) - out._td_dim_name = self._td_dim_name - return out + + def recompose(converted_idx, stack_dim=new_stack_dim): + stack = [] + for item in converted_idx: + if isinstance(item, list): + stack.append(recompose(item, stack_dim=stack_dim)) + else: + stack_elt, idx = item + stack.append(self.tensordicts[stack_elt][idx]) + result = LazyStackedTensorDict.lazy_stack(stack, stack_dim) + # TODO: this produces multiple dims with the same name + result._td_dim_name = self._td_dim_name + return result + + return recompose(converted_idx) else: if isinteger: for ( @@ -1610,7 +1674,10 @@ def __getitem__(self, index: IndexType) -> T: result = [] new_stack_dim = self.stack_dim - num_single + num_none - num_squash for i, _idx in converted_idx.items(): - result.append(self.tensordicts[i][_idx]) + if _idx == (): + result.append(self.tensordicts[i]) + else: + result.append(self.tensordicts[i][_idx]) result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) result._td_dim_name = self._td_dim_name return result diff --git a/tensordict/base.py b/tensordict/base.py index 33fa7b239..e08ba35f9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2322,6 +2322,7 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): value = self._get_str(key, default=default) from .tensorclass import NonTensorData, NonTensorStack + if isinstance(value, NonTensorData): return value.data if isinstance(value, NonTensorStack): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index c468697b4..4421aa318 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,7 +28,7 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -1319,6 +1319,23 @@ def __or__(self, other): self.__class__.__or__ = __or__ + def update( + self, + input_dict_or_td: dict[str, CompatibleType] | T, + clone: bool = False, + inplace: bool = False, + *, + keys_to_update: Sequence[NestedKey] | None = None, + ) -> T: + if isinstance(input_dict_or_td, NonTensorData): + data = input_dict_or_td.data + if clone: + data = deepcopy(data) + self.data = data + elif not input_dict_or_td.is_empty(): + raise RuntimeError(f"Unexpected type {type(input_dict_or_td)}") + return self + def empty(self, recurse=False): return NonTensorData( data=self.data, @@ -1450,10 +1467,59 @@ def clone(self, recurse: bool = True): class NonTensorStack(LazyStackedTensorDict): - """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + """A thin wrapper around LazyStackedTensorDict to make stack on non-tensor data easily recognizable. + + A ``NonTensorStack`` is returned whenever :func:`~torch.stack` is called on + a list of :class:`~tensordict.NonTensorData` or ``NonTensorStack``. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> print(data) + NonTensorStack( + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, ..., + batch_size=torch.Size([3, 2]), + device=None) + + To obtain the values stored in a ``NonTensorStack``, call :class:`~.tolist`. + + """ def tolist(self): - if self.stack_dim == 0: - return [td.tolist() for td in self.tensordicts] - else: - return [td.tolist() for td in self.unbind(0)] + """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. + + Examples: + >>> from tensordict import NonTensorData + >>> import torch + >>> data = torch.stack([ + ... torch.stack([NonTensorData(data=(i, j), batch_size=[]) for i in range(2)]) + ... for j in range(3)]) + >>> data.tolist() + [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(0, 2), (1, 2)]] + + """ + iterator = self.tensordicts if self.stack_dim == 0 else self.unbind(0) + return [td.tolist() for td in iterator] + + @classmethod + def from_nontensordata(cls, non_tensor: NonTensorData): + data = non_tensor.data + prev = NonTensorData(data, batch_size=[], device=non_tensor.device) + for dim in reversed(non_tensor.shape): + prev = cls(*[prev.clone(False) for _ in range(dim)], stack_dim=0) + return prev + + def __repr__(self): + selfrepr = str(self.tolist()) + if len(selfrepr) > 50: + selfrepr = f"{selfrepr[:50]}..." + selfrepr = indent(selfrepr, prefix=4 * " ") + batch_size = indent(f"batch_size={self.batch_size}", prefix=4 * " ") + device = indent(f"device={self.device}", prefix=4 * " ") + return f"NonTensorStack(\n{selfrepr}," f"\n{batch_size}," f"\n{device})" + + def to_dict(self) -> dict[str, Any]: + return self.tolist() diff --git a/tensordict/utils.py b/tensordict/utils.py index 7192c66b3..bb80ca1de 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -668,18 +668,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> and tensor.data == value.data ): return tensor - if isinstance(index, tuple): - if len(index) == 1: - index = index[0] - else: - idx = index[0] - tensor_idx = tensor[idx] - tensor_idx = _set_item(tensor_idx, index[1:], value, validated=True) - tensor = _set_item(tensor, idx, tensor_idx, validated=True) - return tensor - if isinstance(tensor, NonTensorData): - tensor = NonTensorStack(*[tensor[0]] * tensor.shape[0], stack_dim=0) - elif tensor.stack_dim != 0: + elif isinstance(tensor, NonTensorData): + tensor = NonTensorStack.from_nontensordata(tensor) + if tensor.stack_dim != 0: tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) tensor[index] = value return tensor @@ -1610,8 +1601,9 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast + assert isinstance(index, tuple) tensors = { - i: tensor if isinstance(tensor, Tensor) else torch.tensor(tensor) + i: torch.as_tensor(tensor) for i, tensor in enumerate(index) if isinstance(tensor, (range, list, np.ndarray, Tensor)) } diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 79cbd092e..01bdb358d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2774,6 +2774,7 @@ def test_index_tensor_nd_names(self, td_name, device, npy): index = index.numpy() td_idx = td[:, index] assert tensor_example[:, index].shape == td_idx.shape + # TODO: this multiple dims with identical names should not be allowed assert td_idx.names == [names[0], names[1], names[1], *names[2:]] td_idx = td[0, index] assert tensor_example[0, index].shape == td_idx.shape @@ -6100,10 +6101,24 @@ def test_lazy_indexing(self, pos1, pos2, pos3): pos2 = self._idx_list[pos2] pos3 = self._idx_list[pos3] index = (pos1, pos2, pos3) + print( + "index", + tuple( + index if not hasattr(index, "shape") else index.shape for index in index + ), + ) result = outer[index] ref_tensor = torch.zeros(outer.shape) - assert result.batch_size == ref_tensor[index].shape, index - assert result.batch_size == outer_dense[index].shape, index + assert result.batch_size == ref_tensor[index].shape, ( + result.batch_size, + ref_tensor[index].shape, + index, + ) + assert result.batch_size == outer_dense[index].shape, ( + result.batch_size, + outer_dense[index].shape, + index, + ) @pytest.mark.parametrize("stack_dim", [0, 1, 2]) @pytest.mark.parametrize("mask_dim", [0, 1, 2]) From a06aa778a6cf302ff72504ed117a90f18939a112 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 25 Feb 2024 15:28:52 -0800 Subject: [PATCH 10/20] amend --- tensordict/_lazy.py | 1 + tensordict/nn/params.py | 8 -------- tensordict/utils.py | 1 - test/test_tensordict.py | 6 ------ 4 files changed, 1 insertion(+), 15 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index df256f265..02c37d491 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1553,6 +1553,7 @@ def __setitem__(self, index: IndexType, value: T) -> T: return self if is_nd_tensor: unbind_dim = self.stack_dim - num_single + num_none - num_squash + # converted_idx is a nested list with (int, index) items def assign(converted_idx, value=value): value = value.unbind(unbind_dim) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 81a124c81..e5714bf4e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1117,8 +1117,6 @@ def compute_should_use_set_data(tensor, tensor_applied): param.data = param_applied out_param = param else: - assert isinstance(param, nn.Parameter) - assert param.is_leaf out_param = nn.Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -1129,10 +1127,8 @@ def compute_should_use_set_data(tensor, tensor_applied): param.grad, grad_applied ) if should_use_set_data: - assert out_param.grad is not None out_param.grad.data = grad_applied else: - assert param.grad.is_leaf out_param.grad = grad_applied.requires_grad_( param.grad.requires_grad ) @@ -1150,8 +1146,6 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.data = buffer_applied out_buffer = buffer else: - assert isinstance(buffer, Buffer) - assert buffer.is_leaf out_buffer = Buffer(buffer_applied, buffer.requires_grad) self._buffers[key] = out_buffer @@ -1162,10 +1156,8 @@ def compute_should_use_set_data(tensor, tensor_applied): buffer.grad, grad_applied ) if should_use_set_data: - assert out_buffer.grad is not None out_buffer.grad.data = grad_applied else: - assert buffer.grad.is_leaf out_buffer.grad = grad_applied.requires_grad_( buffer.grad.requires_grad ) diff --git a/tensordict/utils.py b/tensordict/utils.py index bb80ca1de..c23c68ec9 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1601,7 +1601,6 @@ def wrapper(*args, **kwargs): def _broadcast_tensors(index): # tensors and range need to be broadcast - assert isinstance(index, tuple) tensors = { i: torch.as_tensor(tensor) for i, tensor in enumerate(index) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 01bdb358d..9ffb602a6 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6101,12 +6101,6 @@ def test_lazy_indexing(self, pos1, pos2, pos3): pos2 = self._idx_list[pos2] pos3 = self._idx_list[pos3] index = (pos1, pos2, pos3) - print( - "index", - tuple( - index if not hasattr(index, "shape") else index.shape for index in index - ), - ) result = outer[index] ref_tensor = torch.zeros(outer.shape) assert result.batch_size == ref_tensor[index].shape, ( From d50104126b355ff2efef87e0a92dd1a8055421dd Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 09:48:51 -0500 Subject: [PATCH 11/20] amend --- tensordict/_td.py | 32 ++++++++++++++++++++------------ test/test_nn.py | 2 ++ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index aeca5001a..8ccf1388b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3043,7 +3043,10 @@ def __contains__(self, key: NestedKey) -> bool: if isinstance(key, str): if key in self._keys(): if self.leaves_only: - return not _is_tensor_collection(self.tensordict.entry_class(key)) + # TODO: make this faster for LazyStacked without compromising regular + return not _is_tensor_collection( + type(self.tensordict._get_str(key)) + ) return True return False else: @@ -3051,25 +3054,30 @@ def __contains__(self, key: NestedKey) -> bool: if len(key) == 1: return key[0] in self._keys() elif self.include_nested: - if key[0] in self._keys(): - entry_type = self.tensordict.entry_class(key[0]) - if entry_type in (Tensor, _MemmapTensor): + item_root = self.tensordict._get_str(key[0], default=None) + if item_root is not None: + entry_type = type(item_root) + if issubclass(entry_type, (Tensor, _MemmapTensor)): return False - if entry_type is KeyedJaggedTensor: + elif entry_type is KeyedJaggedTensor: if len(key) > 2: return False - return key[1] in self.tensordict.get(key[0]).keys() + return key[1] in item_root.keys() + # TODO: make this faster for LazyStacked without compromising regular _is_tensordict = _is_tensor_collection(entry_type) if _is_tensordict: # # this will call _unravel_key_to_tuple many times # return key[1:] in self.tensordict._get_str(key[0], NO_DEFAULT).keys(include_nested=self.include_nested) # this won't call _unravel_key_to_tuple but requires to get the default which can be suboptimal - leaf_td = self.tensordict._get_tuple(key[:-1], None) - if leaf_td is None or ( - not _is_tensor_collection(leaf_td.__class__) - and not isinstance(leaf_td, KeyedJaggedTensor) - ): - return False + if len(key) >= 3: + leaf_td = item_root._get_tuple(key[1:-1], None) + if leaf_td is None or ( + not _is_tensor_collection(leaf_td.__class__) + and not isinstance(leaf_td, KeyedJaggedTensor) + ): + return False + else: + leaf_td = item_root return key[-1] in leaf_td.keys() return False # this is reached whenever there is more than one key but include_nested is False diff --git a/test/test_nn.py b/test/test_nn.py index 37d5975b0..016237077 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -415,6 +415,7 @@ def test_functional_before(self): tensordict_module = TensorDictModule( module=net, in_keys=["in"], out_keys=["out"] ) + make_functional(tensordict_module, return_params=False) td = TensorDict({"in": torch.randn(3, 3)}, [3]) tensordict_module(td, params=TensorDict({"module": params}, [])) @@ -580,6 +581,7 @@ def test_functional_with_buffer(self): tdmodule = TensorDictModule(module=net, in_keys=["in"], out_keys=["out"]) td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) + make_functional(tdmodule, return_params=False) tdmodule(td, params=TensorDict({"module": params}, [])) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 32]) From c23e2577e482ba8060185841001fbddddbb81d33 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 10:25:28 -0500 Subject: [PATCH 12/20] amend --- tensordict/_lazy.py | 7 ++++--- tensordict/_td.py | 5 +---- tensordict/_torch_func.py | 13 ++++++++----- tensordict/base.py | 30 +++++++++--------------------- tensordict/tensorclass.py | 14 +++++++++++--- tensordict/utils.py | 9 ++++----- 6 files changed, 37 insertions(+), 41 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 02c37d491..96827fe44 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -29,6 +29,7 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, + is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -931,15 +932,15 @@ def lazy_stack( if not items: raise RuntimeError("items cannot be empty") - from .tensorclass import NonTensorData - if all(isinstance(item, torch.Tensor) for item in items): return torch.stack(items, dim=dim, out=out) if all( is_tensorclass(item) and type(item) == type(items[0]) # noqa: E721 for item in items ): - if all(isinstance(tensordict, NonTensorData) for tensordict in items): + if all(is_non_tensor(tensordict) for tensordict in items): + from .tensorclass import NonTensorData + return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( [item._tensordict for item in items], dim=dim, out=out diff --git a/tensordict/_td.py b/tensordict/_td.py index 8ccf1388b..d2f5bc327 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2456,11 +2456,8 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(out, _SubTensorDict) and isinstance( - out._source, (NonTensorData, NonTensorStack) - ): + if isinstance(out, _SubTensorDict) and is_non_tensor(out._source): return out._source return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index d7b9a12e5..6d3280958 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,7 +13,12 @@ import torch from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase +from tensordict.base import ( + _is_leaf_nontensor, + is_non_tensor, + NO_DEFAULT, + TensorDictBase, +) from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, @@ -381,11 +386,9 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict.tensorclass import NonTensorData, NonTensorStack + if all(is_non_tensor(td) for td in list_of_tensordicts): + from tensordict.tensorclass import NonTensorData - if all( - isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts - ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) batch_size = list_of_tensordicts[0].batch_size diff --git a/tensordict/base.py b/tensordict/base.py index e08ba35f9..5ac9bf5bf 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -242,14 +242,11 @@ def __getitem__(self, index: IndexType) -> T: idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: result = self._get_tuple(idx_unravel, NO_DEFAULT) - from .tensorclass import NonTensorData - - if isinstance(result, NonTensorData): - return result.data - from tensordict.tensorclass import NonTensorStack - - if isinstance(result, NonTensorStack): - return result.tolist() + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result_data.tolist() + return result_data return result if (istuple and not index) or (not istuple and index is Ellipsis): @@ -2321,20 +2318,15 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): return subtd._get_non_tensor(key[1:], default=default) value = self._get_str(key, default=default) - from .tensorclass import NonTensorData, NonTensorStack - - if isinstance(value, NonTensorData): - return value.data - if isinstance(value, NonTensorStack): + if is_non_tensor(value): return value.tolist() return value def filter_non_tensor_data(self) -> T: """Filters out all non-tensor-data.""" - from tensordict.tensorclass import NonTensorData def _filter(x): - if not isinstance(x, NonTensorData): + if not is_non_tensor(x): if is_tensor_collection(x): return x.filter_non_tensor_data() return x @@ -5388,9 +5380,7 @@ def is_tensor_collection(datatype: type | Any) -> bool: def is_non_tensor(data): """Checks if an item is a non-tensor.""" - from tensordict.tensorclass import NonTensorData, NonTensorStack - - return isinstance(data, (NonTensorData, NonTensorStack)) + return type(data).__dict__.get("_non_tensor", False) def _default_is_leaf(cls: Type) -> bool: @@ -5398,10 +5388,8 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict.tensorclass import NonTensorData, NonTensorStack - if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, (NonTensorData, NonTensorStack)) + return cls.__dict__.get("_non_tensor", False) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 4421aa318..0c0dfdbf4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,7 +28,12 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType +from tensordict.base import ( + _ACCEPTED_CLASSES, + _register_tensor_class, + CompatibleType, + is_non_tensor, +) from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -1261,10 +1266,11 @@ class NonTensorData: # to patch tensordict with additional checks that will encur unwanted overhead # and all the overhead falls back on this class. data: Any + _non_tensor: bool = True def __post_init__(self): - if isinstance(self.data, NonTensorData): - self.data = self.data.data + if is_non_tensor(self.data): + self.data = self.data.tolist() old_eq = self.__class__.__eq__ if old_eq is _eq: @@ -1488,6 +1494,8 @@ class NonTensorStack(LazyStackedTensorDict): """ + _non_tensor: bool = True + def tolist(self): """Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list. diff --git a/tensordict/utils.py b/tensordict/utils.py index c23c68ec9..0dd27a153 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -49,6 +49,8 @@ unravel_key_list, unravel_keys, ) + +from tensordict.base import is_non_tensor from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( @@ -59,7 +61,6 @@ ) from torch.utils.data._utils.worker import _generate_state - if TYPE_CHECKING: from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.tensordict import TensorDictBase @@ -661,7 +662,7 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> return tensor from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(tensor, (NonTensorData, NonTensorStack)): + if is_non_tensor(tensor): if ( isinstance(value, NonTensorData) and isinstance(tensor, NonTensorData) @@ -1521,9 +1522,7 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): """Updates a tensordict with its maximium batch size.""" - from tensordict import NonTensorData - - tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] + tensor_data = [val for val in source.values() if not is_non_tensor(val)] for val in tensor_data: from tensordict.base import _is_tensor_collection From 4bfe951dddd2bc085012a205e81a645adbe3182c Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 10:40:15 -0500 Subject: [PATCH 13/20] amend --- tensordict/_lazy.py | 2 +- tensordict/_td.py | 2 +- tensordict/_torch_func.py | 8 ++------ tensordict/base.py | 6 +----- tensordict/tensorclass.py | 8 ++------ tensordict/utils.py | 6 +++++- 6 files changed, 12 insertions(+), 20 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 96827fe44..3f9765246 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -29,7 +29,6 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, - is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -54,6 +53,7 @@ expand_right, IndexType, infer_size_impl, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, diff --git a/tensordict/_td.py b/tensordict/_td.py index d2f5bc327..021c9151e 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -29,7 +29,6 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, - is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -72,6 +71,7 @@ DeviceType, expand_as_right, IndexType, + is_non_tensor, is_tensorclass, KeyedJaggedTensor, lock_blocked, diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 6d3280958..fbabc77e4 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -13,17 +13,13 @@ import torch from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.base import ( - _is_leaf_nontensor, - is_non_tensor, - NO_DEFAULT, - TensorDictBase, -) +from tensordict.base import _is_leaf_nontensor, NO_DEFAULT, TensorDictBase from tensordict.persistent import PersistentTensorDict from tensordict.utils import ( _check_keys, _ErrorInteceptor, DeviceType, + is_non_tensor, lazy_legacy, set_lazy_legacy, ) diff --git a/tensordict/base.py b/tensordict/base.py index 5ac9bf5bf..0700ec4e4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -60,6 +60,7 @@ IndexType, infer_size_impl, int_generator, + is_non_tensor, KeyedJaggedTensor, lazy_legacy, lock_blocked, @@ -5378,11 +5379,6 @@ def is_tensor_collection(datatype: type | Any) -> bool: return _is_tensor_collection(datatype) -def is_non_tensor(data): - """Checks if an item is a non-tensor.""" - return type(data).__dict__.get("_non_tensor", False) - - def _default_is_leaf(cls: Type) -> bool: return not _is_tensor_collection(cls) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0c0dfdbf4..f64927c1b 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -28,12 +28,7 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS -from tensordict.base import ( - _ACCEPTED_CLASSES, - _register_tensor_class, - CompatibleType, - is_non_tensor, -) +from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class, CompatibleType from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -42,6 +37,7 @@ _LOCK_ERROR, DeviceType, IndexType, + is_non_tensor, is_tensorclass, NestedKey, ) diff --git a/tensordict/utils.py b/tensordict/utils.py index 0dd27a153..c64c30a7a 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -50,7 +50,6 @@ unravel_keys, ) -from tensordict.base import is_non_tensor from torch import Tensor from torch._C import _disabled_torch_function_impl from torch.nn.parameter import ( @@ -2170,3 +2169,8 @@ def __call__(self, mod: torch.nn.Module, args, kwargs): return else: raise RuntimeError("did not find pre-hook") + + +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + return type(data).__dict__.get("_non_tensor", False) From 856079865a830d9041781d3b40a80d412a73646b Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 11:49:37 -0500 Subject: [PATCH 14/20] amend --- tensordict/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 0700ec4e4..93216ca12 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2320,7 +2320,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): value = self._get_str(key, default=default) if is_non_tensor(value): - return value.tolist() + data = getattr(value, "data", None) + if data is None: + return value.tolist() + return data return value def filter_non_tensor_data(self) -> T: From 43ddacd2bd2324b5f3d681604c6ef284aa2c3c6f Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 13:57:34 -0500 Subject: [PATCH 15/20] amend --- tensordict/base.py | 2 +- tensordict/tensorclass.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 93216ca12..bbc3cbf0c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,7 +246,7 @@ def __getitem__(self, index: IndexType) -> T: if is_non_tensor(result): result_data = getattr(result, "data", NO_DEFAULT) if result_data is NO_DEFAULT: - return result_data.tolist() + return result.tolist() return result_data return result diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f64927c1b..cac696893 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -1266,7 +1266,10 @@ class NonTensorData: def __post_init__(self): if is_non_tensor(self.data): - self.data = self.data.tolist() + data = getattr(self.data, "data", None) + if data is None: + data = self.data.tolist() + self.data = data old_eq = self.__class__.__eq__ if old_eq is _eq: From 7c13cfaa435ae2ccac3467bc9eabe649d7161f77 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 14:31:46 -0500 Subject: [PATCH 16/20] amend --- tensordict/_lazy.py | 58 ++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 3f9765246..00410a6dd 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -135,6 +135,8 @@ class LazyStackedTensorDict(TensorDictBase): `td.ndimension()-1` along which the stack should be performed. hook_out (callable, optional): a callable to execute after :meth:`~.get`. hook_in (callable, optional): a callable to execute before :meth:`~.set`. + stack_dim_name (str, optional): the name of the stack dimension. + Defaults to ``None``. Examples: >>> from tensordict import TensorDict @@ -185,6 +187,7 @@ def __init__( hook_out: callable | None = None, hook_in: callable | None = None, batch_size: Sequence[int] | None = None, # TODO: remove + stack_dim_name: str | None = None, ) -> None: self._is_locked = None @@ -229,6 +232,8 @@ def __init__( self.hook_in = hook_in if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") + if stack_dim_name is not None: + self._td_dim_name = stack_dim_name # These attributes should never be set @property @@ -855,7 +860,9 @@ def _get_str( # then we consider this default as non-stackable and return prematurly return default try: - out = self.lazy_stack(tensors, self.stack_dim) + out = self.lazy_stack( + tensors, self.stack_dim, stack_dim_name=self._td_dim_name + ) if _is_tensor_collection(out.__class__): if isinstance(out, LazyStackedTensorDict): # then it's a LazyStackedTD @@ -867,8 +874,6 @@ def _get_str( self._batch_size + out.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._td_dim_name = self._td_dim_name elif is_tensorclass(out): # then it's a tensorclass out._tensordict.hook_out = self.hook_out @@ -879,8 +884,6 @@ def _get_str( self._batch_size + out._tensordict.batch_size[(len(self._batch_size) + incr) :] ) - if self._td_dim_name is not None: - out._tensordict._td_dim_name = self._td_dim_name else: raise RuntimeError elif self.hook_out is not None: @@ -925,8 +928,10 @@ def lazy_stack( cls, items: Sequence[TensorDictBase], dim: int = 0, + *, device: DeviceType | None = None, out: T | None = None, + stack_dim_name: str | None = None, ) -> T: """Stacks tensordicts in a LazyStackedTensorDict.""" if not items: @@ -943,7 +948,10 @@ def lazy_stack( return NonTensorData._stack_non_tensor(items, dim=dim) lazy_stack = cls.lazy_stack( - [item._tensordict for item in items], dim=dim, out=out + [item._tensordict for item in items], + dim=dim, + out=out, + stack_dim_name=stack_dim_name, ) # we take the first non_tensordict by convention return type(items[0])._from_tensordict( @@ -968,7 +976,9 @@ def lazy_stack( # The first case is handled within _check_keys which fails if keys # don't match exactly. # The second requires a check over the tensor shapes. - return LazyStackedTensorDict(*items, stack_dim=dim) + return LazyStackedTensorDict( + *items, stack_dim=dim, stack_dim_name=stack_dim_name + ) else: batch_size = list(batch_size) batch_size.insert(dim, len(items)) @@ -1293,14 +1303,14 @@ def _clone(self, recurse: bool = True) -> T: result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - if self._td_dim_name is not None: - result._td_dim_name = self._td_dim_name return result def pin_memory(self) -> T: @@ -1451,13 +1461,12 @@ def _apply_nest( out = type(self)( *results, stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) else: out = self if names is not None: out.names = names - else: - out._td_dim_name = self._td_dim_name return out def _select( @@ -1654,9 +1663,10 @@ def recompose(converted_idx, stack_dim=new_stack_dim): else: stack_elt, idx = item stack.append(self.tensordicts[stack_elt][idx]) - result = LazyStackedTensorDict.lazy_stack(stack, stack_dim) # TODO: this produces multiple dims with the same name - result._td_dim_name = self._td_dim_name + result = LazyStackedTensorDict.lazy_stack( + stack, stack_dim, stack_dim=self._td_dim_name + ) return result return recompose(converted_idx) @@ -1680,8 +1690,9 @@ def recompose(converted_idx, stack_dim=new_stack_dim): result.append(self.tensordicts[i]) else: result.append(self.tensordicts[i][_idx]) - result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim) - result._td_dim_name = self._td_dim_name + result = LazyStackedTensorDict.lazy_stack( + result, new_stack_dim, stack_dim_name=self._td_dim_name + ) return result def __eq__(self, other): @@ -2427,8 +2438,8 @@ def _transpose(self, dim0, dim1): result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _permute( @@ -2446,9 +2457,10 @@ def _permute( d if d < self.stack_dim else d - 1 for d in dims_list if d != self.stack_dim ] result = LazyStackedTensorDict.lazy_stack( - [td.permute(dims_list) for td in self.tensordicts], stack_dim + [td.permute(dims_list) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = self._td_dim_name return result def _squeeze(self, dim=None): @@ -2471,9 +2483,10 @@ def _squeeze(self, dim=None): else: stack_dim = self.stack_dim - 1 result = LazyStackedTensorDict.lazy_stack( - [td.squeeze(dim) for td in self.tensordicts], stack_dim + [td.squeeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name else: result = self for dim in range(self.batch_dims - 1, -1, -1): @@ -2496,9 +2509,10 @@ def _unsqueeze(self, dim): else: stack_dim = self.stack_dim + 1 result = LazyStackedTensorDict.lazy_stack( - [td.unsqueeze(dim) for td in self.tensordicts], stack_dim + [td.unsqueeze(dim) for td in self.tensordicts], + stack_dim, + stack_dim_name=self._td_dim_name, ) - result._td_dim_name = result._td_dim_name return result lock_ = TensorDictBase.lock_ From 82eab7c7d84f921912700f4d15a13cbc3adfc0b9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 14:53:04 -0500 Subject: [PATCH 17/20] fix --- tensordict/_lazy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 00410a6dd..971ef142d 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1665,7 +1665,7 @@ def recompose(converted_idx, stack_dim=new_stack_dim): stack.append(self.tensordicts[stack_elt][idx]) # TODO: this produces multiple dims with the same name result = LazyStackedTensorDict.lazy_stack( - stack, stack_dim, stack_dim=self._td_dim_name + stack, stack_dim, stack_dim_name=self._td_dim_name ) return result From d5c299a5f13962614e8c746597baaf1486a4de9a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 16:42:42 -0500 Subject: [PATCH 18/20] fix --- tensordict/tensorclass.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index cac696893..c2029f4fa 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -221,6 +221,7 @@ def __torch_function__( cls.to_tensordict = _to_tensordict cls.device = property(_device, _device_setter) cls.batch_size = property(_batch_size, _batch_size_setter) + cls.names = property(_names, _names_setter) cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" @@ -502,10 +503,12 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 """ __dict__ = self.__dict__ - if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__: + if ( + "_tensordict" not in __dict__ + or "_non_tensordict" not in __dict__ + or key in SET_ATTRIBUTES + ): return setattr_(self, key, value) - if key in SET_ATTRIBUTES: - return setattr(self._tensordict, key, value) out = self.set(key, value) if out is not self: @@ -844,6 +847,26 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 self._tensordict._batch_size_setter(new_size) +def _names(self) -> torch.Size: + """Retrieves the dim names for the tensor class. + + Returns: + names (list of str) + + """ + return self._tensordict.names + + +def _names_setter(self, names: str) -> None: # noqa: D417 + """Set the value of ``tensorclass.names``. + + Args: + names (sequence of str) + + """ + self._tensordict.names = names + + def _state_dict( self, destination=None, prefix="", keep_vars=False, flatten=False ) -> dict[str, Any]: From 26b427dcc5016cabe3f2c1203b25414407928319 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 18:11:55 -0500 Subject: [PATCH 19/20] amend --- test/test_tensordict.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 9ffb602a6..66e2875c8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7797,6 +7797,29 @@ def test_map_with_out(self, mmap, chunksize, tmpdir): input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out) assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap) + @classmethod + def nontensor_check(cls, td): + td["check"] = td["non_tensor"] == ( + "a string!" if (td["tensor"] % 2) == 0 else "another string!" + ) + return td + + def test_non_tensor(self): + # with NonTensorStack + td = TensorDict( + {"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10] + ) + td[1::2] = TensorDict({"non_tensor": "another string!"}, [5]) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # with NonTensorData + td = TensorDict( + {"tensor": torch.zeros(10, dtype=torch.int), "non_tensor": "a string!"}, + batch_size=[10], + ) + td = td.map(self.nontensor_check, chunksize=0) + assert td["check"].all() + # class TestNonTensorData: class TestNonTensorData: From 137d89da3df83f7b0de573cd4bc2ef00aee687f1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 26 Feb 2024 18:12:38 -0500 Subject: [PATCH 20/20] amend --- .github/workflows/test-rl-gpu.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-rl-gpu.yml b/.github/workflows/test-rl-gpu.yml index ab284d754..dea057e8b 100644 --- a/.github/workflows/test-rl-gpu.yml +++ b/.github/workflows/test-rl-gpu.yml @@ -31,6 +31,7 @@ jobs: repository: pytorch/tensordict gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }}