diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index d07727e1f..cae34dba6 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -614,36 +614,64 @@ def apply( a new tensordict with transformed_in tensors. """ - out = ( - self - if inplace - else TensorDict( - {}, - batch_size=batch_size, + + def apply( + tensordict, current_key: Tuple = None, being_computed: Dict = None + ): + out = ( + self + if inplace + else TensorDict( + {}, + batch_size=batch_size, + device=self.device, + _run_checks=False, + **constructor_kwargs, + ) + if batch_size is not None + else self.clone(recurse=False) + ) + is_locked = out.is_locked + if not inplace and is_locked: + out.unlock() + + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, item in self.items(): + if isinstance(item, TensorDictBase): + nested_key = current_key + (key,) + if id(item) in being_computed.values(): + being_computed[nested_key] = id(item) + continue + item_trsf = item.apply( + fn, inplace=inplace, batch_size=batch_size, **constructor_kwargs + ) + else: + item_trsf = fn(item) + if item_trsf is not None: + out.set( + key, item_trsf, inplace=inplace, _run_checks=False, _process=False + ) + + out2 = TensorDict( + out, device=self.device, - _run_checks=False, - **constructor_kwargs, + batch_size=self.batch_size, ) - if batch_size is not None - else self.clone(recurse=False) - ) - is_locked = out.is_locked - if not inplace and is_locked: - out.unlock() - for key, item in self.items(): - if isinstance(item, TensorDictBase): - item_trsf = item.apply( - fn, inplace=inplace, batch_size=batch_size, **constructor_kwargs - ) - else: - item_trsf = fn(item) - if item_trsf is not None: - out.set( - key, item_trsf, inplace=inplace, _run_checks=False, _process=False - ) - if not inplace and is_locked: - out.lock() - return out + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out2[other_nested_key] = out2 + return out2 + + if not inplace and is_locked: + out.lock() + return out + + return apply(self) def update( self, @@ -975,26 +1003,60 @@ def __ne__(self, other: object) -> TensorDictBase: tensors of the same shape as the original tensors. """ - if not isinstance(other, (TensorDictBase, dict, float, int)): - return False - if not isinstance(other, TensorDictBase) and isinstance(other, dict): - other = make_tensordict(**other, batch_size=self.batch_size) - if not isinstance(other, TensorDictBase): - return TensorDict( - {key: value != other for key, value in self.items()}, - self.batch_size, + def __ne__( + tensordict, other, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of __ne__ that supports auto-nesting.""" + if not isinstance(other, (TensorDictBase, dict, float, int)): + return False + if not isinstance(other, TensorDictBase) and isinstance(other, dict): + other = make_tensordict(**other, batch_size=tensordict.batch_size) + if not isinstance(other, TensorDictBase): + return TensorDict( + {key: value == other for key, value in tensordict.items()}, + tensordict.batch_size, + device=tensordict.device, + ) + keys1 = set(tensordict.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + out_dict = {} + + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = __ne__( + value, + other[key], + current_key=nested_key, + being_computed=being_computed, + ) + else: + new_value = value != other[key] + out_dict[key] = new_value + + out = TensorDict( + out_dict, device=self.device, + batch_size=self.batch_size, ) - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError( - f"keys in {self} and {other} mismatch, got {keys1} and {keys2}" - ) - d = {} - for (key, item1) in self.items(): - d[key] = item1 != other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return __ne__(tensordict=self, other=other) def __eq__(self, other: object) -> TensorDictBase: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. @@ -1515,10 +1577,41 @@ def contiguous(self) -> TensorDictBase: def to_dict(self) -> Dict[str, Any]: """Returns a dictionary with key-value pairs matching those of the tensordict.""" - return { - key: value.to_dict() if isinstance(value, TensorDictBase) else value - for key, value in self.items() - } + # return { + # key: value.to_dict() if isinstance(value, TensorDictBase) else value + # for key, value in self.items() + # } + + def _to_dict( + tensordict, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = _to_dict( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value + out_dict[key] = new_value + + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out_dict[str(*other_nested_key)] = out_dict + return out_dict + + return _to_dict(self) def unsqueeze(self, dim: int) -> TensorDictBase: """Unsqueeze all tensors for a dimension comprised in between `-td.batch_dims` and `td.batch_dims` and returns them in a new tensordict. @@ -1911,8 +2004,46 @@ def any(self, dim: int = None) -> Union[bool, TensorDictBase]: batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], device=self.device, ) - return any([value.any() for key, value in self.items()]) - + + def _any(tensordict, current_key: Tuple = None, being_computed: Dict = None): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = _any( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value.any(dim=dim) if dim is not None else value.any() + out_dict[key] = new_value + if dim is None: + # no need to do anything with being_computed, as a True in it will be kept + out = any(value for value in out_dict.values()) + else: + out = TensorDict( + out_dict, + batch_size=[ + b for i, b in enumerate(tensordict.batch_size) if i != dim + ], + device=tensordict.device, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return _any(self) + def get_sub_tensordict(self, idx: INDEX_TYPING) -> TensorDictBase: """Returns a SubTensorDict with the desired index.""" return SubTensorDict(source=self, idx=idx) @@ -2386,76 +2517,105 @@ def __init__( _is_shared: Optional[bool] = False, _is_memmap: Optional[bool] = False, ) -> None: - super().__init__() - - self._is_shared = _is_shared - self._is_memmap = _is_memmap - if device is not None: - device = torch.device(device) - self._device = device - - if not _run_checks: - if isinstance(source, dict): - self._tensordict: Dict = copy(source) - else: - self._tensordict: Dict = dict(source) - self._batch_size = torch.Size(batch_size) - upd_dict = {} - for key, value in self._tensordict.items(): - if isinstance(value, dict): - value = TensorDict( - value, - batch_size=self._batch_size, - device=self._device, - _run_checks=_run_checks, - _is_shared=_is_shared, - _is_memmap=_is_memmap, - ) - upd_dict[key] = value - if upd_dict: - self._tensordict.update(upd_dict) - else: - self._tensordict = {} - if not isinstance(source, (TensorDictBase, dict)): - raise ValueError( - "A TensorDict source is expected to be a TensorDictBase " - f"sub-type or a dictionary, found type(source)={type(source)}." - ) - self._batch_size = self._parse_batch_size(source, batch_size) + + def td_init(slf, source, current_key: Tuple = None, being_computed: Dict = None + ): + super(TensorDict, slf).__init__() + slf._is_shared = _is_shared + slf._is_memmap = _is_memmap + nonlocal device + if device is not None: + device = torch.device(device) + slf._device = device - if source is not None: - for key, value in source.items(): + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(source) + + if not _run_checks: + if isinstance(source, dict): + slf._tensordict: Dict = copy(source) + else: + slf._tensordict: Dict = dict(source) + self._batch_size = torch.Size(batch_size) + upd_dict = {} + for key, value in slf._tensordict.items(): if isinstance(value, dict): - value = TensorDict( + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + + vv = TensorDict({}, batch_size) + td_init(vv, value, - batch_size=self._batch_size, - device=self._device, - _run_checks=_run_checks, - _is_shared=_is_shared, - _is_memmap=_is_memmap, + current_key=nested_key, + being_computed=being_computed, ) - elif ( - isinstance(value, TensorDictBase) - and value.batch_size[: self.batch_dims] != self.batch_size - ): - value = value.clone(False) - value.batch_size = self.batch_size - elif isinstance(value, (Tensor, MemmapTensor)): - if value.shape[: len(self._batch_size)] != self._batch_size: - raise RuntimeError( - f"batch_size are incongruent, got {value.shape}, -- expected leading dims to be {self._batch_size}" - ) - if device is not None: - value = value.to(device) - _meta_val = ( - None - if _meta_source is None or key not in _meta_source - else _meta_source[key] - ) - self.set(key, value, _meta_val=_meta_val, _run_checks=False) + upd_dict[key] = vv + + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(slf): + upd_dict[other_nested_key] = upd_dict + + if upd_dict: + slf._tensordict.update(upd_dict) + else: + slf._tensordict = {} + if not isinstance(source, (TensorDictBase, dict)): + raise ValueError( + "A TensorDict source is expected to be a TensorDictBase " + f"sub-type or a dictionary, found type(source)={type(source)}." + ) + slf._batch_size = slf._parse_batch_size(source, batch_size) + + if source is not None: + for key, value in source.items(): + if isinstance(value, dict): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + + vv = TensorDict({}, batch_size) + td_init( + vv, + value, + current_key=nested_key, + being_computed=being_computed, + ) + value = vv + elif ( + isinstance(value, TensorDictBase) + and value.batch_size[: slf.batch_dims] != slf.batch_size + ): + value = value.clone(False) + value.batch_size = slf.batch_size + elif isinstance(value, (Tensor, MemmapTensor)): + if value.shape[: len(slf._batch_size)] != slf._batch_size: + raise RuntimeError( + f"batch_size are incongruent, got {value.shape}, -- expected leading dims to be {slf._batch_size}" + ) + if device is not None: + value = value.to(device) + _meta_val = ( + None + if _meta_source is None or key not in _meta_source + else _meta_source[key] + ) + slf.set(key, value, _meta_val=_meta_val, _run_checks=False) + + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(slf): + slf.set(other_nested_key, slf, _meta_val=_meta_val, _run_checks=False) # self._check_batch_size() # self._check_device() + td_init(self, source) @staticmethod def _parse_batch_size(