Skip to content

Commit

Permalink
improve performance for merging markers from overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
radoering authored and abn committed Jan 11, 2025
1 parent 30fe6c7 commit 4ef2df2
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 110 deletions.
94 changes: 64 additions & 30 deletions src/poetry/puzzle/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def _solve_in_compatibility_mode(
self,
overrides: tuple[dict[Package, dict[str, Dependency]], ...],
) -> dict[Package, TransitivePackageInfo]:
packages: dict[Package, TransitivePackageInfo] = {}
override_packages: list[
tuple[
dict[Package, dict[str, Dependency]],
dict[Package, TransitivePackageInfo],
]
] = []
for override in overrides:
self._provider.debug(
# ignore the warning as provider does not do interpolation
Expand All @@ -149,9 +154,9 @@ def _solve_in_compatibility_mode(
)
self._provider.set_overrides(override)
new_packages = self._solve()
merge_packages_from_override(packages, new_packages, override)
override_packages.append((override, new_packages))

return packages
return merge_override_packages(override_packages)

def _solve(self) -> dict[Package, TransitivePackageInfo]:
if self._provider._overrides:
Expand Down Expand Up @@ -406,34 +411,63 @@ def calculate_markers(
transitive_info.markers = transitive_marker


def merge_packages_from_override(
packages: dict[Package, TransitivePackageInfo],
new_packages: dict[Package, TransitivePackageInfo],
override: dict[Package, dict[str, Dependency]],
) -> None:
override_marker: BaseMarker = AnyMarker()
for deps in override.values():
for dep in deps.values():
override_marker = override_marker.intersect(dep.marker.without_extras())
for new_package, new_package_info in new_packages.items():
if package_info := packages.get(new_package):
# update existing package
package_info.depth = max(package_info.depth, new_package_info.depth)
package_info.groups.update(new_package_info.groups)
for group, marker in new_package_info.markers.items():
package_info.markers[group] = package_info.markers.get(
group, EmptyMarker()
).union(override_marker.intersect(marker))
for package in packages:
if package == new_package:
for dep in new_package.requires:
if dep not in package.requires:
package.add_dependency(dep)

def merge_override_packages(
override_packages: list[
tuple[
dict[Package, dict[str, Dependency]], dict[Package, TransitivePackageInfo]
]
],
) -> dict[Package, TransitivePackageInfo]:
result: dict[Package, TransitivePackageInfo] = {}
all_packages: dict[
Package, list[tuple[Package, TransitivePackageInfo, BaseMarker]]
] = {}
for override, o_packages in override_packages:
override_marker: BaseMarker = AnyMarker()
for deps in override.values():
for dep in deps.values():
override_marker = override_marker.intersect(dep.marker.without_extras())
for package, info in o_packages.items():
all_packages.setdefault(package, []).append(
(package, info, override_marker)
)
for package_duplicates in all_packages.values():
base = package_duplicates[0]
package = base[0]
package_info = base[1]
first_override_marker = base[2]
result[package] = package_info
package_info.depth = max(info.depth for _, info, _ in package_duplicates)
package_info.groups = {
g for _, info, _ in package_duplicates for g in info.groups
}
if all(
info.markers == package_info.markers for _, info, _ in package_duplicates
):
# performance shortcut:
# if markers are the same for all overrides,
# we can use less expensive marker operations
override_marker = EmptyMarker()
for _, _, marker in package_duplicates:
override_marker = override_marker.union(marker)
package_info.markers = {
group: override_marker.intersect(marker)
for group, marker in package_info.markers.items()
}
else:
for group, marker in new_package_info.markers.items():
new_package_info.markers[group] = override_marker.intersect(marker)
packages[new_package] = new_package_info
# fallback / general algorithm with performance issues
for group, marker in package_info.markers.items():
package_info.markers[group] = first_override_marker.intersect(marker)
for _, info, override_marker in package_duplicates[1:]:
for group, marker in info.markers.items():
package_info.markers[group] = package_info.markers.get(
group, EmptyMarker()
).union(override_marker.intersect(marker))
for duplicate_package, _, _ in package_duplicates[1:]:
for dep in duplicate_package.requires:
if dep not in package.requires:
package.add_dependency(dep)
return result


@functools.cache
Expand Down
169 changes: 89 additions & 80 deletions tests/puzzle/test_solver_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from poetry.puzzle.solver import PackageNode
from poetry.puzzle.solver import Solver
from poetry.puzzle.solver import depth_first_search
from poetry.puzzle.solver import merge_packages_from_override
from poetry.puzzle.solver import merge_override_packages


if TYPE_CHECKING:
Expand Down Expand Up @@ -359,28 +359,29 @@ def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) -
}


def test_merge_packages_from_override_restricted(package: ProjectPackage) -> None:
def test_merge_override_packages_restricted(package: ProjectPackage) -> None:
"""Markers of dependencies should be intersected with override markers."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
{package: {"a": dep("b", 'python_version < "3.9"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
{package: {"a": dep("b", 'python_version >= "3.9"')}},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
),
(
{package: {"a": dep("b", 'python_version >= "3.9"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
),
]
)
assert len(packages) == 1
assert packages[a].groups == {"main"}
Expand All @@ -392,28 +393,33 @@ def test_merge_packages_from_override_restricted(package: ProjectPackage) -> Non
}


def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
def test_merge_override_packages_extras(package: ProjectPackage) -> None:
"""Extras from overrides should not be visible in the resulting marker."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
{package: {"a": dep("b", 'python_version >= "3.9" and extra == "foo"')}},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
)
},
),
(
{
package: {
"a": dep("b", 'python_version >= "3.9" and extra == "foo"')
}
},
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
)
},
),
]
)
assert len(packages) == 1
assert packages[a].groups == {"main"}
Expand All @@ -425,21 +431,23 @@ def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
}


def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) -> None:
def test_merge_override_packages_multiple_deps(package: ProjectPackage) -> None:
"""All override markers should be intersected."""
a = Package("a", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
{
package: {
"a": dep("b", 'python_version < "3.9"'),
"c": dep("d", 'sys_platform == "linux"'),
},
a: {"e": dep("f", 'python_version >= "3.8"')},
},
packages = merge_override_packages(
[
(
{
package: {
"a": dep("b", 'python_version < "3.9"'),
"c": dep("d", 'sys_platform == "linux"'),
},
a: {"e": dep("f", 'python_version >= "3.8"')},
},
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
),
]
)

assert len(packages) == 1
Expand All @@ -452,44 +460,45 @@ def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) ->
}


def test_merge_packages_from_override_groups(package: ProjectPackage) -> None:
def test_merge_override_packages_groups(package: ProjectPackage) -> None:
a = Package("a", "1")
b = Package("b", "1")

packages: dict[Package, TransitivePackageInfo] = {}
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
packages = merge_override_packages(
[
(
{package: {"a": dep("b", 'python_version < "3.9"')}},
{
"main": parse_marker("sys_platform == 'win32'"),
"dev": parse_marker("sys_platform == 'linux'"),
a: TransitivePackageInfo(
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
{
"main": parse_marker("sys_platform == 'win32'"),
"dev": parse_marker("sys_platform == 'linux'"),
},
),
},
),
},
{package: {"a": dep("b", 'python_version < "3.9"')}},
)
merge_packages_from_override(
packages,
{
a: TransitivePackageInfo(
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
(
{package: {"a": dep("b", 'python_version >= "3.9"')}},
{
"main": parse_marker("platform_machine == 'amd64'"),
"dev": parse_marker("platform_machine == 'aarch64'"),
a: TransitivePackageInfo(
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
),
b: TransitivePackageInfo(
0,
{"main", "dev"},
{
"main": parse_marker("platform_machine == 'amd64'"),
"dev": parse_marker("platform_machine == 'aarch64'"),
},
),
},
),
},
{package: {"a": dep("b", 'python_version >= "3.9"')}},
]
)
assert len(packages) == 2
assert packages[a].groups == {"main", "dev"}
Expand Down

0 comments on commit 4ef2df2

Please sign in to comment.