Skip to content

Commit

Permalink
Add type hints for args
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Jan 19, 2021
1 parent ebc95e2 commit 53ebae0
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions piplicenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from email.parser import FeedParser
from enum import Enum, auto
from functools import partial
from typing import List, Optional, Text
from typing import List, Optional, Sequence, Text

try:
from pip._internal.utils.misc import get_installed_distributions
Expand Down Expand Up @@ -129,7 +129,7 @@
LICENSE_UNKNOWN = 'UNKNOWN'


def get_packages(args):
def get_packages(args: "CustomNamespace"):

def get_pkg_included_file(pkg, file_names):
"""
Expand Down Expand Up @@ -259,7 +259,8 @@ def get_pkg_info(pkg):
yield pkg_info


def create_licenses_table(args, output_fields=DEFAULT_OUTPUT_FIELDS):
def create_licenses_table(
args: "CustomNamespace", output_fields=DEFAULT_OUTPUT_FIELDS):
table = factory_styled_table_with_args(args, output_fields)
from_source = getattr(args, 'from')

Expand All @@ -282,7 +283,7 @@ def create_licenses_table(args, output_fields=DEFAULT_OUTPUT_FIELDS):
return table


def create_summary_table(args):
def create_summary_table(args: "CustomNamespace"):
counts = Counter(pkg['license'] for pkg in get_packages(args))

table = factory_styled_table_with_args(args, SUMMARY_FIELD_NAMES)
Expand Down Expand Up @@ -406,7 +407,8 @@ def get_string(self, **kwargs):
return output


def factory_styled_table_with_args(args, output_fields=DEFAULT_OUTPUT_FIELDS):
def factory_styled_table_with_args(
args: "CustomNamespace", output_fields=DEFAULT_OUTPUT_FIELDS):
table = PrettyTable()
table.field_names = output_fields
table.align = 'l'
Expand Down Expand Up @@ -457,7 +459,7 @@ def select_license_by_source(from_source, license_classifier, license_meta):
return license_meta


def get_output_fields(args):
def get_output_fields(args: "CustomNamespace"):
if args.summary:
return list(SUMMARY_OUTPUT_FIELDS)

Expand Down Expand Up @@ -492,7 +494,7 @@ def get_output_fields(args):
return output_fields


def get_sortby(args):
def get_sortby(args: "CustomNamespace"):
if args.summary and args.order == OrderArg.COUNT:
return 'Count'
elif args.summary or args.order == OrderArg.LICENSE:
Expand All @@ -507,7 +509,7 @@ def get_sortby(args):
return 'Name'


def create_output_string(args):
def create_output_string(args: "CustomNamespace"):
output_fields = get_output_fields(args)

if args.summary:
Expand All @@ -523,7 +525,7 @@ def create_output_string(args):
return table.get_string(fields=output_fields, sortby=sortby)


def create_warn_string(args):
def create_warn_string(args: "CustomNamespace"):
warn_messages = []
warn = partial(output_colored, '33')

Expand Down Expand Up @@ -580,13 +582,34 @@ def _split_lines(self, text: Text, width: int) -> List[str]:
return super()._split_lines(text, width)


class CustomNamespace(argparse.Namespace):
_from: "FromArg"
order: "OrderArg"
format: "FormatArg"
summary: bool
output_file: str
ignore_packages: List[str]
with_system: bool
with_authors: bool
with_urls: bool
with_description: bool
with_license_file: bool
no_license_path: bool
with_notice_file: bool
filter_strings: bool
filter_code_page: str
fail_on: str
allow_only: str


class CompatibleArgumentParser(argparse.ArgumentParser):
def parse_args(self, args=None, namespace=None):
def parse_args(self, args: Optional[Sequence[Text]] = None,
namespace: CustomNamespace = None) -> CustomNamespace:
args = super().parse_args(args, namespace)
self._verify_args(args)
return args

def _verify_args(self, args):
def _verify_args(self, args: CustomNamespace):
if args.with_license_file is False and (
args.no_license_path is True or
args.with_notice_file is True):
Expand Down

0 comments on commit 53ebae0

Please sign in to comment.