Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Enforce zip(..., strict=True) in TDModules #1212

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def _write_to_tensordict(
tensordict_out = TensorDict()
else:
tensordict_out = tensordict
for _out_key, _tensor in zip(out_keys, tensors):
for _out_key, _tensor in _zip_strict(out_keys, tensors):
if _out_key != "_":
tensordict_out.set(_out_key, TensorDict.from_any(_tensor))
return tensordict_out
Expand Down Expand Up @@ -1097,7 +1097,9 @@ def forward(
for in_key in self.in_keys
)
try:
tensors = self._call_module(tensors, **kwargs)
tensors_out = self._call_module(tensors, **kwargs)
if tensors_out is None:
tensors_out = ()
except Exception as err:
if any(tensor is None for tensor in tensors) and "None" in str(err):
none_set = {
Expand All @@ -1112,18 +1114,18 @@ def forward(
) from err
else:
raise err
if isinstance(tensors, (dict, TensorDictBase)) and all(
key in tensors for key in self.out_keys
if isinstance(tensors_out, (dict, TensorDictBase)) and all(
key in tensors_out for key in self.out_keys
):
if isinstance(tensors, dict):
keys = unravel_key_list(list(tensors.keys()))
values = tensors.values()
tensors = dict(_zip_strict(keys, values))
tensors = tuple(tensors.get(key) for key in self.out_keys)
if not isinstance(tensors, tuple):
tensors = (tensors,)
if isinstance(tensors_out, dict):
keys = unravel_key_list(list(tensors_out.keys()))
values = tensors_out.values()
tensors_out = dict(_zip_strict(keys, values))
tensors_out = tuple(tensors_out.get(key) for key in self.out_keys)
if not isinstance(tensors_out, tuple):
tensors_out = (tensors_out,)
tensordict_out = self._write_to_tensordict(
tensordict, tensors, tensordict_out
tensordict, tensors_out, tensordict_out
)
return tensordict_out
except Exception as err:
Expand Down
4 changes: 2 additions & 2 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PYTREE_REGISTERED_LAZY_TDS,
PYTREE_REGISTERED_TDS,
)
from tensordict.utils import strtobool
from tensordict.utils import _zip_strict, strtobool
from torch import Tensor

from torch.utils._pytree import SUPPORTED_NODES, tree_map
Expand Down Expand Up @@ -296,7 +296,7 @@ def check_tensor_id(name, t0, t1):
def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter >= self._warmup:
srcs, dests = [], []
for arg_src, arg_dest in zip(
for arg_src, arg_dest in _zip_strict(
tree_leaves((args, kwargs)), self._flat_tree
):
self._maybe_copy_onto_(arg_src, arg_dest, srcs, dests)
Expand Down
5 changes: 3 additions & 2 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import (
_LOCK_ERROR,
_zip_strict,
BufferLegacy,
erase_cache,
implement_for,
Expand Down Expand Up @@ -475,8 +476,8 @@ def _reset_params(self, params: dict | None = None, buffers: dict | None = None)
buffer_keys.append(key)
buffers.append(value)

self._parameters.update(dict(zip(param_keys, params)))
self._buffers.update(dict(zip(buffer_keys, buffers)))
self._parameters.update(dict(_zip_strict(param_keys, params)))
self._buffers.update(dict(_zip_strict(buffer_keys, buffers)))
else:
self._parameters.update(params)
self._buffers.update(buffers)
Expand Down
7 changes: 3 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def log_prob(
return dist.log_prob(tensordict.get(self.out_keys[0]))

def _update_td_lp(self, lp):
for out_key, lp_key in zip(self.dist_sample_keys, self.log_prob_keys):
for out_key, lp_key in _zip_strict(self.dist_sample_keys, self.log_prob_keys):
lp_key_expected = _add_suffix(out_key, "_log_prob")
if lp_key != lp_key_expected:
lp.rename_key_(lp_key_expected, lp_key)
Expand Down Expand Up @@ -637,7 +637,7 @@ def forward(
if isinstance(out_tensors, Tensor):
out_tensors = (out_tensors,)
tensordict_out.update(
{key: value for key, value in zip(self.out_keys, out_tensors)}
dict(_zip_strict(self.dist_sample_keys, out_tensors))
)
if self.return_log_prob:
log_prob = dist.log_prob(*out_tensors)
Expand Down Expand Up @@ -1155,8 +1155,7 @@ def get_dist(
if isinstance(tdm, ProbabilisticTensorDictModule):
if isinstance(sample, torch.Tensor):
sample = [sample]
for val, key in zip(sample, tdm.out_keys):
td_copy.set(key, val)
td_copy.update(dict(_zip_strict(tdm.dist_sample_keys, sample)))
else:
td_copy.update(sample)
dists[tdm.out_keys[0]] = dist
Expand Down
6 changes: 3 additions & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from tensordict.nn.utils import _set_skip_existing_None
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from tensordict.utils import unravel_key_list
from tensordict.utils import _zip_strict, unravel_key_list
from torch import nn

_has_functorch = False
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
self._complete_out_keys = list(out_keys)
modules = collections.OrderedDict(
**{key: val for key, val in zip(modules[0], modules_vals)}
**{key: val for key, val in _zip_strict(modules[0], modules_vals)}
)
super().__init__(
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
Expand Down Expand Up @@ -493,7 +493,7 @@ def select_subsequence(
else:
keys = [key for key in self.module if self.module[key] in modules]
modules_dict = collections.OrderedDict(
**{key: val for key, val in zip(keys, modules)}
**{key: val for key, val in _zip_strict(keys, modules)}
)
return type(self)(modules_dict)

Expand Down
Loading