diff --git a/src/poetry/puzzle/solver.py b/src/poetry/puzzle/solver.py index 234518fc52e..7bbce2e4b5b 100644 --- a/src/poetry/puzzle/solver.py +++ b/src/poetry/puzzle/solver.py @@ -140,7 +140,12 @@ def _solve_in_compatibility_mode( self, overrides: tuple[dict[Package, dict[str, Dependency]], ...], ) -> dict[Package, TransitivePackageInfo]: - packages: dict[Package, TransitivePackageInfo] = {} + override_packages: list[ + tuple[ + dict[Package, dict[str, Dependency]], + dict[Package, TransitivePackageInfo], + ] + ] = [] for override in overrides: self._provider.debug( # ignore the warning as provider does not do interpolation @@ -149,9 +154,9 @@ def _solve_in_compatibility_mode( ) self._provider.set_overrides(override) new_packages = self._solve() - merge_packages_from_override(packages, new_packages, override) + override_packages.append((override, new_packages)) - return packages + return merge_override_packages(override_packages) def _solve(self) -> dict[Package, TransitivePackageInfo]: if self._provider._overrides: @@ -406,34 +411,63 @@ def calculate_markers( transitive_info.markers = transitive_marker -def merge_packages_from_override( - packages: dict[Package, TransitivePackageInfo], - new_packages: dict[Package, TransitivePackageInfo], - override: dict[Package, dict[str, Dependency]], -) -> None: - override_marker: BaseMarker = AnyMarker() - for deps in override.values(): - for dep in deps.values(): - override_marker = override_marker.intersect(dep.marker.without_extras()) - for new_package, new_package_info in new_packages.items(): - if package_info := packages.get(new_package): - # update existing package - package_info.depth = max(package_info.depth, new_package_info.depth) - package_info.groups.update(new_package_info.groups) - for group, marker in new_package_info.markers.items(): - package_info.markers[group] = package_info.markers.get( - group, EmptyMarker() - ).union(override_marker.intersect(marker)) - for package in packages: - if package == new_package: - for dep in new_package.requires: - if dep not in package.requires: - package.add_dependency(dep) - +def merge_override_packages( + override_packages: list[ + tuple[ + dict[Package, dict[str, Dependency]], dict[Package, TransitivePackageInfo] + ] + ], +) -> dict[Package, TransitivePackageInfo]: + result: dict[Package, TransitivePackageInfo] = {} + all_packages: dict[ + Package, list[tuple[Package, TransitivePackageInfo, BaseMarker]] + ] = {} + for override, o_packages in override_packages: + override_marker: BaseMarker = AnyMarker() + for deps in override.values(): + for dep in deps.values(): + override_marker = override_marker.intersect(dep.marker.without_extras()) + for package, info in o_packages.items(): + all_packages.setdefault(package, []).append( + (package, info, override_marker) + ) + for package_duplicates in all_packages.values(): + base = package_duplicates[0] + package = base[0] + package_info = base[1] + first_override_marker = base[2] + result[package] = package_info + package_info.depth = max(info.depth for _, info, _ in package_duplicates) + package_info.groups = { + g for _, info, _ in package_duplicates for g in info.groups + } + if all( + info.markers == package_info.markers for _, info, _ in package_duplicates + ): + # performance shortcut: + # if markers are the same for all overrides, + # we can use less expensive marker operations + override_marker = EmptyMarker() + for _, _, marker in package_duplicates: + override_marker = override_marker.union(marker) + package_info.markers = { + group: override_marker.intersect(marker) + for group, marker in package_info.markers.items() + } else: - for group, marker in new_package_info.markers.items(): - new_package_info.markers[group] = override_marker.intersect(marker) - packages[new_package] = new_package_info + # fallback / general algorithm with performance issues + for group, marker in package_info.markers.items(): + package_info.markers[group] = first_override_marker.intersect(marker) + for _, info, override_marker in package_duplicates[1:]: + for group, marker in info.markers.items(): + package_info.markers[group] = package_info.markers.get( + group, EmptyMarker() + ).union(override_marker.intersect(marker)) + for duplicate_package, _, _ in package_duplicates[1:]: + for dep in duplicate_package.requires: + if dep not in package.requires: + package.add_dependency(dep) + return result @functools.cache diff --git a/tests/puzzle/test_solver_internals.py b/tests/puzzle/test_solver_internals.py index 228654e70a3..ae882560e5c 100644 --- a/tests/puzzle/test_solver_internals.py +++ b/tests/puzzle/test_solver_internals.py @@ -13,7 +13,7 @@ from poetry.puzzle.solver import PackageNode from poetry.puzzle.solver import Solver from poetry.puzzle.solver import depth_first_search -from poetry.puzzle.solver import merge_packages_from_override +from poetry.puzzle.solver import merge_override_packages if TYPE_CHECKING: @@ -359,28 +359,29 @@ def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) - } -def test_merge_packages_from_override_restricted(package: ProjectPackage) -> None: +def test_merge_override_packages_restricted(package: ProjectPackage) -> None: """Markers of dependencies should be intersected with override markers.""" a = Package("a", "1") - packages: dict[Package, TransitivePackageInfo] = {} - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} - ) - }, - {package: {"a": dep("b", 'python_version < "3.9"')}}, - ) - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")} - ) - }, - {package: {"a": dep("b", 'python_version >= "3.9"')}}, + packages = merge_override_packages( + [ + ( + {package: {"a": dep("b", 'python_version < "3.9"')}}, + { + a: TransitivePackageInfo( + 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} + ) + }, + ), + ( + {package: {"a": dep("b", 'python_version >= "3.9"')}}, + { + a: TransitivePackageInfo( + 0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")} + ) + }, + ), + ] ) assert len(packages) == 1 assert packages[a].groups == {"main"} @@ -392,28 +393,33 @@ def test_merge_packages_from_override_restricted(package: ProjectPackage) -> Non } -def test_merge_packages_from_override_extras(package: ProjectPackage) -> None: +def test_merge_override_packages_extras(package: ProjectPackage) -> None: """Extras from overrides should not be visible in the resulting marker.""" a = Package("a", "1") - packages: dict[Package, TransitivePackageInfo] = {} - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} - ) - }, - {package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}}, - ) - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")} - ) - }, - {package: {"a": dep("b", 'python_version >= "3.9" and extra == "foo"')}}, + packages = merge_override_packages( + [ + ( + {package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}}, + { + a: TransitivePackageInfo( + 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} + ) + }, + ), + ( + { + package: { + "a": dep("b", 'python_version >= "3.9" and extra == "foo"') + } + }, + { + a: TransitivePackageInfo( + 0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")} + ) + }, + ), + ] ) assert len(packages) == 1 assert packages[a].groups == {"main"} @@ -425,21 +431,23 @@ def test_merge_packages_from_override_extras(package: ProjectPackage) -> None: } -def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) -> None: +def test_merge_override_packages_multiple_deps(package: ProjectPackage) -> None: """All override markers should be intersected.""" a = Package("a", "1") - packages: dict[Package, TransitivePackageInfo] = {} - merge_packages_from_override( - packages, - {a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})}, - { - package: { - "a": dep("b", 'python_version < "3.9"'), - "c": dep("d", 'sys_platform == "linux"'), - }, - a: {"e": dep("f", 'python_version >= "3.8"')}, - }, + packages = merge_override_packages( + [ + ( + { + package: { + "a": dep("b", 'python_version < "3.9"'), + "c": dep("d", 'sys_platform == "linux"'), + }, + a: {"e": dep("f", 'python_version >= "3.8"')}, + }, + {a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})}, + ), + ] ) assert len(packages) == 1 @@ -452,44 +460,45 @@ def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) -> } -def test_merge_packages_from_override_groups(package: ProjectPackage) -> None: +def test_merge_override_packages_groups(package: ProjectPackage) -> None: a = Package("a", "1") b = Package("b", "1") - packages: dict[Package, TransitivePackageInfo] = {} - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} - ), - b: TransitivePackageInfo( - 0, - {"main", "dev"}, + packages = merge_override_packages( + [ + ( + {package: {"a": dep("b", 'python_version < "3.9"')}}, { - "main": parse_marker("sys_platform == 'win32'"), - "dev": parse_marker("sys_platform == 'linux'"), + a: TransitivePackageInfo( + 0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")} + ), + b: TransitivePackageInfo( + 0, + {"main", "dev"}, + { + "main": parse_marker("sys_platform == 'win32'"), + "dev": parse_marker("sys_platform == 'linux'"), + }, + ), }, ), - }, - {package: {"a": dep("b", 'python_version < "3.9"')}}, - ) - merge_packages_from_override( - packages, - { - a: TransitivePackageInfo( - 0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")} - ), - b: TransitivePackageInfo( - 0, - {"main", "dev"}, + ( + {package: {"a": dep("b", 'python_version >= "3.9"')}}, { - "main": parse_marker("platform_machine == 'amd64'"), - "dev": parse_marker("platform_machine == 'aarch64'"), + a: TransitivePackageInfo( + 0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")} + ), + b: TransitivePackageInfo( + 0, + {"main", "dev"}, + { + "main": parse_marker("platform_machine == 'amd64'"), + "dev": parse_marker("platform_machine == 'aarch64'"), + }, + ), }, ), - }, - {package: {"a": dep("b", 'python_version >= "3.9"')}}, + ] ) assert len(packages) == 2 assert packages[a].groups == {"main", "dev"}