Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mart.nn.Get() to extract a value from the kwargs dict. #251

Merged
merged 12 commits into from
May 16, 2024
18 changes: 17 additions & 1 deletion mart/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum"]
__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum", "Get"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,3 +300,19 @@ def __init__(self):

def forward(self, *args):
return sum(args)


class Get:
"""Get a value from the kwargs dictionary by key.

The key can be a path to a nested dictionary, concatenated by dots. For example,
`Get(key="a.b")(a={"b": 1}) == 1`.
"""

def __init__(self, key):
self.key = key

def __call__(self, **kwargs):
# Add support to nested dicts.
kwargs = DotDict(kwargs)
return kwargs[self.key]
Loading