Skip to content

Commit

Permalink
requirement: Check requirements using the canonical name when fixing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tetsuo-cpp authored Mar 29, 2023
1 parent 1f46ebd commit 35573a4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ All versions prior to 0.0.9 are untracked.

### Fixed

* Fixed bug with the `--fix` flag where new requirements were sometimes being
appended to requirement files instead of patching the existing requirement
([#577](https://github.com/pypa/pip-audit/pull/577))

* Fixed a crash caused by auditing requirements files that refer to other
requirements files ([#568](https://github.com/pypa/pip-audit/pull/568))

Expand Down
6 changes: 5 additions & 1 deletion pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import IO, Iterator

from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name
from pip_requirements_parser import InstallRequirement, InvalidRequirementLine, RequirementsFile

from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError
Expand Down Expand Up @@ -203,7 +204,10 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None:
with filename.open("w") as f:
found = False
for req in reqs:
if isinstance(req, InstallRequirement) and req.name == fix_version.dep.name:
if (
isinstance(req, InstallRequirement)
and canonicalize_name(req.name) == fix_version.dep.canonical_name
):
found = True
if req.specifier.contains(
fix_version.dep.version
Expand Down
46 changes: 46 additions & 0 deletions test/dependency_source/test_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,52 @@ def test_requirement_source_fix(req_file):
)


def test_requirement_source_fix_roundtrip(req_file):
req_path = req_file()
with open(req_path, "w") as f:
f.write("flask==0.5")

source = requirement.RequirementSource([req_path])
specs = list(source.collect())

flask_dep: ResolvedDependency | None = None
for spec in specs:
if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask":
flask_dep = spec
break
assert flask_dep is not None
assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5"))

flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0"))
source.fix(flask_fix)

with open(req_path) as f:
assert f.read().strip() == "flask==1.0"


def test_requirement_source_fix_roundtrip_non_canonical_name(req_file):
req_path = req_file()
with open(req_path, "w") as f:
f.write("Flask==0.5")

source = requirement.RequirementSource([req_path])
specs = list(source.collect())

flask_dep: ResolvedDependency | None = None
for spec in specs:
if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask":
flask_dep = spec
break
assert flask_dep is not None
assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5"))

flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0"))
source.fix(flask_fix)

with open(req_path) as f:
assert f.read().strip() == "Flask==1.0"


def test_requirement_source_fix_multiple_files(req_file):
_check_fixes(
["flask==0.5", "requests==2.0\nflask==0.5"],
Expand Down

0 comments on commit 35573a4

Please sign in to comment.