diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index be8aa42f1..ecfe956be 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -11,12 +11,14 @@ from tensordict._lazy import LazyStackedTensorDict from tensordict._td import TensorDict -from tensordict.tensorclass import NonTensorData +from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _STRDTYPE2DTYPE CLS_MAP = { "TensorDict": TensorDict, "LazyStackedTensorDict": LazyStackedTensorDict, + "NonTensorData": NonTensorData, + "NonTensorStack": NonTensorStack, } @@ -57,7 +59,9 @@ def from_metadata(metadata=metadata_dict, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) + if isinstance(cls, str): + cls = CLS_MAP[cls] + result = cls._from_dict_validated(d, **cls_metadata) if is_locked: result.lock_() # if is_shared: @@ -121,7 +125,9 @@ def from_metadata(metadata=metadata, prefix=None): d[k] = from_metadata( v, prefix=prefix + (k,) if prefix is not None else (k,) ) - result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) + if isinstance(cls, str): + cls = CLS_MAP[cls] + result = cls._from_dict_validated(d, **cls_metadata) if is_locked: result = result.lock_() result._consolidated = consolidated diff --git a/tensordict/base.py b/tensordict/base.py index 21a4ab133..b8d47c075 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -4994,7 +4994,7 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): if requires_metadata: # metadata is nested metadata_dict = { - "cls": type(self).__name__, + "cls": type(self), "non_tensors": {}, "leaves": {}, "cls_metadata": self._reduce_get_metadata(), @@ -5055,7 +5055,7 @@ def assign( metadata_dict_key = None if requires_metadata: metadata_dict_key = metadata_dict[key] = { - "cls": cls.__name__, + "cls": cls, "non_tensors": {}, "leaves": {}, "cls_metadata": value._reduce_get_metadata(),