Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TensorClass.get(..., default) not propagated #1208

Closed
3 tasks done
egorchakov opened this issue Feb 5, 2025 · 5 comments
Closed
3 tasks done

[BUG] TensorClass.get(..., default) not propagated #1208

egorchakov opened this issue Feb 5, 2025 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@egorchakov
Copy link

Describe the bug

When calling .get() on a @tensorclass with TensorDict attributes, the default kwarg does not seem to be propagated properly.

To Reproduce

# file: tensorclass_get_default.py
# /// script
# requires-python = "==3.12"
# dependencies = [
#     "tensordict==0.6.2",
#     "packaging",
# ]
# ///

from tensordict import TensorDict, tensorclass


@tensorclass
class Data:
    td: TensorDict


Data(td=TensorDict({})).get(("td", "missing"), default=None)
: uv run tensorclass_get_default.py
Reading inline script metadata from `tensorclass_get_default.py`
Traceback (most recent call last):
  File "[...]/tensorclass_get_default.py", line 18, in <module>
    Data(td=TensorDict({})).get(("td", "missing"), default=None)
  File "/home/evgenii/.cache/uv/archive-v0/yCPXuVvMie1NN0uqK0b3L/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1792, in _get
    return getattr(self, key[0]).get(key[1:])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/evgenii/.cache/uv/archive-v0/yCPXuVvMie1NN0uqK0b3L/lib/python3.12/site-packages/tensordict/base.py", line 5044, in get
    return self._get_tuple(key, default=default)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/evgenii/.cache/uv/archive-v0/yCPXuVvMie1NN0uqK0b3L/lib/python3.12/site-packages/tensordict/_td.py", line 2468, in _get_tuple
    first = self._get_str(key[0], default)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/evgenii/.cache/uv/archive-v0/yCPXuVvMie1NN0uqK0b3L/lib/python3.12/site-packages/tensordict/_td.py", line 2464, in _get_str
    return self._default_get(first_key, default)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/evgenii/.cache/uv/archive-v0/yCPXuVvMie1NN0uqK0b3L/lib/python3.12/site-packages/tensordict/base.py", line 4995, in _default_get
    raise KeyError(
KeyError: 'key "missing" not found in TensorDict with keys []'

Expected behavior

The default kwarg is propagated as if on a nested TensorDict.

System info

0.6.2 2.2.2 3.12.0 (main, Oct  3 2023, 01:27:23) [Clang 17.0.1 ] linux 2.6.0+cu124

Reason and Possible fixes

tensordict.set_get_defaults_to_none(True) fixes this if we want the default to be None specifically.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egorchakov egorchakov added the bug Something isn't working label Feb 5, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2025

I fail to reprod on nightlies so closing
the new release will be out today (hopefully)

@vmoens vmoens closed this as completed Feb 5, 2025
@egorchakov
Copy link
Author

egorchakov commented Feb 6, 2025

@vmoens It seems in v0.7.0 the returned value isNone rather than the specified default one:

# file: tensorclass_get_default.py
# /// script
# requires-python = "==3.12"
# dependencies = [
#     "tensordict==0.7.0",
# ]
# ///

from tensordict import TensorClass, TensorDict


class Data(TensorClass):
    td: TensorDict


assert Data(td=TensorDict({})).get(("td", "missing"), default=0) == 0
: uv run tensorclass_get_default.py
Traceback (most recent call last):
  File "[...]/tensorclass_get_default.py", line 16, in <module>
    assert Data(td=TensorDict({})).get(("td", "missing"), default=0) == 0
AssertionError

@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

This is really bad :/
I will make a minor with the fix

@egorchakov egorchakov changed the title [BUG] @tensorclass .get(..., default) not propagated [BUG] TensorClass.get(..., default) not propagated Feb 6, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

(in the meantime feel free to use data.key0.get(key[1:], default))

@vmoens
Copy link
Contributor

vmoens commented Feb 6, 2025

Solved by #1211

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants