diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 2ab0911a6..d07727e1f 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1004,24 +1004,61 @@ def __eq__(self, other: object) -> TensorDictBase: tensors of the same shape as the original tensors. """ - if not isinstance(other, (TensorDictBase, dict, float, int)): - return False - if not isinstance(other, TensorDictBase) and isinstance(other, dict): - other = make_tensordict(**other, batch_size=self.batch_size) - if not isinstance(other, TensorDictBase): - return TensorDict( - {key: value == other for key, value in self.items()}, - self.batch_size, + + def __eq__( + tensordict, other, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of __eq__ that supports auto-nesting.""" + if not isinstance(other, (TensorDictBase, dict, float, int)): + return False + if not isinstance(other, TensorDictBase) and isinstance(other, dict): + other = make_tensordict(**other, batch_size=tensordict.batch_size) + if not isinstance(other, TensorDictBase): + return TensorDict( + {key: value == other for key, value in tensordict.items()}, + tensordict.batch_size, + device=tensordict.device, + ) + keys1 = set(tensordict.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + out_dict = {} + + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = __eq__( + value, + other[key], + current_key=nested_key, + being_computed=being_computed, + ) + else: + new_value = value == other[key] + out_dict[key] = new_value + + out = TensorDict( + out_dict, device=self.device, + batch_size=self.batch_size, ) - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") - d = {} - for (key, item1) in self.items(): - d[key] = item1 == other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return __eq__(tensordict=self, other=other) @abc.abstractmethod def del_(self, key: str) -> TensorDictBase: @@ -1186,16 +1223,42 @@ def to_tensordict(self): a new TensorDict object containing the same values. """ - return TensorDict( - { - key: value.clone() - if not isinstance(value, TensorDictBase) - else value.to_tensordict() - for key, value in self.items() - }, - device=self.device, - batch_size=self.batch_size, - ) + + def to_tensordict( + tensordict, current_key: Tuple = None, being_computed: Dict = None + ): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = to_tensordict( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value.clone() + out_dict[key] = new_value + out = TensorDict( + out_dict, + device=self.device, + batch_size=self.batch_size, + _run_checks=False, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return to_tensordict(self) def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" @@ -1267,14 +1330,44 @@ def clone(self, recurse: bool = True) -> TensorDictBase: """ - return TensorDict( - source={key: _clone_value(value, recurse) for key, value in self.items()}, - batch_size=self.batch_size, - device=self.device, - _run_checks=False, - _is_shared=self.is_shared() if not recurse else False, - _is_memmap=self.is_memmap() if not recurse else False, - ) + def clone(tensordict, current_key: Tuple = None, being_computed: Dict = None): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = clone( + value, current_key=nested_key, being_computed=being_computed + ) + else: + if not recurse: + new_value = value + else: + new_value = value.clone() + out_dict[key] = new_value + out = TensorDict( + out_dict, + device=tensordict.device, + batch_size=tensordict.batch_size, + _run_checks=False, + _is_shared=tensordict.is_shared() if not recurse else False, + _is_memmap=tensordict.is_memmap() if not recurse else False, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return clone(self) @classmethod def __torch_function__( @@ -1752,15 +1845,47 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]: "dim must be greater than -tensordict.batch_dims and smaller " "than tensordict.batchdims" ) - if dim is not None: - if dim < 0: - dim = self.batch_dims + dim - return TensorDict( - source={key: value.all(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, - ) - return all(value.all() for value in self.values()) + elif dim is not None and dim < 0: + dim = self.batch_dims + dim + + def _all(tensordict, current_key: Tuple = None, being_computed: Dict = None): + """A version of to_tensordict that supports auto-nesting.""" + out_dict = {} + if current_key is None: + current_key = () + if being_computed is None: + being_computed = {} + being_computed[current_key] = id(tensordict) + for key, value in tensordict.items(): + if isinstance(value, TensorDictBase): + nested_key = current_key + (key,) + if id(value) in being_computed.values(): + being_computed[nested_key] = id(value) + continue + new_value = _all( + value, current_key=nested_key, being_computed=being_computed + ) + else: + new_value = value.all(dim=dim) if dim is not None else value.all() + out_dict[key] = new_value + if dim is None: + # no need to do anything with being_computed, as a True in it will be kept + out = all(value for value in out_dict.values()) + else: + out = TensorDict( + out_dict, + batch_size=[ + b for i, b in enumerate(tensordict.batch_size) if i != dim + ], + device=tensordict.device, + ) + for other_nested_key, other_value in being_computed.items(): + if other_nested_key != current_key: + if other_value == id(tensordict): + out[other_nested_key] = out + return out + + return _all(self) def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict.