Skip to content

Commit

Permalink
Workaround for __type_params__ bug in Python 3.12.0 (#236)
Browse files Browse the repository at this point in the history
* Add test from #235

* torch import try/except

* test gen, ruff

* Add 3.12.0 to pytest yml

* Workaround for `type[T]` bug in Python 3.12.0
  • Loading branch information
brentyi authored Jan 16, 2025
1 parent 6792e04 commit 7577a30
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.12.0", "3.13"]

steps:
- uses: actions/checkout@v2
Expand Down
14 changes: 8 additions & 6 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,14 @@ def concretize_type_params(
typ = resolve_newtype_and_aliases(typ)
type_from_typevar = {}
GenericAlias = getattr(types, "GenericAlias", None)
while (
GenericAlias is not None
and isinstance(typ, GenericAlias)
and len(getattr(typ, "__type_params__", ())) > 0
):
for k, v in zip(typ.__type_params__, get_args(typ)): # type: ignore
while GenericAlias is not None and isinstance(typ, GenericAlias):
type_params = getattr(typ, "__type_params__", ())
# The __len__ check is for a bug in Python 3.12.0:
# https://github.com/brentyi/tyro/issues/235
if not hasattr(type_params, "__len__") or len(type_params) == 0:
break

for k, v in zip(type_params, get_args(typ)):
type_from_typevar[k] = TypeParamResolver.concretize_type_params(
v, seen=seen
)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_new_style_annotations_min_py39.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Type, Union

import pytest
from helptext_utils import get_helptext_with_checks

import tyro

Expand Down Expand Up @@ -67,3 +68,33 @@ def main(
def test_tuple_direct() -> None:
assert tyro.cli(tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore
assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore


try:
from torch.optim.lr_scheduler import LinearLR, LRScheduler

def test_type_with_init_false() -> None:
"""https://github.com/brentyi/tyro/issues/235"""

@dataclasses.dataclass(frozen=True)
class LinearLRConfig:
_target: type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
_target2: Type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
start_factor: float = 1.0 / 3
end_factor: float = 1.0
total_iters: Optional[int] = None

def main(config: LinearLRConfig) -> LinearLRConfig:
return config

assert tyro.cli(main, args=[]) == LinearLRConfig()
assert "_target" not in get_helptext_with_checks(LinearLRConfig)
except ImportError:
# We can't install PyTorch in Python 3.13.
import sys

assert sys.version_info >= (3, 13)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Type

import pytest
from helptext_utils import get_helptext_with_checks

import tyro

Expand Down Expand Up @@ -67,3 +68,33 @@ def main(
def test_tuple_direct() -> None:
assert tyro.cli(tuple[int, ...], args="1 2".split(" ")) == (1, 2) # type: ignore
assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore


try:
from torch.optim.lr_scheduler import LinearLR, LRScheduler

def test_type_with_init_false() -> None:
"""https://github.com/brentyi/tyro/issues/235"""

@dataclasses.dataclass(frozen=True)
class LinearLRConfig:
_target: type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
_target2: Type[LRScheduler] = dataclasses.field(
init=False, default_factory=lambda: LinearLR
)
start_factor: float = 1.0 / 3
end_factor: float = 1.0
total_iters: Optional[int] = None

def main(config: LinearLRConfig) -> LinearLRConfig:
return config

assert tyro.cli(main, args=[]) == LinearLRConfig()
assert "_target" not in get_helptext_with_checks(LinearLRConfig)
except ImportError:
# We can't install PyTorch in Python 3.13.
import sys

assert sys.version_info >= (3, 13)

0 comments on commit 7577a30

Please sign in to comment.