Skip to content

Commit

Permalink
markers: add special handling for extra
Browse files Browse the repository at this point in the history
  • Loading branch information
radoering committed Sep 11, 2023
1 parent 9747649 commit b91e1df
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
28 changes: 27 additions & 1 deletion src/poetry/core/version/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from typing import TypeVar
from typing import Union

from packaging.utils import canonicalize_name

from poetry.core.constraints.generic import BaseConstraint
from poetry.core.constraints.generic import Constraint
from poetry.core.constraints.generic import MultiConstraint
Expand Down Expand Up @@ -243,8 +245,24 @@ def validate(self, environment: dict[str, Any] | None) -> bool:
if self._name not in environment:
return True

# "extra" is special because it can have multiple values at the same time.
# "extra == 'a'" will be true if "a" is one of the active extras.
# "extra != 'a'" will be true if "a" is not one of the active extras.
# Further, extra names are normalized for comparison.
if self._name == "extra":
extras = environment["extra"]
if isinstance(extras, str):
extras = {extras}
extras = {canonicalize_name(extra) for extra in extras}
assert isinstance(self._constraint, Constraint)
normalized_value = canonicalize_name(self._constraint.value)
if self._constraint.operator == "==":
return normalized_value in extras
assert self._constraint.operator == "!="
return normalized_value not in extras

# The type of constraint returned by the parser matches our constraint: either
# both are BaseConstraint or both are VersionConstraint. But it's hard for mypy
# both are BaseConstraint or both are VersionConstraint. But it's hard for mypy
# to know that.
constraint = self._parser(environment[self._name])
return self._constraint.allows(constraint) # type: ignore[arg-type]
Expand Down Expand Up @@ -976,6 +994,14 @@ def _merge_single_markers(
if marker1.name != marker2.name:
return None

# "extra" is special because it can have multiple values at the same time.
# That's why we can only merge two "extra" markers if they have the same value.
if marker1.name == "extra":
assert isinstance(marker1, SingleMarker)
assert isinstance(marker2, SingleMarker)
if marker1.value != marker2.value: # type: ignore[attr-defined]
return None

if merge_class == MultiMarker:
merge_method = marker1.constraint.intersect
else:
Expand Down
99 changes: 96 additions & 3 deletions tests/version/test_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@
' "linux" and python_version < "3.6" and python_version >= "3.3" or'
' sys_platform == "darwin" and python_version < "3.3"'
),
# "extra" is a special marker that can have multiple values at the same time.
# Thus, "extra == 'a' and extra == 'b'" is not empty.
# Further, "extra == 'a' and extra != 'b'" cannot be simplified
# because it has the meaning "extra 'a' must and extra 'b' must not be active"
'extra == "a" and extra == "b"',
'extra == "a" and extra != "b"',
'extra != "a" and extra == "b"',
'extra != "a" and extra != "b"',
'extra == "a" or extra == "b"',
'extra == "a" or extra != "b"',
'extra != "a" or extra == "b"',
'extra != "a" or extra != "b"',
],
)
def test_parse_marker(marker: str) -> None:
Expand Down Expand Up @@ -199,6 +211,27 @@ def test_single_marker_not_in_python_intersection() -> None:
assert str(intersection) == 'python_version not in "2.7, 3.0, 3.1, 3.2"'


@pytest.mark.parametrize(
("marker1", "marker2", "expected"),
[
# same value
('extra == "a"', 'extra == "a"', 'extra == "a"'),
('extra == "a"', 'extra != "a"', "<empty>"),
('extra != "a"', 'extra == "a"', "<empty>"),
('extra != "a"', 'extra != "a"', 'extra != "a"'),
# different values
('extra == "a"', 'extra == "b"', 'extra == "a" and extra == "b"'),
('extra == "a"', 'extra != "b"', 'extra == "a" and extra != "b"'),
('extra != "a"', 'extra == "b"', 'extra != "a" and extra == "b"'),
('extra != "a"', 'extra != "b"', 'extra != "a" and extra != "b"'),
],
)
def test_single_marker_intersect_extras(
marker1: str, marker2: str, expected: str
) -> None:
assert str(parse_marker(marker1).intersect(parse_marker(marker2))) == expected


def test_single_marker_union() -> None:
m = parse_marker('sys_platform == "darwin"')

Expand Down Expand Up @@ -372,6 +405,25 @@ def test_single_marker_union_with_inverse() -> None:
assert union.is_any()


@pytest.mark.parametrize(
("marker1", "marker2", "expected"),
[
# same value
('extra == "a"', 'extra == "a"', 'extra == "a"'),
('extra == "a"', 'extra != "a"', ""),
('extra != "a"', 'extra == "a"', ""),
('extra != "a"', 'extra != "a"', 'extra != "a"'),
# different values
('extra == "a"', 'extra == "b"', 'extra == "a" or extra == "b"'),
('extra == "a"', 'extra != "b"', 'extra == "a" or extra != "b"'),
('extra != "a"', 'extra == "b"', 'extra != "a" or extra == "b"'),
('extra != "a"', 'extra != "b"', 'extra != "a" or extra != "b"'),
],
)
def test_single_marker_union_extras(marker1: str, marker2: str, expected: str) -> None:
assert str(parse_marker(marker1).union(parse_marker(marker2))) == expected


def test_multi_marker() -> None:
m = parse_marker('sys_platform == "darwin" and implementation_name == "cpython"')

Expand Down Expand Up @@ -858,8 +910,6 @@ def test_multi_marker_removes_duplicates() -> None:
{"os_name": "other", "python_version": "2.7.4"},
False,
),
("extra == 'security'", {"extra": "quux"}, False),
("extra == 'security'", {"extra": "security"}, True),
(f"os.name == '{os.name}'", None, True),
("sys.platform == 'win32'", {"sys_platform": "linux2"}, False),
("platform.version in 'Ubuntu'", {"platform_version": "#39"}, False),
Expand Down Expand Up @@ -906,6 +956,49 @@ def test_multi_marker_removes_duplicates() -> None:
{"platform_machine": "x86_64"},
False,
),
# extras
# single extra
("extra == 'security'", {"extra": "quux"}, False),
("extra == 'security'", {"extra": "security"}, True),
("extra != 'security'", {"extra": "quux"}, True),
("extra != 'security'", {"extra": "security"}, False),
# normalization
("extra == 'Security.1'", {"extra": "security-1"}, True),
# extra unknown
("extra == 'a'", {}, True),
("extra != 'a'", {}, True),
("extra == 'a' and extra == 'b'", {}, True),
# extra explicitly not set
("extra == 'a'", {"extra": ()}, False),
("extra != 'b'", {"extra": ()}, True),
("extra == 'a' and extra == 'b'", {"extra": ()}, False),
("extra == 'a' or extra == 'b'", {"extra": ()}, False),
("extra != 'a' and extra != 'b'", {"extra": ()}, True),
("extra != 'a' or extra != 'b'", {"extra": ()}, True),
("extra != 'a' and extra == 'b'", {"extra": ()}, False),
("extra != 'a' or extra == 'b'", {"extra": ()}, True),
# multiple extras
("extra == 'a'", {"extra": ("a", "b")}, True),
("extra == 'a'", {"extra": ("b", "c")}, False),
("extra != 'a'", {"extra": ("a", "b")}, False),
("extra != 'a'", {"extra": ("b", "c")}, True),
("extra == 'a' and extra == 'b'", {"extra": ("a", "b", "c")}, True),
("extra == 'a' and extra == 'b'", {"extra": ("a", "c")}, False),
("extra == 'a' or extra == 'b'", {"extra": ("a", "c")}, True),
("extra == 'a' or extra == 'b'", {"extra": ("b", "c")}, True),
("extra == 'a' or extra == 'b'", {"extra": ("c", "d")}, False),
("extra != 'a' and extra != 'b'", {"extra": ("a", "c")}, False),
("extra != 'a' and extra != 'b'", {"extra": ("b", "c")}, False),
("extra != 'a' and extra != 'b'", {"extra": ("c", "d")}, True),
("extra != 'a' or extra != 'b'", {"extra": ("a", "b", "c")}, False),
("extra != 'a' or extra != 'b'", {"extra": ("a", "c")}, True),
("extra != 'a' or extra != 'b'", {"extra": ("b", "c")}, True),
("extra != 'a' and extra == 'b'", {"extra": ("a", "b")}, False),
("extra != 'a' and extra == 'b'", {"extra": ("b", "c")}, True),
("extra != 'a' and extra == 'b'", {"extra": ("c", "d")}, False),
("extra != 'a' or extra == 'b'", {"extra": ("a", "b")}, True),
("extra != 'a' or extra == 'b'", {"extra": ("c", "d")}, True),
("extra != 'a' or extra == 'b'", {"extra": ("a", "c")}, False),
],
)
def test_validate(
Expand Down Expand Up @@ -959,7 +1052,7 @@ def test_parse_version_like_markers(marker: str, env: dict[str, str]) -> None:
'python_version >= "3.6" or extra == "foo" and implementation_name =='
' "pypy" or extra == "bar"'
),
"",
'python_version >= "3.6" or implementation_name == "pypy"',
),
('extra == "foo"', ""),
('extra == "foo" or extra == "bar"', ""),
Expand Down

0 comments on commit b91e1df

Please sign in to comment.