Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requirement: Check requirements using the canonical name when fixing #577

Merged
merged 8 commits into from
Mar 29, 2023
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ All versions prior to 0.0.9 are untracked.

### Fixed

* Fixed bug with the `--fix` flag where new requirements were sometimes being
appended to requirement files instead of patching the existing requirement
([#577](https://github.com/pypa/pip-audit/pull/577))

* Fixed a crash caused by auditing requirements files that refer to other
requirements files ([#568](https://github.com/pypa/pip-audit/pull/568))

Expand Down
6 changes: 5 additions & 1 deletion pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import IO, Iterator

from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name
from pip_requirements_parser import InstallRequirement, InvalidRequirementLine, RequirementsFile

from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError
Expand Down Expand Up @@ -203,7 +204,10 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None:
with filename.open("w") as f:
found = False
for req in reqs:
if isinstance(req, InstallRequirement) and req.name == fix_version.dep.name:
if (
isinstance(req, InstallRequirement)
and canonicalize_name(req.name) == fix_version.dep.canonical_name
):
found = True
if req.specifier.contains(
fix_version.dep.version
Expand Down
46 changes: 46 additions & 0 deletions test/dependency_source/test_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,52 @@ def test_requirement_source_fix(req_file):
)


def test_requirement_source_fix_roundtrip(req_file):
req_path = req_file()
with open(req_path, "w") as f:
f.write("flask==0.5")

source = requirement.RequirementSource([req_path])
specs = list(source.collect())

flask_dep: ResolvedDependency | None = None
for spec in specs:
if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask":
flask_dep = spec
break
assert flask_dep is not None
assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5"))

flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0"))
source.fix(flask_fix)

with open(req_path) as f:
assert f.read().strip() == "flask==1.0"


def test_requirement_source_fix_roundtrip_non_canonical_name(req_file):
req_path = req_file()
with open(req_path, "w") as f:
f.write("Flask==0.5")

source = requirement.RequirementSource([req_path])
specs = list(source.collect())

flask_dep: ResolvedDependency | None = None
for spec in specs:
if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask":
flask_dep = spec
break
assert flask_dep is not None
assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5"))

flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0"))
source.fix(flask_fix)

with open(req_path) as f:
assert f.read().strip() == "Flask==1.0"


def test_requirement_source_fix_multiple_files(req_file):
_check_fixes(
["flask==0.5", "requests==2.0\nflask==0.5"],
Expand Down