From fc8dc1117b123ebece23aa39135b07931155ddc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 Feb 2023 17:31:11 +0100 Subject: [PATCH 01/12] Refactor from_str --- src/lightning_utilities/core/enums.py | 50 ++++++++------------------- 1 file changed, 15 insertions(+), 35 deletions(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 6c0953fa..3ea1548f 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -2,9 +2,8 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # -import warnings from enum import Enum -from typing import Optional +from typing import List from typing_extensions import Literal @@ -19,46 +18,29 @@ class StrEnum(str, Enum): True >>> MySE.from_str("t-2", source="value") == MySE.t2 True + >>> MySE.from_str("t-2", source="value") + + >>> MySE.from_str("t-3", source="any") + Traceback (most recent call last): + ... + ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3. """ @classmethod - def from_str( - cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False - ) -> Optional["StrEnum"]: - """Create StrEnum from a sting matching the key or value. + def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum": + """Create StrEnum from a string matching the key or value. Args: value: matching string source: compare with: - - - ``"key"``: validates only with Enum keys, typical alphanumeric with "_" - - ``"value"``: validates only with Enum values, could be any string - - ``"key"``: validates with any key or value, but key has priority - - strict: allow not matching string and returns None; if false raises exceptions + - ``"key"``: validates only from the enum keys, typical alphanumeric with "_" + - ``"value"``: validates only from the values, could be any string + - ``"any"``: validates with any key or value, but key has priority Raises: ValueError: - if requested string does not match any option based on selected source and use ``"strict=True"`` - UserWarning: - if requested string does not match any option based on selected source and use ``"strict=False"`` - - Example: - >>> class MySE(StrEnum): - ... t1 = "T-1" - ... t2 = "T-2" - >>> MySE.from_str("t-1", source="key") - >>> MySE.from_str("t-2", source="value") - - >>> MySE.from_str("t-3", source="any", strict=True) - Traceback (most recent call last): - ... - ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3. + if requested string does not match any option based on selected source. """ - allowed = cls._allowed_matches(source) - if strict and not any(enum_.lower() == value.lower() for enum_ in allowed): - raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.") - if source in ("key", "any"): for enum_key in cls.__members__.keys(): if enum_key.lower() == value.lower(): @@ -67,12 +49,10 @@ def from_str( for enum_key, enum_val in cls.__members__.items(): if enum_val == value: return cls[enum_key] - - warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}.")) - return None + raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.") @classmethod - def _allowed_matches(cls, source: str) -> list: + def _allowed_matches(cls, source: str) -> List[str]: keys, vals = [], [] for enum_key, enum_val in cls.__members__.items(): keys.append(enum_key) From 51cb36ae1e8d245f8b973f2c00d7bbd2f3f909f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 Feb 2023 17:36:23 +0100 Subject: [PATCH 02/12] more --- src/lightning_utilities/core/enums.py | 9 ++++++++- tests/unittests/core/test_enums.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 3ea1548f..c7f69a15 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -3,7 +3,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # from enum import Enum -from typing import List +from typing import List, Optional from typing_extensions import Literal @@ -51,6 +51,13 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> return cls[enum_key] raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.") + @classmethod + def maybe_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: + try: + return cls.from_str(value, source) + except ValueError: + return + @classmethod def _allowed_matches(cls, source: str) -> List[str]: keys, vals = [], [] diff --git a/tests/unittests/core/test_enums.py b/tests/unittests/core/test_enums.py index 6ddba5ca..d465a5c9 100644 --- a/tests/unittests/core/test_enums.py +++ b/tests/unittests/core/test_enums.py @@ -47,9 +47,9 @@ class MyEnum(StrEnum): T2 = "t:2" assert MyEnum.from_str("T1", source="key") - assert MyEnum.from_str("T1", source="value") is None + assert MyEnum.maybe_from_str("T1", source="value") is None assert MyEnum.from_str("T1", source="any") - assert MyEnum.from_str("T:2", source="key") is None + assert MyEnum.maybe_from_str("T:2", source="key") is None assert MyEnum.from_str("T:2", source="value") assert MyEnum.from_str("T:2", source="any") From f6cfd89aa2dbae9d411f715f2478bfb48097b31b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 Feb 2023 17:41:47 +0100 Subject: [PATCH 03/12] Update src/lightning_utilities/core/enums.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning_utilities/core/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index c7f69a15..c87b9847 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -28,7 +28,7 @@ class StrEnum(str, Enum): @classmethod def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum": - """Create StrEnum from a string matching the key or value. + """Create ``StrEnum`` from a string matching the key or value. Args: value: matching string From aa0771b45ebc963c95d97ac435d056cb7393680e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 7 Feb 2023 17:43:03 +0100 Subject: [PATCH 04/12] Update src/lightning_utilities/core/enums.py --- src/lightning_utilities/core/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index c87b9847..0b992606 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -56,7 +56,7 @@ def maybe_from_str(cls, value: str, source: Literal["key", "value", "any"] = "ke try: return cls.from_str(value, source) except ValueError: - return + return None @classmethod def _allowed_matches(cls, source: str) -> List[str]: From 943d027fbb0b6ab69ba4f3b67afdf53d43382997 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:33:03 +0100 Subject: [PATCH 05/12] Apply suggestions from code review --- src/lightning_utilities/core/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 0b992606..62f31dc4 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -33,6 +33,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Args: value: matching string source: compare with: + - ``"key"``: validates only from the enum keys, typical alphanumeric with "_" - ``"value"``: validates only from the values, could be any string - ``"any"``: validates with any key or value, but key has priority From 1f7a7434b32d037556855f852b22f13182b25cfe Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 14 Feb 2023 13:45:00 +0100 Subject: [PATCH 06/12] Apply suggestions from code review --- src/lightning_utilities/core/enums.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 62f31dc4..42503a95 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -53,11 +53,11 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.") @classmethod - def maybe_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: + def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: try: return cls.from_str(value, source) except ValueError: - return None + warnings.warn(UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")) @classmethod def _allowed_matches(cls, source: str) -> List[str]: From f402100fe3471ec05391379a6a576bfdc4c90243 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Feb 2023 12:45:08 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_utilities/core/enums.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 42503a95..b486afe9 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -57,7 +57,9 @@ def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key" try: return cls.from_str(value, source) except ValueError: - warnings.warn(UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")) + warnings.warn( + UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.") + ) @classmethod def _allowed_matches(cls, source: str) -> List[str]: From 9b09a33d29eda39563a64edec8e4fd7ad30b8895 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 14 Feb 2023 13:46:33 +0100 Subject: [PATCH 08/12] tests --- tests/unittests/core/test_enums.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/core/test_enums.py b/tests/unittests/core/test_enums.py index d465a5c9..3bd3638c 100644 --- a/tests/unittests/core/test_enums.py +++ b/tests/unittests/core/test_enums.py @@ -47,9 +47,9 @@ class MyEnum(StrEnum): T2 = "t:2" assert MyEnum.from_str("T1", source="key") - assert MyEnum.maybe_from_str("T1", source="value") is None + assert MyEnum.try_from_str("T1", source="value") is None assert MyEnum.from_str("T1", source="any") - assert MyEnum.maybe_from_str("T:2", source="key") is None + assert MyEnum.try_from_str("T:2", source="key") is None assert MyEnum.from_str("T:2", source="value") assert MyEnum.from_str("T:2", source="any") From 1a723e81f0d1104c6ac4ab3694ef111c23a1d6fb Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 14 Feb 2023 13:47:22 +0100 Subject: [PATCH 09/12] chlog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9596e72..1374db1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99)) +- Fixed `StrEnum.from_str` with source as key ( + [#99](https://github.com/Lightning-AI/utilities/pull/99), + [#102](https://github.com/Lightning-AI/utilities/pull/102) +) ## [0.6.0] - 2023-01-23 From 848c31d13853f1eb30e0b4e2eb3756e822f12018 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 14 Feb 2023 13:52:15 +0100 Subject: [PATCH 10/12] warning --- src/lightning_utilities/core/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index b486afe9..44c1c67e 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -2,6 +2,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # +import warnings from enum import Enum from typing import List, Optional From 27a2957cc23496f741d24215e07b377e95776293 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 14 Feb 2023 13:54:31 +0100 Subject: [PATCH 11/12] mocks --- tests/unittests/core/__init__.py | 0 tests/unittests/core/test_apply_func.py | 2 +- tests/unittests/{core => }/mocks.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 tests/unittests/core/__init__.py rename tests/unittests/{core => }/mocks.py (100%) diff --git a/tests/unittests/core/__init__.py b/tests/unittests/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index 1dfc71ed..f5e0a5e2 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -5,7 +5,7 @@ from typing import Any, ClassVar, List, Optional import pytest -from unittests.core.mocks import torch +from unittests.mocks import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections diff --git a/tests/unittests/core/mocks.py b/tests/unittests/mocks.py similarity index 100% rename from tests/unittests/core/mocks.py rename to tests/unittests/mocks.py From b4cda31e3d316cacbd2bc48561deb3f3c2979e76 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 14 Feb 2023 13:54:57 +0100 Subject: [PATCH 12/12] return --- src/lightning_utilities/core/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 44c1c67e..e0c9b873 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -61,6 +61,7 @@ def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key" warnings.warn( UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.") ) + return None @classmethod def _allowed_matches(cls, source: str) -> List[str]: