Skip to content

Commit

Permalink
Merge pull request #1 from vmoens/zoolmain_solvenested
Browse files Browse the repository at this point in the history
Solving infinite recursion
  • Loading branch information
Zooll authored Jan 20, 2023
2 parents 1d4571c + 397b99f commit 6be56da
Showing 1 changed file with 168 additions and 43 deletions.
211 changes: 168 additions & 43 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,24 +1004,61 @@ def __eq__(self, other: object) -> TensorDictBase:
tensors of the same shape as the original tensors.
"""
if not isinstance(other, (TensorDictBase, dict, float, int)):
return False
if not isinstance(other, TensorDictBase) and isinstance(other, dict):
other = make_tensordict(**other, batch_size=self.batch_size)
if not isinstance(other, TensorDictBase):
return TensorDict(
{key: value == other for key, value in self.items()},
self.batch_size,

def __eq__(
tensordict, other, current_key: Tuple = None, being_computed: Dict = None
):
"""A version of __eq__ that supports auto-nesting."""
if not isinstance(other, (TensorDictBase, dict, float, int)):
return False
if not isinstance(other, TensorDictBase) and isinstance(other, dict):
other = make_tensordict(**other, batch_size=tensordict.batch_size)
if not isinstance(other, TensorDictBase):
return TensorDict(
{key: value == other for key, value in tensordict.items()},
tensordict.batch_size,
device=tensordict.device,
)
keys1 = set(tensordict.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
out_dict = {}

out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = __eq__(
value,
other[key],
current_key=nested_key,
being_computed=being_computed,
)
else:
new_value = value == other[key]
out_dict[key] = new_value

out = TensorDict(
out_dict,
device=self.device,
batch_size=self.batch_size,
)
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for (key, item1) in self.items():
d[key] = item1 == other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out

return __eq__(tensordict=self, other=other)

@abc.abstractmethod
def del_(self, key: str) -> TensorDictBase:
Expand Down Expand Up @@ -1186,16 +1223,42 @@ def to_tensordict(self):
a new TensorDict object containing the same values.
"""
return TensorDict(
{
key: value.clone()
if not isinstance(value, TensorDictBase)
else value.to_tensordict()
for key, value in self.items()
},
device=self.device,
batch_size=self.batch_size,
)

def to_tensordict(
tensordict, current_key: Tuple = None, being_computed: Dict = None
):
"""A version of to_tensordict that supports auto-nesting."""
out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = to_tensordict(
value, current_key=nested_key, being_computed=being_computed
)
else:
new_value = value.clone()
out_dict[key] = new_value
out = TensorDict(
out_dict,
device=self.device,
batch_size=self.batch_size,
_run_checks=False,
)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out

return to_tensordict(self)

def zero_(self) -> TensorDictBase:
"""Zeros all tensors in the tensordict in-place."""
Expand Down Expand Up @@ -1267,14 +1330,44 @@ def clone(self, recurse: bool = True) -> TensorDictBase:
"""

return TensorDict(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
_run_checks=False,
_is_shared=self.is_shared() if not recurse else False,
_is_memmap=self.is_memmap() if not recurse else False,
)
def clone(tensordict, current_key: Tuple = None, being_computed: Dict = None):
"""A version of to_tensordict that supports auto-nesting."""
out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = clone(
value, current_key=nested_key, being_computed=being_computed
)
else:
if not recurse:
new_value = value
else:
new_value = value.clone()
out_dict[key] = new_value
out = TensorDict(
out_dict,
device=tensordict.device,
batch_size=tensordict.batch_size,
_run_checks=False,
_is_shared=tensordict.is_shared() if not recurse else False,
_is_memmap=tensordict.is_memmap() if not recurse else False,
)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out

return clone(self)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -1752,15 +1845,47 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]:
"dim must be greater than -tensordict.batch_dims and smaller "
"than tensordict.batchdims"
)
if dim is not None:
if dim < 0:
dim = self.batch_dims + dim
return TensorDict(
source={key: value.all(dim=dim) for key, value in self.items()},
batch_size=[b for i, b in enumerate(self.batch_size) if i != dim],
device=self.device,
)
return all(value.all() for value in self.values())
elif dim is not None and dim < 0:
dim = self.batch_dims + dim

def _all(tensordict, current_key: Tuple = None, being_computed: Dict = None):
"""A version of to_tensordict that supports auto-nesting."""
out_dict = {}
if current_key is None:
current_key = ()
if being_computed is None:
being_computed = {}
being_computed[current_key] = id(tensordict)
for key, value in tensordict.items():
if isinstance(value, TensorDictBase):
nested_key = current_key + (key,)
if id(value) in being_computed.values():
being_computed[nested_key] = id(value)
continue
new_value = _all(
value, current_key=nested_key, being_computed=being_computed
)
else:
new_value = value.all(dim=dim) if dim is not None else value.all()
out_dict[key] = new_value
if dim is None:
# no need to do anything with being_computed, as a True in it will be kept
out = all(value for value in out_dict.values())
else:
out = TensorDict(
out_dict,
batch_size=[
b for i, b in enumerate(tensordict.batch_size) if i != dim
],
device=tensordict.device,
)
for other_nested_key, other_value in being_computed.items():
if other_nested_key != current_key:
if other_value == id(tensordict):
out[other_nested_key] = out
return out

return _all(self)

def any(self, dim: int = None) -> Union[bool, TensorDictBase]:
"""Checks if any value is True/non-null in the tensordict.
Expand Down

0 comments on commit 6be56da

Please sign in to comment.