From 5343b4f995405639a15d5e8c24466431e30ac542 Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Wed, 22 Jun 2022 19:12:21 -0500 Subject: [PATCH 1/2] Removing the click context use --- README.md | 2 +- safety/cli.py | 47 +++++++++++++---- safety/output_utils.py | 48 ++++++++--------- safety/safety.py | 39 +++++++++----- safety/util.py | 104 ++++++++++++++++++++++++++++--------- tests/test_cli.py | 35 ++++++++++++- tests/test_output_utils.py | 16 +++--- tests/test_safety.py | 5 +- 8 files changed, 210 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 54a74f2c..8a83c316 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Travis](https://img.shields.io/travis/pyupio/safety.svg)](https://travis-ci.org/pyupio/safety) [![Updates](https://pyup.io/repos/github/pyupio/safety/shield.svg)](https://pyup.io/repos/github/pyupio/safety/) -Safety checks your installed Python dependencies for known security vulnerabilities and suggests the proper remediations for vulnerabilities detected. Safety can be run on developer machines, in CI/CD pipelines and on production systems. +Safety checks Python dependencies for known security vulnerabilities and suggests the proper remediations for vulnerabilities detected. Safety can be run on developer machines, in CI/CD pipelines and on production systems. By default it uses the open Python vulnerability database [Safety DB](https://github.com/pyupio/safety-db), which is **licensed for non-commercial use only**. diff --git a/safety/cli.py b/safety/cli.py index 74257bae..27c59915 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -20,12 +20,18 @@ LOG = logging.getLogger(__name__) + @click.group() @click.option('--debug/--no-debug', default=False) @click.option('--telemetry/--disable-telemetry', default=True) @click.version_option(version=get_safety_version()) @click.pass_context def cli(ctx, debug, telemetry): + """ + Safety checks Python dependencies for known security vulnerabilities and suggests the proper + remediations for vulnerabilities detected. Safety can be run on developer machines, in CI/CD pipelines and + on production systems. + """ ctx.telemetry = telemetry level = logging.CRITICAL if debug: @@ -47,7 +53,8 @@ def cli(ctx, debug, telemetry): with_values={"output": ['json', 'bare'], "json": [True, False], "bare": [True, False]}, help='Full reports include a security advisory (if available). Default: --short-report') @click.option("--cache", is_flag=False, flag_value=60, default=0, - help="Cache requests to the vulnerability database locally. Default: 0 seconds") + help="Cache requests to the vulnerability database locally. Default: 0 seconds", + hidden=True) @click.option("--stdin/--no-stdin", default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["files"], help="Read input from stdin. Default: --no-stdin") @click.option("files", "--file", "-r", multiple=True, type=click.File(), cls=MutuallyExclusiveOption, mutually_exclusive=["stdin"], @@ -55,9 +62,11 @@ def cli(ctx, debug, telemetry): @click.option("--ignore", "-i", multiple=True, type=str, default=[], callback=transform_ignore, help="Ignore one (or multiple) vulnerabilities by ID. Default: empty") @click.option('--json/--no-json', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "bare"], - with_values={"output": ['screen', 'text', 'bare', 'json'], "bare": [True, False]}, callback=json_alias) + with_values={"output": ['screen', 'text', 'bare', 'json'], "bare": [True, False]}, callback=json_alias, + hidden=True) @click.option('--bare/--not-bare', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "json"], - with_values={"output": ['screen', 'text', 'bare', 'json'], "json": [True, False]}, callback=bare_alias) + with_values={"output": ['screen', 'text', 'bare', 'json'], "json": [True, False]}, callback=bare_alias, + hidden=True) @click.option('--output', "-o", type=click.Choice(['screen', 'text', 'json', 'bare'], case_sensitive=False), default='screen', callback=active_color_if_needed, envvar='SAFETY_OUTPUT') @click.option("--proxy-protocol", "-pr", type=click.Choice(['http', 'https']), default='https', cls=DependentOption, required_options=['proxy_host'], @@ -70,15 +79,19 @@ def cli(ctx, debug, telemetry): help="Output standard exit codes. Default: --exit-code") @click.option("--policy-file", type=SafetyPolicyFile(), default='.safety-policy.yml', help="Define the policy file to be used") -@click.option("--save-json", default="", help="Path to where output file will be placed. Default: empty") +@click.option("--save-json", default="", help="Path to where output file will be placed, if the path is a directory, " + "Safety will use safety-report.json as filename. Default: empty") @click.pass_context def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, bare, proxy_protocol, proxy_host, proxy_port, exit_code, policy_file, save_json): + """ + Find vulnerabilities in Python dependencies at the target provided. + + """ LOG.info('Running check command') try: packages = get_packages(files, stdin) - ctx.obj = packages proxy_dictionary = get_proxy_dict(proxy_protocol, proxy_host, proxy_port) announcements = [] @@ -89,14 +102,15 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, ignore_severity_rules = None ignore, ignore_severity_rules, exit_code = get_processed_options(policy_file, ignore, ignore_severity_rules, exit_code) - ctx.continue_on_error = not exit_code - ctx.ignore_severity_rules = ignore_severity_rules is_env_scan = not stdin and not files + params = {'stdin': stdin, 'files': files, 'policy_file': policy_file, 'continue_on_error': not exit_code, + 'ignore_severity_rules': ignore_severity_rules} LOG.info('Calling the check function') vulns, db_full = safety.check(packages=packages, key=key, db_mirror=db, cached=cache, ignore_vulns=ignore, ignore_severity_rules=ignore_severity_rules, proxy=proxy_dictionary, - include_ignored=True, is_env_scan=is_env_scan, telemetry=ctx.parent.telemetry) + include_ignored=True, is_env_scan=is_env_scan, telemetry=ctx.parent.telemetry, + params=params) LOG.debug('Vulnerabilities returned: %s', vulns) LOG.debug('full database returned is None: %s', db_full is None) @@ -117,11 +131,15 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, LOG.info('All vulnerabilities found (ignored and Not ignored): %s', len(vulns)) if save_json: + default_name = 'safety-report.json' json_report = output_report if output != 'json': json_report = SafetyFormatter(output='json').render_vulnerabilities(announcements, vulns, remediations, full_report, packages) + if os.path.isdir(save_json): + save_json = os.path.join(save_json, default_name) + with open(save_json, 'w+') as output_json_file: output_json_file.write(json_report) @@ -152,6 +170,9 @@ def check(ctx, key, db, full_report, stdin, files, cache, ignore, output, json, help="Read input from an insecure report file. Default: empty") @click.pass_context def review(ctx, full_report, output, file): + """ + Show an output from a previous exported JSON report. + """ LOG.info('Running check command') announcements = safety.get_announcements(key=None, proxy=None, telemetry=ctx.parent.telemetry) report = {} @@ -166,7 +187,8 @@ def review(ctx, full_report, output, file): exception = e if isinstance(e, SafetyException) else SafetyException(info=e) output_exception(exception, exit_code_output=True) - vulns, remediations, packages = safety.review(report) + params = {'file': file} + vulns, remediations, packages = safety.review(report, params=params) output_report = SafetyFormatter(output=output).render_vulnerabilities(announcements, vulns, remediations, full_report, packages) @@ -197,6 +219,9 @@ def review(ctx, full_report, output, file): help="Proxy protocol (https or http) --proxy-protocol") @click.pass_context def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyport): + """ + Find the open source licenses used by your Python dependencies. + """ LOG.info('Running license command') packages = get_packages(files, False) ctx.obj = packages @@ -230,7 +255,7 @@ def license(ctx, key, db, output, cache, files, proxyprotocol, proxyhost, proxyp @click.argument('name') @click.pass_context def generate(ctx, name, path): - """create a basic supported file type. + """Create a boilerplate supported file type. NAME is the name of the file type to generate. Valid values are: policy_file """ @@ -270,7 +295,7 @@ def generate(ctx, name, path): @click.argument('name') @click.pass_context def validate(ctx, name, path): - """verify a supported file type. + """Verify the validity of a supported file type. NAME is the name of the file type to validate. Valid values are: policy_file """ diff --git a/safety/output_utils.py b/safety/output_utils.py index 3ef07dc0..827381c5 100644 --- a/safety/output_utils.py +++ b/safety/output_utils.py @@ -6,7 +6,7 @@ import click from safety.constants import RED, YELLOW -from safety.util import get_safety_version, Package, get_terminal_size +from safety.util import get_safety_version, Package, get_terminal_size, SafetyContext LOG = logging.getLogger(__name__) @@ -136,7 +136,7 @@ def format_vulnerability(vulnerability, full_mode, only_text=False, columns=get_ {'value': vulnerability.advisory.replace('\n', '')}]} ] - if click.get_current_context().params.get('key', False): + if SafetyContext().key: fixed_version_line = {'words': [ {'style': {'bold': True}, 'value': 'Fixed versions: '}, {'value': ', '.join(vulnerability.fixed_versions) if vulnerability.fixed_versions else 'No known fix'} @@ -357,12 +357,13 @@ def format_long_text(text, color='', columns=get_terminal_size().columns, start_ def get_printable_list_of_scanned_items(scanning_target): - context = click.get_current_context() + context = SafetyContext() + result = [] scanned_items_data = [] if scanning_target == 'environment': - locations = set([pkg.found for pkg in context.obj if isinstance(pkg, Package)]) + locations = set([pkg.found for pkg in context.packages if isinstance(pkg, Package)]) for path in locations: result.append([{'styled': False, 'value': '-> ' + path}]) @@ -374,7 +375,7 @@ def get_printable_list_of_scanned_items(scanning_target): scanned_items_data.append(msg) elif scanning_target == 'stdin': - scanned_stdin = [pkg.name for pkg in context.obj if isinstance(pkg, Package)] + scanned_stdin = [pkg.name for pkg in context.packages if isinstance(pkg, Package)] value = 'No found packages in stdin' scanned_items_data = [value] @@ -435,8 +436,9 @@ def build_report_brief_section(columns=None, primary_announcement=None, report_t def build_report_for_review_vuln_report(as_dict=False): - report_from_file = click.get_current_context().review - packages = click.get_current_context().obj + ctx = SafetyContext() + report_from_file = ctx.review + packages = ctx.packages if as_dict: return report_from_file @@ -509,15 +511,15 @@ def build_scanned_count_sentence(packages): def add_warnings_if_needed(brief_info): - ctx = click.get_current_context() + ctx = SafetyContext() warnings = [] - if ctx.obj: - if hasattr(ctx, 'continue_on_error') and ctx.continue_on_error: + if ctx.packages: + if ctx.params.get('continue_on_error', False): warnings += [[{'style': True, 'value': '* Continue-on-error is enabled, so returning successful (0) exit code in all cases.'}]] - if hasattr(ctx, 'ignore_severity_rules') and ctx.ignore_severity_rules and not is_using_api_key(): + if ctx.params.get('ignore_severity_rules', False) and not is_using_api_key(): warnings += [[{'style': True, 'value': '* Could not filter by severity, please upgrade your account to include severity data.'}]] @@ -528,18 +530,18 @@ def add_warnings_if_needed(brief_info): def get_report_brief_info(as_dict=False, report_type=1, **kwargs): LOG.info('get_report_brief_info: %s, %s, %s', as_dict, report_type, kwargs) - context = click.get_current_context() + context = SafetyContext() - packages = [pkg for pkg in context.obj if isinstance(pkg, Package)] + packages = [pkg for pkg in context.packages if isinstance(pkg, Package)] brief_data = {} - command = context.command.name + command = context.command if command == 'review': review = build_report_for_review_vuln_report(as_dict) return review - key = context.params.get('key', False) - db = context.params.get('db', False) + key = context.key + db = context.db_mirror scanning_types = {'check': {'name': 'Vulnerabilities', 'action': 'Scanning dependencies', 'scanning_target': 'environment'}, # Files, Env or Stdin 'license': {'name': 'Licenses', 'action': 'Scanning licenses', 'scanning_target': 'environment'}, # Files or Env @@ -552,14 +554,14 @@ def get_report_brief_info(as_dict=False, report_type=1, **kwargs): scanning_types[command]['scanning_target'] = target break - scanning_target = scanning_types.get(context.command.name, {}).get('scanning_target', '') + scanning_target = scanning_types.get(context.command, {}).get('scanning_target', '') brief_data['scan_target'] = scanning_target scanned_items, data = get_printable_list_of_scanned_items(scanning_target) brief_data['scanned'] = data nl = [{'style': False, 'value': ''}] action_executed = [ - {'style': True, 'value': scanning_types.get(context.command.name, {}).get('action', '')}, + {'style': True, 'value': scanning_types.get(context.command, {}).get('action', '')}, {'style': False, 'value': ' in your '}, {'style': True, 'value': scanning_target + ':'}, ] @@ -618,7 +620,7 @@ def get_report_brief_info(as_dict=False, report_type=1, **kwargs): brief_info = [[{'style': False, 'value': 'Safety '}, {'style': True, 'value': 'v' + get_safety_version()}, {'style': False, 'value': ' is scanning for '}, - {'style': True, 'value': scanning_types.get(context.command.name, {}).get('name', '')}, + {'style': True, 'value': scanning_types.get(context.command, {}).get('name', '')}, {'style': True, 'value': '...'}] + safety_policy_used, action_executed ] + [nl] + scanned_items + [nl] + [using_sentence] + [scanned_count_sentence] + [timestamp] @@ -650,15 +652,11 @@ def build_primary_announcement(primary_announcement, columns=None, only_text=Fal def is_using_api_key(): - context = click.get_current_context() - review_used_api_key = context.review.get('api_key', False) if hasattr(context, - 'review') and context.review else False - return bool(context.params.get('key', None)) or review_used_api_key + return bool(SafetyContext().key) def is_using_a_safety_policy_file(): - context = click.get_current_context() - return bool(context.params.get('policy_file', None)) + return bool(SafetyContext().params.get('policy_file', None)) def should_add_nl(output, found_vulns): diff --git a/safety/safety.py b/safety/safety.py index 0a1e980d..628e6192 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -8,7 +8,6 @@ import time from datetime import datetime -import click import requests from packaging.specifiers import SpecifierSet from packaging.utils import canonicalize_name @@ -20,7 +19,8 @@ InvalidKeyError, TooManyRequestsError, NetworkConnectionError, RequestTimeoutError, ServerError, MalformedDatabase) from .models import Vulnerability, CVE, Severity -from .util import RequirementFile, read_requirements, Package, build_telemetry_data +from .util import RequirementFile, read_requirements, Package, build_telemetry_data, sync_safety_context, SafetyContext, \ + validate_expiration_date session = requests.session() @@ -101,6 +101,9 @@ def fetch_database_url(mirror, db_name, key, cached, proxy, telemetry=True): if key: headers["X-Api-Key"] = key + if not proxy: + proxy = {} + if cached: cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached) if cached_data: @@ -148,7 +151,7 @@ def fetch_database_file(path, db_name): return json.loads(f.read()) -def fetch_database(full=False, key=False, db=False, cached=False, proxy={}, telemetry=True): +def fetch_database(full=False, key=False, db=False, cached=0, proxy=None, telemetry=True): if db: mirrors = [db] else: @@ -224,7 +227,7 @@ def get_cve_from(data, db_full): def ignore_vuln_if_needed(vuln_id, cve, ignore_vulns, ignore_severity_rules): - if not ignore_severity_rules: + if not ignore_severity_rules or not isinstance(ignore_vulns, dict): return severity = None @@ -248,9 +251,10 @@ def ignore_vuln_if_needed(vuln_id, cve, ignore_vulns, ignore_severity_rules): ignore_vulns[vuln_id] = {'reason': reason, 'expires': None} -def check(packages, key, db_mirror, cached, ignore_vulns, ignore_severity_rules, proxy, include_ignored=False, - is_env_scan=True, telemetry=True): - key = key if key else os.environ.get("SAFETY_API_KEY", False) +@sync_safety_context +def check(packages, key=False, db_mirror=False, cached=0, ignore_vulns=None, ignore_severity_rules=None, proxy=None, + include_ignored=False, is_env_scan=True, telemetry=True, params=None): + SafetyContext().command = 'check' db = fetch_database(key=key, db=db_mirror, cached=cached, proxy=proxy, telemetry=telemetry) db_full = None vulnerable_packages = frozenset(db.keys()) @@ -376,7 +380,9 @@ def calculate_remediations(vulns, db_full): return remediations -def review(report): +@sync_safety_context +def review(report=None, params=None): + SafetyContext().command = 'review' vulnerable = [] vulnerabilities = report.get('vulnerabilities', []) + report.get('ignored_vulnerabilities', []) remediations = {} @@ -398,8 +404,11 @@ def review(report): packages = report.get('scanned_packages', []) pkgs = {pkg_name: Package(**pkg_values) for pkg_name, pkg_values in packages.items()} - click.get_current_context().obj = pkgs.values() - click.get_current_context().review = report.get('report_meta', []) + ctx = SafetyContext() + found_packages = list(pkgs.values()) + ctx.packages = found_packages + ctx.review = report.get('report_meta', []) + ctx.key = ctx.review.get('api_key', False) cvssv2 = None cvssv3 = None @@ -417,14 +426,20 @@ def review(report): else: vuln['severity'] = None + ignored_expires = vuln.get('ignored_expires', None) + + if ignored_expires: + vuln['ignored_expires'] = validate_expiration_date(ignored_expires) + vuln['CVE'] = CVE(name=XVE_ID, cvssv2=cvssv2, cvssv3=cvssv3) if XVE_ID else None vulnerable.append(Vulnerability(**vuln)) - return vulnerable, remediations, pkgs + return vulnerable, remediations, found_packages -def get_licenses(key, db_mirror, cached, proxy, telemetry=True): +@sync_safety_context +def get_licenses(key=False, db_mirror=False, cached=0, proxy=None, telemetry=True): key = key if key else os.environ.get("SAFETY_API_KEY", False) if not key and not db_mirror: diff --git a/safety/util.py b/safety/util.py index b6afdc00..20a3c073 100644 --- a/safety/util.py +++ b/safety/util.py @@ -4,6 +4,7 @@ import sys from datetime import datetime from difflib import SequenceMatcher +from threading import Lock from typing import List import click @@ -178,12 +179,13 @@ def get_packages_licenses(packages, licenses_db): def get_flags_from_context(): flags = {} - context = click.get_current_context() + context = click.get_current_context(silent=True) - for option in context.command.params: - flags_per_opt = option.opts + option.secondary_opts - for flag in flags_per_opt: - flags[flag] = option.name + if context: + for option in context.command.params: + flags_per_opt = option.opts + option.secondary_opts + for flag in flags_per_opt: + flags[flag] = option.name return flags @@ -231,12 +233,14 @@ def get_basic_announcements(announcements): def build_telemetry_data(telemetry=True): + context = SafetyContext() + body = { 'os_type': os.environ.get("SAFETY_OS_TYPE", None) or platform.system(), 'os_release': os.environ.get("SAFETY_OS_RELEASE", None) or platform.release(), 'os_description': os.environ.get("SAFETY_OS_DESCRIPTION", None) or platform.platform(), 'python_version': platform.python_version(), - 'safety_command': click.get_current_context().command.name, + 'safety_command': context.command, 'safety_options': get_used_options() } if telemetry else {} @@ -390,6 +394,23 @@ def get_terminal_size(): return os.terminal_size((columns, lines)) +def validate_expiration_date(expiration_date): + d = None + + if expiration_date: + try: + d = datetime.strptime(expiration_date, '%Y-%m-%d') + except ValueError as e: + pass + + try: + d = datetime.strptime(expiration_date, '%Y-%m-%d %H:%M:%S') + except ValueError as e: + pass + + return d + + class SafetyPolicyFile(click.ParamType): """ Custom Safety Policy file to hold validations @@ -528,24 +549,14 @@ def convert(self, value, param, ctx): ) # Validate expires - d = None - if expires: - try: - d = datetime.strptime(expires, '%Y-%m-%d') - except ValueError as e: - pass - - try: - d = datetime.strptime(expires, '%Y-%m-%d %H:%M:%S') - except ValueError as e: - pass - - if not d: - self.fail(msg.format(hint=f"{context_msg}expires: \"{expires}\" isn't a valid format " - f"for the expires keyword, " - "valid options are: YYYY-MM-DD or " - "YYYY-MM-DD HH:MM:SS") - ) + d = validate_expiration_date(expires) + + if expires and not d: + self.fail(msg.format(hint=f"{context_msg}expires: \"{expires}\" isn't a valid format " + f"for the expires keyword, " + "valid options are: YYYY-MM-DD or " + "YYYY-MM-DD HH:MM:SS") + ) normalized[str(ignored_vuln_id)] = {'reason': reason, 'expires': d} @@ -583,3 +594,48 @@ def shell_complete( from click.shell_completion import CompletionItem return [CompletionItem(incomplete, type="file")] + + +class SingletonMeta(type): + + _instances = {} + + _lock = Lock() + + def __call__(cls, *args, **kwargs): + with cls._lock: + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class SafetyContext(metaclass=SingletonMeta): + packages = None + key = False + db_mirror = False + cached = None + ignore_vulns = None + ignore_severity_rules = None + proxy = None + include_ignored = False + telemetry = None + files = None + stdin = None + is_env_scan = None + command = None + review = None + params = {} + + +def sync_safety_context(f): + def new_func(*args, **kwargs): + ctx = SafetyContext() + + for attr in dir(ctx): + if attr in kwargs: + setattr(ctx, attr, kwargs.get(attr)) + + return f(*args, **kwargs) + + return new_func diff --git a/tests/test_cli.py b/tests/test_cli.py index 62a4303a..64cc8c40 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import os import tempfile import unittest +from datetime import datetime from unittest.mock import patch, Mock import click @@ -9,7 +10,7 @@ from safety import cli from safety.models import Vulnerability, CVE, Severity -from safety.util import Package +from safety.util import Package, SafetyContext def get_vulnerability(vuln_kwargs=None, cve_kwargs=None, pkg_kwargs=None): @@ -136,6 +137,38 @@ def test_review_pass(self, mocked_announcements): self.assertEqual(result.exit_code, 0) self.assertEqual(result.output, u'insecure-package\n') + @patch("safety.util.SafetyContext") + @patch("safety.safety.check") + @patch("safety.cli.get_packages") + def test_chained_review_pass(self, get_packages, check_func, ctx): + expires = datetime.strptime('2022-10-21', '%Y-%m-%d') + vulns = [get_vulnerability(), get_vulnerability(vuln_kwargs={'vulnerability_id': '25853', 'ignored': True, + 'ignored_reason': 'A basic reason', + 'ignored_expires': expires})] + packages = [pkg for pkg in {vuln.pkg.name: vuln.pkg for vuln in vulns}.values()] + get_packages.return_value = packages + provided_context = SafetyContext() + provided_context.command = 'check' + provided_context.packages = packages + ctx.return_value = provided_context + check_func.return_value = vulns, None + + with tempfile.TemporaryDirectory() as tempdir: + for output in self.output_options: + path_to_report = os.path.join(tempdir, f'report_{output}.json') + + pre_result = self.runner.invoke(cli.cli, [ + 'check', '--key', 'foo', '-o', output, + '--save-json', path_to_report]) + + self.assertEqual(pre_result.exit_code, 64) + + for output in self.output_options: + filename = f'report_{output}.json' + path_to_report = os.path.join(tempdir, filename) + result = self.runner.invoke(cli.cli, ['review', '--output', output, '--file', path_to_report]) + self.assertEqual(result.exit_code, 0, f'Unable to load the previous saved report: {filename}') + @patch("safety.safety.session") def test_license_with_file(self, requests_session): licenses_db = { diff --git a/tests/test_output_utils.py b/tests/test_output_utils.py index fc292f72..571d49c4 100644 --- a/tests/test_output_utils.py +++ b/tests/test_output_utils.py @@ -17,7 +17,6 @@ class TestOutputUtils(unittest.TestCase): def setUp(self) -> None: self.maxDiff = None - @patch.object(click, 'get_current_context', Mock(params=Mock(key=Mock(return_value='foobar')))) def test_format_vulnerability(self): numpy_pkg = {'name': 'numpy', 'version': '1.22.0', 'secure_versions': ['1.22.3'], 'insecure_versions': ['1.22.2', '1.22.1', '1.22.0', '1.22.0rc3', '1.21.5']} @@ -57,7 +56,6 @@ def test_format_vulnerability(self): EXPECTED = '\n'.join(lines) self.assertEqual(output, EXPECTED) - @patch.object(click, 'get_current_context', Mock(params=Mock(key=Mock(return_value='foobar')))) def test_format_vulnerability_with_ignored_vulnerability(self): numpy_pkg = {'name': 'numpy', 'version': '1.22.0', 'secure_versions': ['1.22.3'], 'insecure_versions': ['1.22.2', '1.22.1', '1.22.0', '1.22.0rc3', '1.21.5']} @@ -117,9 +115,9 @@ def test_format_vulnerability_with_ignored_vulnerability(self): EXPECTED = '\n'.join(lines) self.assertEqual(output, EXPECTED) - @patch("safety.output_utils.click.get_current_context") + @patch("safety.output_utils.SafetyContext") def test_get_printable_list_of_scanned_items_stdin(self, ctx): - ctx.return_value = Mock(obj=[]) + ctx.return_value = Mock(packages=[]) output = get_printable_list_of_scanned_items('stdin') EXPECTED = ( @@ -131,7 +129,7 @@ def test_get_printable_list_of_scanned_items_stdin(self, ctx): p_kwargs = {'name': 'django', 'version': '2.2', 'found': '/site-packages/django', 'insecure_versions': [], 'secure_versions': ['2.2'], 'latest_version_without_known_vulnerabilities': '2.2', 'latest_version': '2.2', 'more_info_url': 'https://pyup.io/package/foo'} - ctx.return_value = Mock(obj=[Package(**p_kwargs)]) + ctx.return_value = Mock(packages=[Package(**p_kwargs)]) output = get_printable_list_of_scanned_items('stdin') EXPECTED = ( @@ -140,9 +138,9 @@ def test_get_printable_list_of_scanned_items_stdin(self, ctx): self.assertTupleEqual(output, EXPECTED) - @patch("safety.output_utils.click.get_current_context") + @patch("safety.output_utils.SafetyContext") def test_get_printable_list_of_scanned_items_environment(self, ctx): - ctx.return_value = Mock(obj=[]) + ctx.return_value = Mock(packages=[]) output = get_printable_list_of_scanned_items('environment') no_locations = 'No locations found in the environment' @@ -153,7 +151,7 @@ def test_get_printable_list_of_scanned_items_environment(self, ctx): self.assertTupleEqual(output, EXPECTED) - @patch("safety.output_utils.click.get_current_context") + @patch("safety.output_utils.SafetyContext") def test_get_printable_list_of_scanned_items_files(self, ctx): dirname = os.path.dirname(__file__) file_a = open(os.path.join(dirname, "reqs_1.txt"), mode='r') @@ -170,7 +168,7 @@ def test_get_printable_list_of_scanned_items_files(self, ctx): self.assertTupleEqual(output, EXPECTED) - @patch("safety.output_utils.click.get_current_context") + @patch("safety.output_utils.SafetyContext") def test_get_printable_list_of_scanned_items_file(self, ctx): # Used by the review command report = open(os.path.join( diff --git a/tests/test_safety.py b/tests/test_safety.py index b206ad44..cecc7891 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -28,6 +28,7 @@ from safety.models import CVE from safety.safety import ignore_vuln_if_needed, get_closest_ver, precompute_remediations, compute_sec_ver, \ calculate_remediations, read_vulnerabilities +from safety.util import SafetyContext from tests.resources import VALID_REPORT, VULNS, SCANNED_PACKAGES, REMEDIATIONS from tests.test_cli import get_vulnerability @@ -643,14 +644,12 @@ def test_read_vulnerabilities(self): with open(os.path.join(self.dirname, "test_db", "report.json")) as f: self.assertDictEqual(self.report, read_vulnerabilities(f)) - @patch.object(click, 'get_current_context', Mock()) def test_review_without_recommended_fix(self): vulns, remediations, packages = safety.review(self.report) - self.assertDictEqual(packages, self.report_packages) + self.assertListEqual(packages, list(self.report_packages.values())) self.assertDictEqual(remediations, self.report_remediations) self.assertListEqual(vulns, self.report_vulns) - @patch.object(click, 'get_current_context', Mock()) def test_report_with_recommended_fix(self): REMEDIATIONS_WITH_FIX = {'django': {'version': '4.0.1', 'vulns_found': 4, 'secure_versions': ['2.2.28', '3.2.13', '4.0.4'], 'closest_secure_version': {'major': parse('4.0.4'), From 3a5d425df52c22923bb99a91d578721b11950339 Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Thu, 23 Jun 2022 00:44:39 -0500 Subject: [PATCH 2/2] Improving the CLI flags --- safety/cli.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/safety/cli.py b/safety/cli.py index 27c59915..b0329c33 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -55,18 +55,19 @@ def cli(ctx, debug, telemetry): @click.option("--cache", is_flag=False, flag_value=60, default=0, help="Cache requests to the vulnerability database locally. Default: 0 seconds", hidden=True) -@click.option("--stdin/--no-stdin", default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["files"], - help="Read input from stdin. Default: --no-stdin") -@click.option("files", "--file", "-r", multiple=True, type=click.File(), cls=MutuallyExclusiveOption, mutually_exclusive=["stdin"], +@click.option("--stdin", default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["files"], + help="Read input from stdin.", is_flag=True, show_default=True) +@click.option("files", "--file", "-r", multiple=True, type=click.File(), cls=MutuallyExclusiveOption, + mutually_exclusive=["stdin"], help="Read input from one (or multiple) requirement files. Default: empty") @click.option("--ignore", "-i", multiple=True, type=str, default=[], callback=transform_ignore, help="Ignore one (or multiple) vulnerabilities by ID. Default: empty") -@click.option('--json/--no-json', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "bare"], +@click.option('--json', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "bare"], with_values={"output": ['screen', 'text', 'bare', 'json'], "bare": [True, False]}, callback=json_alias, - hidden=True) -@click.option('--bare/--not-bare', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "json"], + hidden=True, is_flag=True, show_default=True) +@click.option('--bare', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "json"], with_values={"output": ['screen', 'text', 'bare', 'json'], "json": [True, False]}, callback=bare_alias, - hidden=True) + hidden=True, is_flag=True, show_default=True) @click.option('--output', "-o", type=click.Choice(['screen', 'text', 'json', 'bare'], case_sensitive=False), default='screen', callback=active_color_if_needed, envvar='SAFETY_OUTPUT') @click.option("--proxy-protocol", "-pr", type=click.Choice(['http', 'https']), default='https', cls=DependentOption, required_options=['proxy_host'],