Skip to content

Commit

Permalink
Fix attribute inheritance regression (introduced in 0.9.9) (#240)
Browse files Browse the repository at this point in the history
* Fix attribute inheritance regression

* Add test

* Workaround for Python 3.8 __annotations__ behavior

* Remove unnecessary filter

* Additional test

* ruff
  • Loading branch information
brentyi authored Jan 22, 2025
1 parent 690e04e commit 4d18b8e
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,13 @@ def get_hints_for_bound_method(cls) -> Dict[str, Any]:
for x, t in _get_type_hints_backported_syntax(
obj, include_extras=include_extras
).items()
if x in obj.__annotations__
# Only include type hints that are explicitly defined in this class.
#
# Why `cls.__dict__.__annotations__` instead of `cls.__annotations__`?
# Because in Python 3.8 and earlier, `cls.__annotations__`
# recursively merges parent class annotations.
# See this issue: https://github.com/python/cpython/issues/99535
if x in obj.__dict__.get("__annotations__", {})
}

# We need to recurse into base classes in order to correctly resolve superclass parameters.
Expand All @@ -689,6 +695,11 @@ def get_hints_for_bound_method(cls) -> Dict[str, Any]:
{
x: TypeParamResolver.concretize_type_params(t)
for x, t in base_hints.items()
# Include type hints that are not assigned earlier in the MRO.
#
# This needs to be recursive (include parents of parents),
# so we shouldn't filter by local __annotations__.
if x not in out
}
)

Expand Down
91 changes: 91 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,3 +1794,94 @@ def main(
with pytest.raises(SystemExit):
# ConfigB has a required argument.
assert tyro.cli(main, args=["x:config-b"])


_dataset_map = {
"alpaca": "tatsu-lab/alpaca",
"alpaca_clean": "yahma/alpaca-cleaned",
"alpaca_gpt4": "vicgalle/alpaca-gpt4",
}
_inv_dataset_map = {value: key for key, value in _dataset_map.items()}
_datasets = list(_dataset_map.keys())

HFDataset = Annotated[
str,
tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(_datasets) + "}",
instance_from_str=lambda args: _dataset_map[args[0]],
is_instance=lambda instance: isinstance(instance, str)
and instance in _inv_dataset_map,
str_from_instance=lambda instance: [_inv_dataset_map[instance]],
choices=tuple(_datasets),
),
tyro.conf.arg(
help_behavior_hint=lambda df: f"(default: {df}, run datasets.py for full options)"
),
]


def test_annotated_attribute_inheritance() -> None:
"""From @mirceamironenco.
https://github.com/brentyi/tyro/issues/239"""

@dataclasses.dataclass(frozen=True)
class TrainConfig:
dataset: str = "vicgalle/alpaca-gpt4"

@dataclasses.dataclass(frozen=True)
class CLITrainerConfig(TrainConfig):
dataset: HFDataset = "vicgalle/alpaca-gpt4"

assert "{alpaca,alpaca_clean,alpaca_gpt4}" in get_helptext_with_checks(
CLITrainerConfig
)
assert (
"default: alpaca_gpt4, run datasets.py for full options"
in get_helptext_with_checks(CLITrainerConfig)
)


@dataclasses.dataclass(frozen=True)
class OptimizerConfig:
lr: float = 1e-1


@dataclasses.dataclass(frozen=True)
class AdamConfig(OptimizerConfig):
adam_foo: float = 1.0


@dataclasses.dataclass(frozen=True)
class SGDConfig(OptimizerConfig):
sgd_foo: float = 1.0


@dataclasses.dataclass
class TrainConfig:
optimizer: OptimizerConfig = AdamConfig()


def _dummy_constructor() -> Type[OptimizerConfig]:
return Union[AdamConfig, SGDConfig] # type: ignore


CLIOptimizerConfig = Annotated[
OptimizerConfig,
tyro.conf.arg(constructor_factory=_dummy_constructor),
]


def test_attribute_inheritance_2() -> None:
"""From @mirceamironenco.
https://github.com/brentyi/tyro/issues/239"""

@dataclasses.dataclass
class CLITrainerConfig(TrainConfig):
optimizer: CLIOptimizerConfig = SGDConfig()

assert "[{optimizer:adam-config,optimizer:sgd-config}]" in get_helptext_with_checks(
CLITrainerConfig
)
91 changes: 91 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,3 +1800,94 @@ def main(
with pytest.raises(SystemExit):
# ConfigB has a required argument.
assert tyro.cli(main, args=["x:config-b"])


_dataset_map = {
"alpaca": "tatsu-lab/alpaca",
"alpaca_clean": "yahma/alpaca-cleaned",
"alpaca_gpt4": "vicgalle/alpaca-gpt4",
}
_inv_dataset_map = {value: key for key, value in _dataset_map.items()}
_datasets = list(_dataset_map.keys())

HFDataset = Annotated[
str,
tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(_datasets) + "}",
instance_from_str=lambda args: _dataset_map[args[0]],
is_instance=lambda instance: isinstance(instance, str)
and instance in _inv_dataset_map,
str_from_instance=lambda instance: [_inv_dataset_map[instance]],
choices=tuple(_datasets),
),
tyro.conf.arg(
help_behavior_hint=lambda df: f"(default: {df}, run datasets.py for full options)"
),
]


def test_annotated_attribute_inheritance() -> None:
"""From @mirceamironenco.
https://github.com/brentyi/tyro/issues/239"""

@dataclasses.dataclass(frozen=True)
class TrainConfig:
dataset: str = "vicgalle/alpaca-gpt4"

@dataclasses.dataclass(frozen=True)
class CLITrainerConfig(TrainConfig):
dataset: HFDataset = "vicgalle/alpaca-gpt4"

assert "{alpaca,alpaca_clean,alpaca_gpt4}" in get_helptext_with_checks(
CLITrainerConfig
)
assert (
"default: alpaca_gpt4, run datasets.py for full options"
in get_helptext_with_checks(CLITrainerConfig)
)


@dataclasses.dataclass(frozen=True)
class OptimizerConfig:
lr: float = 1e-1


@dataclasses.dataclass(frozen=True)
class AdamConfig(OptimizerConfig):
adam_foo: float = 1.0


@dataclasses.dataclass(frozen=True)
class SGDConfig(OptimizerConfig):
sgd_foo: float = 1.0


@dataclasses.dataclass
class TrainConfig:
optimizer: OptimizerConfig = AdamConfig()


def _dummy_constructor() -> Type[OptimizerConfig]:
return AdamConfig | SGDConfig # type: ignore


CLIOptimizerConfig = Annotated[
OptimizerConfig,
tyro.conf.arg(constructor_factory=_dummy_constructor),
]


def test_attribute_inheritance_2() -> None:
"""From @mirceamironenco.
https://github.com/brentyi/tyro/issues/239"""

@dataclasses.dataclass
class CLITrainerConfig(TrainConfig):
optimizer: CLIOptimizerConfig = SGDConfig()

assert "[{optimizer:adam-config,optimizer:sgd-config}]" in get_helptext_with_checks(
CLITrainerConfig
)

0 comments on commit 4d18b8e

Please sign in to comment.