diff --git a/changelog.d/1018.change.rst b/changelog.d/1018.change.rst new file mode 100644 index 00000000..cd2f51ec --- /dev/null +++ b/changelog.d/1018.change.rst @@ -0,0 +1 @@ +Add CLI `ignore` option for refiners, providers and subtitle ids. diff --git a/changelog.d/585.change.rst b/changelog.d/585.change.rst new file mode 100644 index 00000000..cd2f51ec --- /dev/null +++ b/changelog.d/585.change.rst @@ -0,0 +1 @@ +Add CLI `ignore` option for refiners, providers and subtitle ids. diff --git a/docs/config.toml b/docs/config.toml index 5be91410..c8816c68 100644 --- a/docs/config.toml +++ b/docs/config.toml @@ -19,7 +19,8 @@ apikey = "xxxxxxxxx" [download] provider = ["addic7ed", "opensubtitlescom", "opensubtitles"] -refiner = ["metadata", "hash", "omdb", "tmdb"] +refiner = ["metadata", "hash", "omdb"] +ignore_refiner = ["tmdb"] language = ["fr", "en", "pt-br"] encoding = "utf-8" min_score = 50 diff --git a/pyproject.toml b/pyproject.toml index 0c347f4e..e814a500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,13 +102,12 @@ tvsubtitles = "subliminal.providers.tvsubtitles:TVsubtitlesProvider" hash = "subliminal.refiners.hash:refine" metadata = "subliminal.refiners.metadata:refine" omdb = "subliminal.refiners.omdb:refine" +tmdb = "subliminal.refiners.tmdb:refine" tvdb = "subliminal.refiners.tvdb:refine" [project.entry-points."babelfish.language_converters"] addic7ed = "subliminal.converters.addic7ed:Addic7edConverter" opensubtitlescom = "subliminal.converters.opensubtitlescom:OpenSubtitlesComConverter" -shooter = "subliminal.converters.shooter:ShooterConverter" -thesubdb = "subliminal.converters.thesubdb:TheSubDBConverter" tvsubtitles = "subliminal.converters.tvsubtitles:TVsubtitlesConverter" [tool.setuptools] diff --git a/subliminal/cli.py b/subliminal/cli.py index 8edc6841..cb832294 100644 --- a/subliminal/cli.py +++ b/subliminal/cli.py @@ -39,7 +39,9 @@ scan_videos, ) from subliminal.core import ARCHIVE_EXTENSIONS, scan_name, search_external_subtitles +from subliminal.extensions import default_providers, default_refiners from subliminal.score import match_hearing_impaired +from subliminal.utils import merge_extend_and_ignore_unions if TYPE_CHECKING: from collections.abc import Sequence @@ -125,11 +127,11 @@ def configure(ctx: click.Context, param: click.Parameter | None, filename: str | with open(filename, 'rb') as f: toml_dict = tomli.load(f) except tomli.TOMLDecodeError: - msg = f'Cannot read the configuration file at {filename}' + msg = f'Cannot read the configuration file at "{filename}"' else: - msg = f'Using configuration file at {filename}' + msg = f'Using configuration file at "{filename}"' else: - msg = 'Not using any configuration file.' + msg = f'Not using any configuration file, not a file "{filename}"' options = {} @@ -140,6 +142,17 @@ def configure(ctx: click.Context, param: click.Parameter | None, filename: str | # make download options download_dict = toml_dict.setdefault('download', {}) + # remove the provider and refiner lists to select, extend and ignore + provider_lists = { + 'select': download_dict.pop('provider', []), + 'extend': download_dict.pop('extend_provider', []), + 'ignore': download_dict.pop('ignore_provider', []), + } + refiner_lists = { + 'select': download_dict.pop('refiner', []), + 'extend': download_dict.pop('extend_refiner', []), + 'ignore': download_dict.pop('ignore_refiner', []), + } options['download'] = download_dict # make provider and refiner options @@ -147,7 +160,9 @@ def configure(ctx: click.Context, param: click.Parameter | None, filename: str | refiners_dict = toml_dict.setdefault('refiner', {}) ctx.obj = { - '__config__': {'dict': toml_dict, 'debug_message': msg}, + 'debug_message': msg, + 'provider_lists': provider_lists, + 'refiner_lists': refiner_lists, 'provider_configs': providers_dict, 'refiner_configs': refiners_dict, } @@ -165,9 +180,9 @@ def plural(quantity: int, name: str, *, bold: bool = True, **kwargs: Any) -> str AGE = AgeParamType() -PROVIDER = click.Choice(sorted(provider_manager.names())) +PROVIDER = click.Choice(['ALL', *sorted(provider_manager.names())]) -REFINER = click.Choice(sorted(refiner_manager.names())) +REFINER = click.Choice(['ALL', *sorted(refiner_manager.names())]) dirs = PlatformDirs('subliminal') cache_file = 'subliminal.dbm' @@ -257,7 +272,7 @@ def subliminal( logging.getLogger('subliminal').addHandler(handler) logging.getLogger('subliminal').setLevel(logging.DEBUG) # log about the config file - msg = ctx.obj['__config__']['debug_message'] + msg = ctx.obj['debug_message'] logger.info(msg) ctx.obj['debug'] = debug @@ -305,7 +320,54 @@ def cache(ctx: click.Context, clear_subliminal: bool) -> None: help='Language as IETF code, e.g. en, pt-BR (can be used multiple times).', ) @click.option('-p', '--provider', type=PROVIDER, multiple=True, help='Provider to use (can be used multiple times).') +@click.option( + '-pp', + '--extend-provider', + type=PROVIDER, + multiple=True, + help=( + 'Provider to use, on top of the default list (can be used multiple times). ' + 'Supersedes the providers used or ignored in the configuration file.' + ), +) +@click.option( + '-P', + '--ignore-provider', + type=PROVIDER, + multiple=True, + help=( + 'Provider to ignore (can be used multiple times). ' + 'Supersedes the providers used or ignored in the configuration file.' + ), +) @click.option('-r', '--refiner', type=REFINER, multiple=True, help='Refiner to use (can be used multiple times).') +@click.option( + '-rr', + '--extend-refiner', + type=REFINER, + multiple=True, + help=( + 'Refiner to use, on top of the default list (can be used multiple times). ' + 'Supersedes the refiners used or ignored in the configuration file.' + ), +) +@click.option( + '-R', + '--ignore-refiner', + type=REFINER, + multiple=True, + help=( + 'Refiner to ignore (can be used multiple times). ' + 'Supersedes the refiners used or ignored in the configuration file.' + ), +) +@click.option( + '-I', + '--ignore-subtitles', + type=click.STRING, + multiple=True, + help='Subtitle ids to ignore (can be used multiple times).', +) @click.option('-a', '--age', type=AGE, help='Filter videos newer than AGE, e.g. 12h, 1w2d.') @click.option( '--use_creation_time', @@ -381,7 +443,12 @@ def cache(ctx: click.Context, clear_subliminal: bool) -> None: def download( obj: dict[str, Any], provider: Sequence[str], + extend_provider: Sequence[str], + ignore_provider: Sequence[str], refiner: Sequence[str], + extend_refiner: Sequence[str], + ignore_refiner: Sequence[str], + ignore_subtitles: Sequence[str], language: Sequence[Language], age: timedelta | None, use_ctime: bool, @@ -420,6 +487,28 @@ def download( if debug: verbose = 3 + # parse list of refiners + use_providers = merge_extend_and_ignore_unions( + { + 'select': provider, + 'extend': extend_provider, + 'ignore': ignore_provider, + }, + obj['provider_lists'], + default_providers, + ) + logger.info('Use providers: %s', use_providers) + use_refiners = merge_extend_and_ignore_unions( + { + 'select': refiner, + 'extend': extend_refiner, + 'ignore': ignore_refiner, + }, + obj['refiner_lists'], + default_refiners, + ) + logger.info('Use refiners: %s', use_refiners) + # scan videos videos = [] ignored_videos = [] @@ -474,11 +563,10 @@ def download( if check_video(video, languages=language_set, age=age, use_ctime=use_ctime, undefined=single): refine( video, - episode_refiners=refiner, - movie_refiners=refiner, + refiners=use_refiners, refiner_configs=obj['refiner_configs'], embedded_subtitles=not force, - providers=provider, + providers=use_providers, languages=language_set, ) videos.append(video) @@ -516,11 +604,21 @@ def download( if not videos: return + # exit if no providers are used + if len(use_providers) == 0: + click.echo('No provider was selected to download subtitles.') + if 'ALL' in ignore_provider: + click.echo('All ignored from CLI argument: `--ignore-provider=ALL`') + elif 'ALL' in obj['provider_lists']['ignore']: + config_ignore = list(obj['provider_lists']['ignore']) + click.echo(f'All ignored from configuration: `ignore_provider={config_ignore}`') + return + # download best subtitles downloaded_subtitles = defaultdict(list) with AsyncProviderPool( max_workers=max_workers, - providers=provider, + providers=use_providers, provider_configs=obj['provider_configs'], ) as pp: with click.progressbar( @@ -540,6 +638,7 @@ def download( min_score=scores['hash'] * min_score // 100, hearing_impaired=hearing_impaired, only_one=single, + ignore_subtitles=ignore_subtitles, ) downloaded_subtitles[v] = subtitles diff --git a/subliminal/core.py b/subliminal/core.py index 81972dd5..b6d429ff 100644 --- a/subliminal/core.py +++ b/subliminal/core.py @@ -15,7 +15,14 @@ from guessit import guessit # type: ignore[import-untyped] from rarfile import BadRarFile, Error, NotRarFile, RarCannotExec, RarFile, is_rarfile # type: ignore[import-untyped] -from .extensions import default_providers, provider_manager, refiner_manager +from .extensions import ( + default_providers, + default_refiners, + discarded_episode_refiners, + discarded_movie_refiners, + provider_manager, + refiner_manager, +) from .score import compute_score as default_compute_score from .subtitle import SUBTITLE_EXTENSIONS, Subtitle from .utils import get_age, handle_exception @@ -65,7 +72,7 @@ def __init__( providers: Sequence[str] | None = None, provider_configs: Mapping[str, Any] | None = None, ) -> None: - self.providers = providers or default_providers + self.providers = providers if providers is not None else default_providers self.provider_configs = provider_configs or {} self.initialized_providers = {} self.discarded_providers = set() @@ -213,6 +220,7 @@ def download_best_subtitles( hearing_impaired: bool = False, only_one: bool = False, compute_score: ComputeScore | None = None, + ignore_subtitles: Sequence[str] | None = None, ) -> list[Subtitle]: """Download the best matching subtitles. @@ -227,11 +235,16 @@ def download_best_subtitles( :param bool only_one: download only one subtitle, not one per language. :param compute_score: function that takes `subtitle` and `video` as positional arguments, `hearing_impaired` as keyword argument and returns the score. + :param ignore_subtitles: list of subtitle ids to ignore (None defaults to an empty list). :return: downloaded subtitles. :rtype: list of :class:`~subliminal.subtitle.Subtitle` """ compute_score = compute_score or default_compute_score + ignore_subtitles = ignore_subtitles or [] + + # ignore subtitles + subtitles = [s for s in subtitles if s.id not in ignore_subtitles] # sort subtitles by score scored_subtitles = sorted( @@ -303,7 +316,11 @@ def list_subtitles_provider_tuple( def list_subtitles(self, video: Video, languages: Set[Language]) -> list[Subtitle]: """List subtitles, multi-threaded.""" - subtitles = [] + subtitles: list[Subtitle] = [] + + # Avoid raising a ValueError with `ThreadPoolExecutor(self.max_workers)` + if self.max_workers == 0: + return subtitles with ThreadPoolExecutor(self.max_workers) as executor: executor_map = executor.map( @@ -648,8 +665,7 @@ def scan_videos( def refine( video: Video, *, - episode_refiners: Sequence[str] | None = None, - movie_refiners: Sequence[str] | None = None, + refiners: Sequence[str] | None = None, refiner_configs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Video: @@ -661,20 +677,19 @@ def refine( :param video: the video to refine. :type video: :class:`~subliminal.video.Video` - :param tuple episode_refiners: refiners to use for episodes. - :param tuple movie_refiners: refiners to use for movies. + :param Sequence refiners: refiners to select. None defaults to all refiners. :param dict refiner_configs: refiner configuration as keyword arguments per refiner name to pass when calling the refine method :param kwargs: additional parameters for the :func:`~subliminal.refiners.refine` functions. """ - refiners: tuple[str, ...] = () + refiners = refiners if refiners is not None else default_refiners + if isinstance(video, Movie): + refiners = [r for r in refiners if r not in discarded_movie_refiners] if isinstance(video, Episode): - refiners = tuple(episode_refiners) if episode_refiners is not None else ('metadata', 'tvdb', 'omdb', 'tmdb') - elif isinstance(video, Movie): # pragma: no branch - refiners = tuple(movie_refiners) if movie_refiners is not None else ('metadata', 'omdb', 'tmdb') + refiners = [r for r in refiners if r not in discarded_episode_refiners] - for refiner in ('hash', *refiners): + for refiner in refiners: logger.info('Refining video with %s', refiner) try: refiner_manager[refiner].plugin(video, **dict((refiner_configs or {}).get(refiner, {}), **kwargs)) diff --git a/subliminal/extensions.py b/subliminal/extensions.py index c7e9a725..03d6e1a5 100644 --- a/subliminal/extensions.py +++ b/subliminal/extensions.py @@ -55,7 +55,7 @@ def list_entry_points(self) -> list[EntryPoint]: # registered extensions for rep in self.registered_extensions: ep = parse_entry_point(rep, self.namespace) - if ep.name not in [e.name for e in eps]: + if ep.name not in [e.name for e in eps]: # pragma: no branch eps.append(ep) return eps @@ -84,7 +84,7 @@ def register(self, entry_point: str) -> None: verify_requirements=False, ) self.extensions.append(ext) - if self._extensions_by_name is not None: + if self._extensions_by_name is not None: # pragma: no branch self._extensions_by_name[ext.name] = ext self.registered_extensions.insert(0, entry_point) @@ -101,9 +101,9 @@ def unregister(self, entry_point: str) -> None: ep = parse_entry_point(entry_point, self.namespace) self.registered_extensions.remove(entry_point) - if self._extensions_by_name is not None: + if self._extensions_by_name is not None: # pragma: no branch del self._extensions_by_name[ep.name] - for i, ext in enumerate(self.extensions): + for i, ext in enumerate(self.extensions): # pragma: no branch if ext.name == ep.name: del self.extensions[i] break @@ -153,3 +153,15 @@ def parse_entry_point(src: str, group: str) -> EntryPoint: 'tmdb = subliminal.refiners.tmdb:refine', ], ) + +#: Disabled refiners +disabled_refiners: list[str] = [] + +#: Default enabled refiners +default_refiners = [r for r in refiner_manager.names() if r not in disabled_refiners] + +#: Discarded Movie refiners +discarded_movie_refiners: list[str] = ['tvdb'] + +#: Discarded Episode refiners +discarded_episode_refiners: list[str] = [] diff --git a/subliminal/utils.py b/subliminal/utils.py index ee7cb696..f49f6395 100644 --- a/subliminal/utils.py +++ b/subliminal/utils.py @@ -20,7 +20,17 @@ if TYPE_CHECKING: from collections.abc import Sequence, Set - from typing import TypeGuard + from typing import TypedDict, TypeGuard + + S = TypeVar('S') + + class ExtendedLists(Generic[S], TypedDict): + """Dict with item to select, extend-select and ignore.""" + + select: Sequence[S] + extend: Sequence[S] + ignore: Sequence[S] + T = TypeVar('T') R = TypeVar('R') @@ -61,12 +71,12 @@ def sanitize(string: str, ignore_characters: Set[str] | None = None) -> str: # replace some characters with one space characters = {'-', ':', '(', ')', '.', ','} - ignore_characters - if characters: + if characters: # pragma: no branch string = re.sub(r'[{}]'.format(re.escape(''.join(characters))), ' ', string) # remove some characters characters = {"'"} - ignore_characters - if characters: + if characters: # pragma: no branch string = re.sub(r'[{}]'.format(re.escape(''.join(characters))), '', string) # replace multiple spaces with one @@ -95,6 +105,7 @@ def sanitize_release_group(string: str) -> str: @none_passthrough def sanitize_id(id_: str | int) -> int: """Sanitize the IMDB (or other) id and transform it to a string (without leading 'tt' or zeroes).""" + # TODO: use str.removeprefix('tt') id_ = str(id_).lower().lstrip('t') return int(id_) @@ -233,3 +244,85 @@ def get_age( file_date = max(file_date, creation_date(filepath)) reference_date = reference_date if reference_date is not None else datetime.now(timezone.utc) return reference_date - datetime.fromtimestamp(file_date, timezone.utc) + + +def merge_extend_and_ignore_unions( + lists: ExtendedLists[str], + default_lists: ExtendedLists[str], + defaults: Sequence[str] | None = None, + all_token: str | None = 'ALL', # noqa: S107 +) -> list[str]: + """Merge lists of item to select and ignore. + + Ignore lists supersede the select lists. + `select` and `ignore` supersede `default_select` and `default_ignore`. + + :param Sequence[T] select: sequence of items to select (supersede the defaults). + :param Sequence[T] ignore: sequence of items to select (supersede the defaults and `select`). + :param Sequence[T] default_select: default sequence of items to select. + :param Sequence[T] default_ignore: default sequence of items to ignore. + :return: the list of selected and not-ignored items. + :rtype: list[T] + """ + extend = lists['extend'] or [] + ignore = lists['ignore'] or [] + defaults = defaults or [] + + # Ignore all + if all_token is not None and all_token in ignore: + return [] + + # Nothing selected, start by the selected list using the default_lists + if not lists['select']: + item_set = set(get_extend_and_ignore_union(**default_lists, defaults=defaults, all_token=all_token)) + else: + item_set = set(lists['select']) + + # Add the extend list + item_set.update(set(extend)) + # Replace all_token + if all_token in item_set: + item_set -= {all_token} + item_set.update(defaults) + # Remove the ignore list + item_set -= set(ignore) + + return list(item_set) + + +def get_extend_and_ignore_union( + select: Sequence[str] | None = None, + extend: Sequence[str] | None = None, + ignore: Sequence[str] | None = None, + defaults: Sequence[str] | None = None, + all_token: str | None = 'ALL', # noqa: S107 +) -> list[str]: + """Get the list of items to use. + + :param Sequence select: items to select. Empty sequence or None is equivalent to `defaults`. + :param Sequence extend: like 'select', but add additional items (empty sequence does nothing). + :param Sequence ignore: items to ignore. + :param Sequence defaults: default items + :param str all_token: token used to represent all the items. + + """ + extend = extend or [] + ignore = ignore or [] + defaults = defaults or [] + + # Ignore all + if all_token is not None and all_token in ignore: + return [] + + # Start with the defaults + item_set = set(select or defaults) + # Add the extend list + item_set.update(set(extend)) + # Replace all_token + if all_token is not None and all_token in item_set: + item_set -= {all_token} + item_set.update(defaults) + # Remove the ignore list + item_set -= set(ignore) + + return list(item_set) diff --git a/tests/test_core.py b/tests/test_core.py index 8fe161e7..1e0efbb1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,4 @@ # ruff: noqa: PT011, SIM115 -import logging import os import sys from datetime import datetime, timedelta, timezone @@ -15,7 +14,6 @@ download_best_subtitles, download_subtitles, list_subtitles, - refine, save_subtitles, scan_archive, scan_name, @@ -23,7 +21,7 @@ scan_videos, search_external_subtitles, ) -from subliminal.extensions import provider_manager, refiner_manager +from subliminal.extensions import provider_manager from subliminal.providers.tvsubtitles import TVsubtitlesSubtitle from subliminal.score import episode_scores from subliminal.subtitle import Subtitle @@ -55,24 +53,6 @@ def _mock_providers(monkeypatch): monkeypatch.setattr(provider.plugin, 'terminate', Mock()) -@pytest.fixture() -def refiner_mocks(monkeypatch): - mocks = {} - # monkeypatch refiners - for refiner in refiner_manager: - mocks[refiner.name] = Mock() - - def mocked_refine(refiner): - def func(video, *args, **kwargs): - mocks[refiner.name](video) - return refiner.plugin(*args, **kwargs) - - return func - - monkeypatch.setattr(refiner, 'plugin', mocked_refine(refiner)) - return mocks - - def test_provider_pool_get_keyerror(): pool = ProviderPool() with pytest.raises(KeyError): @@ -467,38 +447,6 @@ def test_scan_videos_age(movies, tmpdir, monkeypatch): mock_scan_video.assert_has_calls(scan_video_calls, any_order=True) # type: ignore[arg-type] -def test_refine_movie(movies, caplog, refiner_mocks): - video = movies['man_of_steel'] - - with caplog.at_level(logging.INFO): - refine(video, movie_refiners=['metadata'], providers=['opensubtitles']) - - # test refiners - for name in ('omdb', 'tmdb'): - assert f'Refining video with {name}' not in caplog.text - refiner_mocks[name].assert_not_called() - - for name in ('hash', 'metadata'): - assert f'Refining video with {name}' in caplog.text - refiner_mocks[name].assert_called_once_with(video) - - -def test_refine_episode(episodes, caplog, refiner_mocks): - video = episodes['bbt_s07e05'] - - with caplog.at_level(logging.INFO): - refine(video, episode_refiners=['omdb', 'tvdb'], providers=['opensubtitles']) - - # test refiners - for name in ('metadata', 'tmdb'): - assert f'Refining video with {name}' not in caplog.text - refiner_mocks[name].assert_not_called() - - for name in ('hash', 'omdb', 'tvdb'): - assert f'Refining video with {name}' in caplog.text - refiner_mocks[name].assert_called_once_with(video) - - @pytest.mark.usefixtures('_mock_providers') def test_list_subtitles_movie(movies): video = movies['man_of_steel'] diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 893c9c24..b3bd931b 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -6,15 +6,35 @@ else: from importlib_metadata import entry_points # type: ignore[assignment,no-redef,import-not-found] +import pytest from subliminal.extensions import ( + EntryPoint, RegistrableExtensionManager, default_providers, + default_refiners, disabled_providers, + disabled_refiners, parse_entry_point, provider_manager, + refiner_manager, ) +def test_parse_entry_point() -> None: + src = 'addic7ed = subliminal.providers.addic7ed:Addic7edProvider' + ep = parse_entry_point(src, group='subliminal.providers') + assert isinstance(ep, EntryPoint) + assert ep.name == 'addic7ed' + assert ep.value == 'subliminal.providers.addic7ed:Addic7edProvider' + assert ep.group == 'subliminal.providers' + + +def test_parse_entry_point_wrong() -> None: + src = 'subliminal.providers.addic7ed:Addic7edProvider' + with pytest.raises(ValueError, match='EntryPoint must be'): + parse_entry_point(src, group='subliminal.providers') + + def test_registrable_extension_manager_all_extensions(): native_extensions = sorted(e.name for e in provider_manager) @@ -53,6 +73,18 @@ def test_registrable_extension_manager_register(): assert len(list(manager)) == 3 assert 'de7cidda' in manager.names() + eps = manager.list_entry_points() + ep_names = [ep.name for ep in eps] + assert ep_names == ['addic7ed', 'opensubtitles', 'de7cidda'] + + # Raise ValueError on same entry point + with pytest.raises(ValueError, match='Extension already registered'): + manager.register('de7cidda = subliminal.providers.addic7ed:Addic7edProvider') + + # Raise ValueError on same entry point name + with pytest.raises(ValueError, match='An extension with the same name already exist'): + manager.register('de7cidda = subliminal.providers.opensubtitles:OpenSubtitlesProvider') + def test_registrable_extension_manager_unregister(): manager = RegistrableExtensionManager( @@ -68,6 +100,10 @@ def test_registrable_extension_manager_unregister(): assert len(list(manager)) == 2 assert set(manager.names()) == {'gestdown', 'tvsubtitles'} + # Raise ValueError on entry point not found + with pytest.raises(ValueError, match='Extension not registered'): + manager.unregister('seltitbusnepo = subliminal.providers.opensubtitles:OpenSubtitlesProvider') + def test_provider_manager(): setup_names = {ep.name for ep in entry_points(group=provider_manager.namespace)} @@ -78,3 +114,14 @@ def test_provider_manager(): disabled_names = set(disabled_providers) assert enabled_names == setup_names - disabled_names assert internal_names == enabled_names | disabled_names + + +def test_refiner_manager(): + setup_names = {ep.name for ep in entry_points(group=refiner_manager.namespace)} + internal_names = { + parse_entry_point(iep, refiner_manager.namespace).name for iep in refiner_manager.internal_extensions + } + enabled_names = set(default_refiners) + disabled_names = set(disabled_refiners) + assert enabled_names == setup_names - disabled_names + assert internal_names == enabled_names | disabled_names diff --git a/tests/test_utils.py b/tests/test_utils.py index 894124de..661ddd19 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,25 @@ +from __future__ import annotations + import datetime from typing import Any +from xmlrpc.client import ProtocolError -from subliminal.utils import get_age, sanitize +import pytest +import requests +from subliminal.exceptions import ServiceUnavailable +from subliminal.utils import ( + ensure_list, + get_age, + get_extend_and_ignore_union, + handle_exception, + matches_title, + merge_extend_and_ignore_unions, + sanitize, +) def test_sanitize(): + assert sanitize(None) is None assert sanitize("Marvel's Agents of S.H.I.E.L.D.") == 'marvels agents of s h i e l d' @@ -38,3 +53,132 @@ def mock_creation_date_sooner(*args: Any) -> float: c_age_2 = get_age(__file__, use_ctime=True, reference_date=NOW) assert c_age_2 == datetime.timedelta(weeks=2) + + +@pytest.mark.parametrize( + ('actual', 'title', 'alt', 'expected'), + [ + (None, 'The Big Bang Theory', [], False), + ('The Big Bang Theory', None, [], False), + ('the.big.bang.theory', 'The Big Bang Theory', [], True), + ('the.big.bang.theory', 'Big Bang Theory', None, False), + ('the.big.bang.theory', 'Big Bang Theory', ['The Big Bang Theory'], True), + ('the.big.bang.theory', 'Big Bang Theory', ['Not The Big Bang Theory'], False), + ], +) +def test_matches_title(actual: str | None, title: str | None, alt: list[str], expected: bool) -> None: + ret = matches_title(actual, title, alt) + assert ret == expected + + +@pytest.mark.parametrize( + ('err', 'msg'), + [ + (requests.Timeout(), 'Request timed out'), + (ServiceUnavailable(), 'Service unavailable'), + (ProtocolError('', 0, '', {}), 'Service unavailable'), + (requests.exceptions.HTTPError(response=requests.Response()), 'HTTP error'), + (requests.exceptions.SSLError(''), 'SSL error'), + (ValueError(), 'Unexpected error'), + ], +) +def test_handle_exception(caplog: pytest.LogCaptureFixture, err: Exception, msg: str) -> None: + handle_exception(err, '') + for record in caplog.records: + assert record.levelname == 'ERROR' + assert record.message.startswith(msg) + + +def test_ensure_list() -> None: + ret: list = ensure_list(None) + assert isinstance(ret, list) + assert ret == [] + + ret = ensure_list('a') + assert isinstance(ret, list) + assert set(ret) == {'a'} + + ret = ensure_list(('a', 'b')) + assert isinstance(ret, list) + assert set(ret) == {'a', 'b'} + + ret = ensure_list({'a', 'b'}) # type: ignore[arg-type] + assert isinstance(ret, list) + assert set(ret) == {'a', 'b'} + + +@pytest.mark.parametrize( + ('select', 'extend', 'ignore', 'defaults', 'expected'), + [ + (None, None, None, None, set()), + (None, None, None, ['a', 'b'], {'a', 'b'}), + ([], None, None, ['a', 'b'], {'a', 'b'}), + ([], [], ['a'], ['a', 'b'], {'b'}), + ([], ['c'], ['a'], ['a', 'b'], {'b', 'c'}), + (['a'], ['b'], ['c'], ['a', 'b'], {'a', 'b'}), + (['a', 'b', 'c'], ['c'], ['a'], ['a', 'b'], {'b', 'c'}), + (['a', 'b'], ['c'], ['c'], ['a', 'b'], {'a', 'b'}), + ([], ['c'], ['c'], ['a', 'b'], {'a', 'b'}), + (['ALL'], ['b'], [], ['a', 'b'], {'a', 'b'}), + (['ALL'], ['c'], [], ['a', 'b'], {'a', 'b', 'c'}), + (['ALL'], [], ['b'], ['a', 'b'], {'a'}), + (['c'], ['ALL'], [], ['a', 'b'], {'a', 'b', 'c'}), + ([], [], ['ALL'], ['a', 'b'], set()), + (['ALL'], [], ['ALL'], ['a', 'b'], set()), + (['ALL'], ['ALL'], ['ALL'], ['a', 'b'], set()), + ], +) +def test_get_extend_and_ignore_union( + select: list[str] | None, + extend: list[str] | None, + ignore: list[str] | None, + defaults: list[str] | None, + expected: set[str], +) -> None: + final = set(get_extend_and_ignore_union(select, extend, ignore, defaults)) + assert final == expected + + +@pytest.mark.parametrize( + ('lists', 'default_lists', 'defaults', 'expected'), + [ + ( + {'select': None, 'extend': None, 'ignore': None}, + {'select': None, 'extend': None, 'ignore': None}, + ['a', 'b'], + {'a', 'b'}, + ), + ( + {'select': None, 'extend': None, 'ignore': 'ALL'}, + {'select': None, 'extend': None, 'ignore': None}, + ['a', 'b'], + set(), + ), + ( + {'select': None, 'extend': None, 'ignore': None}, + {'select': None, 'extend': None, 'ignore': 'ALL'}, + ['a', 'b'], + set(), + ), + ( + {'select': ['c'], 'extend': None, 'ignore': None}, + {'select': None, 'extend': None, 'ignore': 'ALL'}, + ['a', 'b'], + {'c'}, + ), + ( + {'select': ['c'], 'extend': ['ALL'], 'ignore': None}, + {'select': None, 'extend': None, 'ignore': 'ALL'}, + ['a', 'b'], + {'a', 'b', 'c'}, + ), + ], +) +def test_merge_extend_and_ignore_unions( + lists: dict, + default_lists: dict, + defaults: list[str], + expected: set[str], +) -> None: + final = set(merge_extend_and_ignore_unions(lists, default_lists, defaults)) # type: ignore[arg-type] + assert final == expected