diff --git a/piplicenses.py b/piplicenses.py index 2ff86a9..6de8afb 100644 --- a/piplicenses.py +++ b/piplicenses.py @@ -32,6 +32,7 @@ import os import sys from collections import Counter +from enum import Enum, auto from email import message_from_string from email.parser import FeedParser from functools import partial @@ -411,26 +412,26 @@ def factory_styled_table_with_args(args, output_fields=DEFAULT_OUTPUT_FIELDS): table = PrettyTable() table.field_names = output_fields table.align = 'l' - table.border = (args.format == 'markdown' or args.format == 'rst' or - args.format == 'confluence' or args.format == 'json') + table.border = args.format in (FormatArg.MARKDOWN, FormatArg.RST, + FormatArg.CONFLUENCE, FormatArg.JSON) table.header = True - if args.format == 'markdown': + if args.format == FormatArg.MARKDOWN: table.junction_char = '|' table.hrules = RULE_HEADER - elif args.format == 'rst': + elif args.format == FormatArg.RST: table.junction_char = '+' table.hrules = RULE_ALL - elif args.format == 'confluence': + elif args.format == FormatArg.CONFLUENCE: table.junction_char = '|' table.hrules = RULE_NONE - elif args.format == 'json': + elif args.format == FormatArg.JSON: table = JsonPrettyTable(table.field_names) - elif args.format == 'json-license-finder': + elif args.format == FormatArg.JSON_LICENSE_FINDER: table = JsonLicenseFinderTable(table.field_names) - elif args.format == 'csv': + elif args.format == FormatArg.CSV: table = CSVPrettyTable(table.field_names) - elif args.format == 'plain-vertical': + elif args.format == FormatArg.PLAIN_VERTICAL: table = PlainVerticalTable(table.field_names) return table @@ -451,8 +452,8 @@ def find_license_from_classifier(message): def select_license_by_source(from_source, license_classifier, license_meta): license_classifier_str = ', '.join(license_classifier) or LICENSE_UNKNOWN - if (from_source == 'classifier' or - from_source == 'mixed' and len(license_classifier) > 0): + if (from_source == FromArg.CLASSIFIER or + from_source == FromArg.MIXED and len(license_classifier) > 0): return license_classifier_str else: return license_meta @@ -464,7 +465,7 @@ def get_output_fields(args): output_fields = list(DEFAULT_OUTPUT_FIELDS) - if getattr(args, 'from') == 'all': + if getattr(args, 'from') == FromArg.ALL: output_fields.append('License-Metadata') output_fields.append('License-Classifier') else: @@ -494,15 +495,15 @@ def get_output_fields(args): def get_sortby(args): - if args.summary and args.order == 'count': + if args.summary and args.order == OrderArg.COUNT: return 'Count' - elif args.summary or args.order == 'license': + elif args.summary or args.order == OrderArg.LICENSE: return 'License' - elif args.order == 'name': + elif args.order == OrderArg.NAME: return 'Name' - elif args.order == 'author' and args.with_authors: + elif args.order == OrderArg.AUTHOR and args.with_authors: return 'Author' - elif args.order == 'url' and args.with_urls: + elif args.order == OrderArg.URL and args.with_urls: return 'URL' return 'Name' @@ -518,7 +519,7 @@ def create_output_string(args): sortby = get_sortby(args) - if args.format == 'html': + if args.format == FormatArg.HTML: return table.get_html_string(fields=output_fields, sortby=sortby) else: return table.get_string(fields=output_fields, sortby=sortby) @@ -528,7 +529,7 @@ def create_warn_string(args): warn_messages = [] warn = partial(output_colored, '33') - if args.with_license_file and not args.format == 'json': + if args.with_license_file and not args.format == FormatArg.JSON: message = warn(('Due to the length of these fields, this option is ' 'best paired with --format=json.')) warn_messages.append(message) @@ -579,7 +580,6 @@ class CompatibleArgumentParser(argparse.ArgumentParser): def parse_args(self, args=None, namespace=None): args = super(CompatibleArgumentParser, self).parse_args(args, namespace) - self._compatible_format_args(args) self._check_code_page(args.filter_code_page) return args @@ -595,60 +595,65 @@ def _check_code_page(code_page): "codecs.html for valid code pages") % code_page) sys.exit(1) - @staticmethod - def _compatible_format_args(args): - from_input = getattr(args, 'from').lower() - order_input = args.order.lower() - format_input = args.format.lower() - - # XXX: Use enum when drop support Python 2.7 - if from_input in ('meta', 'm'): - setattr(args, 'from', 'meta') - if from_input in ('classifier', 'c'): - setattr(args, 'from', 'classifier') +class NoValueEnum(Enum): + def __repr__(self): # pragma: no cover + return '<%s.%s>' % (self.__class__.__name__, self.name) - if from_input in ('mixed', 'mix'): - setattr(args, 'from', 'mixed') - if order_input in ('count', 'c'): - args.order = 'count' +class FromArg(NoValueEnum): + META = M = auto() + CLASSIFIER = C = auto() + MIXED = MIX = auto() + ALL = auto() - if order_input in ('license', 'l'): - args.order = 'license' - if order_input in ('name', 'n'): - args.order = 'name' +class OrderArg(NoValueEnum): + COUNT = C = auto() + LICENSE = L = auto() + NAME = N = auto() + AUTHOR = A = auto() + URL = U = auto() - if order_input in ('author', 'a'): - args.order = 'author' - if order_input in ('url', 'u'): - args.order = 'url' +class FormatArg(NoValueEnum): + PLAIN = P = auto() + PLAIN_VERTICAL = auto() + MARKDOWN = MD = M = auto() + RST = REST = R = auto() + CONFLUENCE = C = auto() + HTML = H = auto() + JSON = J = auto() + JSON_LICENSE_FINDER = JLF = auto() + CSV = auto() - if format_input in ('plain', 'p'): - args.format = 'plain' - if format_input in ('markdown', 'md', 'm'): - args.format = 'markdown' +def format_arg_option(value: str) -> str: + return value.replace('-', '_').upper() - if format_input in ('rst', 'rest', 'r'): - args.format = 'rst' - if format_input in ('confluence', 'c'): - args.format = 'confluence' +def choices_from_enum(enum_cls: NoValueEnum) -> List[str]: + return [key.replace('_', '-').lower() + for key in enum_cls.__members__.keys()] - if format_input in ('html', 'h'): - args.format = 'html' - if format_input in ('json', 'j'): - args.format = 'json' +map_dest_to_enum = { + 'from': FromArg, + 'order': OrderArg, + 'format': FormatArg, +} - if format_input in ('json-license-finder', 'jlf'): - args.format = 'json-license-finder' - if format_input in ('csv', ): - args.format = 'csv' +class SelectAction(argparse.Action): + def __call__( + self, parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Text, + option_string: Optional[Text] = None, + ) -> None: + enum_cls = map_dest_to_enum[self.dest] + values = format_arg_option(values) + setattr(namespace, self.dest, getattr(enum_cls, values)) def create_parser(): @@ -667,22 +672,25 @@ def create_parser(): common_options.add_argument( '--from', - action='store', type=str, - default='mixed', metavar='SOURCE', + action=SelectAction, type=str, + default=FromArg.MIXED, metavar='SOURCE', + choices=choices_from_enum(FromArg), help='R|where to find license information\n' '"meta", "classifier, "mixed", "all"\n' '(default: %(default)s)') common_options.add_argument( '-o', '--order', - action='store', type=str, - default='name', metavar='COL', + action=SelectAction, type=str, + default=OrderArg.NAME, metavar='COL', + choices=choices_from_enum(OrderArg), help='R|order by column\n' '"name", "license", "author", "url"\n' '(default: %(default)s)') common_options.add_argument( '-f', '--format', - action='store', type=str, - default='plain', metavar='STYLE', + action=SelectAction, type=str, + default=FormatArg.PLAIN, metavar='STYLE', + choices=choices_from_enum(FormatArg), help='R|dump as set format style\n' '"plain", "plain-vertical" "markdown", "rst", \n' '"confluence", "html", "json", \n' diff --git a/test_piplicenses.py b/test_piplicenses.py index b1dd260..ec8d17f 100644 --- a/test_piplicenses.py +++ b/test_piplicenses.py @@ -11,7 +11,7 @@ import docutils.frontend import piplicenses -from piplicenses import (__pkgname__, create_parser, output_colored, +from piplicenses import (FromArg, __pkgname__, create_parser, output_colored, create_licenses_table, get_output_fields, get_sortby, factory_styled_table_with_args, create_warn_string, find_license_from_classifier, create_output_string, @@ -203,22 +203,22 @@ def test_not_found_license_from_classifier(self): def test_select_license_by_source(self): self.assertEqual('MIT License', - select_license_by_source('classifier', + select_license_by_source(FromArg.CLASSIFIER, ['MIT License'], 'MIT')) self.assertEqual(LICENSE_UNKNOWN, - select_license_by_source('classifier', + select_license_by_source(FromArg.CLASSIFIER, [], 'MIT')) self.assertEqual('MIT License', - select_license_by_source('mixed', + select_license_by_source(FromArg.MIXED, ['MIT License'], 'MIT')) self.assertEqual('MIT', - select_license_by_source('mixed', + select_license_by_source(FromArg.MIXED, [], 'MIT'))