diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 040b7f50a..b95bc5738 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -485,6 +485,35 @@ class TensorDictModuleBase(nn.Module): >>> tensordict_out = module.forward(tensordict_in) + Unlike :class:`~tensordict.nn.TensorDictModule`, `TensorDictModuleBase` is typically used via subclassing: + you can wrap any python function in a `TensorDictModuleBase` subclass, as long as the subclass forward reads and + writes tensordict (or related types) instances. + + The `in_keys` and `out_keys` should be properly specified. For example, `out_keys` can be dynamically reduced using + :meth:`~tensordict.nn.TensorDictBase.select_out_keys`. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModuleBase + >>> class Mod(TensorDictModuleBase): + ... in_keys = ["a"] # can also be specified during __init__ + ... out_keys = ["b", "c"] + ... def forward(self, tensordict): + ... b = tensordict["a"].clone() + ... c = b + 1 + ... return tensordict.replace({"b": b, "c": c}) + >>> mod = Mod() + >>> td = mod(TensorDict(a=0)) + >>> td["b"] + tensor(0) + >>> td["c"] + tensor(1) + >>> mod.select_out_keys("c") + >>> td = mod(TensorDict(a=0)) + >>> td["c"] + tensor(1) + >>> assert "b" not in td + """ def __new__(cls, *args, **kwargs):