Skip to content

Commit

Permalink
Fix tyro.conf.UseAppendAction + abstract Sequence annotations (#248)
Browse files Browse the repository at this point in the history
* Fix `tyro.conf.UseAppendAction` + abstract `Sequence` annotations

* Fix Python 3.7/3.8 tests

* type: ignore argparse override

* ruff / type ignore
  • Loading branch information
brentyi authored Jan 29, 2025
1 parent e1deb24 commit 6ce3ade
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/tyro/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ def _print_message(self, message, file=None):
except (AttributeError, OSError): # pragma: no cover
pass

@override
def _parse_known_args(self, arg_strings, namespace): # pragma: no cover
# @override
def _parse_known_args( # type: ignore
self, arg_strings, namespace
): # pragma: no cover
"""We override _parse_known_args() to improve error messages in the presence of
subcommands. Difference is marked with <new>...</new> below."""

Expand Down
3 changes: 2 additions & 1 deletion src/tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import collections.abc
import dataclasses
import json
import shlex
Expand Down Expand Up @@ -372,7 +373,7 @@ def append_instantiator(x: list[list[str]]) -> Any:
out.append(part)

# Return output with correct type.
if isinstance(out, dict):
if container_type in (dict, Sequence, collections.abc.Sequence):
return out
else:
return container_type(out)
Expand Down
16 changes: 15 additions & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json as json_
import shlex
import sys
from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Sequence, Tuple, Type, TypeVar, Union

import pytest
from helptext_utils import get_helptext_with_checks
Expand Down Expand Up @@ -903,6 +903,20 @@ class A:
tyro.cli(A, args=["--x", "1", "2", "3"])


def test_append_sequence() -> None:
@dataclasses.dataclass
class A:
x: tyro.conf.UseAppendAction[Sequence[int]]

assert tyro.cli(A, args=[]) == A(x=[])
assert tyro.cli(A, args="--x 1 --x 2 --x 3".split(" ")) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x"])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x", "1", "2", "3"])


def test_append_tuple() -> None:
@dataclasses.dataclass
class A:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_new_style_annotations_min_py39.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import dataclasses
from typing import Any, Literal, Optional, Type, Union

Expand Down Expand Up @@ -70,6 +71,20 @@ def test_tuple_direct() -> None:
assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore


def test_append_abc_sequence() -> None:
@dataclasses.dataclass
class A:
x: tyro.conf.UseAppendAction[collections.abc.Sequence[int]]

assert tyro.cli(A, args=[]) == A(x=[])
assert tyro.cli(A, args="--x 1 --x 2 --x 3".split(" ")) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x"])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x", "1", "2", "3"])


try:
from torch.optim.lr_scheduler import LinearLR, LRScheduler

Expand Down
15 changes: 15 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Dict,
Generic,
List,
Sequence,
Tuple,
Type,
TypedDict,
Expand Down Expand Up @@ -911,6 +912,20 @@ class A:
tyro.cli(A, args=["--x", "1", "2", "3"])


def test_append_sequence() -> None:
@dataclasses.dataclass
class A:
x: tyro.conf.UseAppendAction[Sequence[int]]

assert tyro.cli(A, args=[]) == A(x=[])
assert tyro.cli(A, args="--x 1 --x 2 --x 3".split(" ")) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x"])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x", "1", "2", "3"])


def test_append_tuple() -> None:
@dataclasses.dataclass
class A:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import dataclasses
from typing import Any, Literal, Optional, Type

Expand Down Expand Up @@ -70,6 +71,20 @@ def test_tuple_direct() -> None:
assert tyro.cli(tuple[int, int], args="1 2".split(" ")) == (1, 2) # type: ignore


def test_append_abc_sequence() -> None:
@dataclasses.dataclass
class A:
x: tyro.conf.UseAppendAction[collections.abc.Sequence[int]]

assert tyro.cli(A, args=[]) == A(x=[])
assert tyro.cli(A, args="--x 1 --x 2 --x 3".split(" ")) == A(x=[1, 2, 3])
assert tyro.cli(A, args=[]) == A(x=[])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x"])
with pytest.raises(SystemExit):
tyro.cli(A, args=["--x", "1", "2", "3"])


try:
from torch.optim.lr_scheduler import LinearLR, LRScheduler

Expand Down

0 comments on commit 6ce3ade

Please sign in to comment.