diff --git a/pyproject.toml b/pyproject.toml index 6e4c991d25..3eb04a3958 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ module = [ "core.settings.*", "core.util.authentication_for_opds", "core.util.cache", + "tests.fixtures.authenticator", "tests.migration.*", ] no_implicit_reexport = true diff --git a/tests/api/admin/controller/test_patron_auth.py b/tests/api/admin/controller/test_patron_auth.py index e9daddc7ae..c7c2844e5f 100644 --- a/tests/api/admin/controller/test_patron_auth.py +++ b/tests/api/admin/controller/test_patron_auth.py @@ -35,16 +35,21 @@ from api.sip import SIP2AuthenticationProvider from core.integration.goals import Goals from core.model import AdminRole, Library, get_one -from core.model.integration import ( - IntegrationConfiguration, - IntegrationLibraryConfiguration, -) +from core.model.integration import IntegrationConfiguration from core.util.problem_detail import ProblemDetail if TYPE_CHECKING: from tests.fixtures.api_admin import SettingsControllerFixture - from tests.fixtures.authenticator import AuthProviderFixture - from tests.fixtures.database import DatabaseTransactionFixture + from tests.fixtures.authenticator import ( + MilleniumAuthIntegrationFixture, + SamlAuthIntegrationFixture, + SimpleAuthIntegrationFixture, + Sip2AuthIntegrationFixture, + ) + from tests.fixtures.database import ( + DatabaseTransactionFixture, + IntegrationLibraryConfigurationFixture, + ) @pytest.fixture @@ -116,10 +121,8 @@ def test_patron_auth_services_get_with_simple_auth_service( self, settings_ctrl_fixture: SettingsControllerFixture, db: DatabaseTransactionFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], - create_integration_library_configuration: Callable[ - ..., IntegrationLibraryConfiguration - ], + create_simple_auth_integration: SimpleAuthIntegrationFixture, + create_integration_library_configuration: IntegrationLibraryConfigurationFixture, get_response: Callable[[], dict[str, Any] | ProblemDetail], ): auth_service, _ = create_simple_auth_integration( @@ -163,7 +166,7 @@ def test_patron_auth_services_get_with_millenium_auth_service( self, settings_ctrl_fixture: SettingsControllerFixture, db: DatabaseTransactionFixture, - create_millenium_auth_integration: Callable[..., AuthProviderFixture], + create_millenium_auth_integration: MilleniumAuthIntegrationFixture, get_response: Callable[[], dict[str, Any] | ProblemDetail], ): auth_service, _ = create_millenium_auth_integration( @@ -194,7 +197,7 @@ def test_patron_auth_services_get_with_sip2_auth_service( self, settings_ctrl_fixture: SettingsControllerFixture, db: DatabaseTransactionFixture, - create_sip2_auth_integration: Callable[..., AuthProviderFixture], + create_sip2_auth_integration: Sip2AuthIntegrationFixture, get_response: Callable[[], dict[str, Any] | ProblemDetail], ): auth_service, _ = create_sip2_auth_integration( @@ -229,7 +232,7 @@ def test_patron_auth_services_get_with_saml_auth_service( self, settings_ctrl_fixture: SettingsControllerFixture, db: DatabaseTransactionFixture, - create_saml_auth_integration: Callable[..., AuthProviderFixture], + create_saml_auth_integration: SamlAuthIntegrationFixture, get_response: Callable[[], dict[str, Any] | ProblemDetail], ): auth_service, _ = create_saml_auth_integration( @@ -284,7 +287,7 @@ def test_patron_auth_services_post_missing_service( def test_patron_auth_services_post_cannot_change_protocol( self, post_response: Callable[..., Response | ProblemDetail], - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): auth_service, _ = create_simple_auth_integration() form = ImmutableMultiDict( @@ -299,7 +302,7 @@ def test_patron_auth_services_post_cannot_change_protocol( def test_patron_auth_services_post_name_in_use( self, post_response: Callable[..., Response | ProblemDetail], - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): auth_service, _ = create_simple_auth_integration() form = ImmutableMultiDict( @@ -314,7 +317,7 @@ def test_patron_auth_services_post_name_in_use( def test_patron_auth_services_post_invalid_configuration( self, post_response: Callable[..., Response | ProblemDetail], - create_millenium_auth_integration: Callable[..., AuthProviderFixture], + create_millenium_auth_integration: MilleniumAuthIntegrationFixture, common_args: list[tuple[str, str]], ): auth_service, _ = create_millenium_auth_integration() @@ -336,7 +339,7 @@ def test_patron_auth_services_post_invalid_configuration( def test_patron_auth_services_post_incomplete_configuration( self, post_response: Callable[..., Response | ProblemDetail], - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, common_args: list[tuple[str, str]], ): auth_service, _ = create_simple_auth_integration() @@ -385,7 +388,7 @@ def test_patron_auth_services_post_missing_patron_auth_no_such_library( def test_patron_auth_services_post_missing_patron_auth_multiple_basic( self, post_response: Callable[..., Response | ProblemDetail], - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, default_library: Library, common_args: list[tuple[str, str]], ): @@ -544,7 +547,7 @@ def test_patron_auth_services_post_edit( post_response: Callable[..., Response | ProblemDetail], common_args: List[Tuple[str, str]], settings_ctrl_fixture: SettingsControllerFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, db: DatabaseTransactionFixture, monkeypatch: MonkeyPatch, ): @@ -610,7 +613,7 @@ def test_patron_auth_service_delete( self, common_args: List[Tuple[str, str]], settings_ctrl_fixture: SettingsControllerFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): controller = settings_ctrl_fixture.manager.admin_patron_auth_services_controller db = settings_ctrl_fixture.ctrl.db diff --git a/tests/api/admin/controller/test_patron_auth_self_tests.py b/tests/api/admin/controller/test_patron_auth_self_tests.py index ea2e991f3d..54f615fb31 100644 --- a/tests/api/admin/controller/test_patron_auth_self_tests.py +++ b/tests/api/admin/controller/test_patron_auth_self_tests.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest @@ -23,7 +23,7 @@ from _pytest.monkeypatch import MonkeyPatch from flask.ctx import RequestContext - from tests.fixtures.authenticator import AuthProviderFixture + from tests.fixtures.authenticator import SimpleAuthIntegrationFixture from tests.fixtures.database import DatabaseTransactionFixture @@ -56,7 +56,7 @@ def test_patron_auth_self_tests_get_with_no_libraries( self, controller: PatronAuthServiceSelfTestsController, get_request_context: RequestContext, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): auth_service, _ = create_simple_auth_integration() response_obj = controller.process_patron_auth_service_self_tests( @@ -75,7 +75,7 @@ def test_patron_auth_self_tests_test_get_no_results( self, controller: PatronAuthServiceSelfTestsController, get_request_context: RequestContext, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, default_library: Library, ): auth_service, _ = create_simple_auth_integration(library=default_library) @@ -99,7 +99,7 @@ def test_patron_auth_self_tests_test_get( self, controller: PatronAuthServiceSelfTestsController, get_request_context: RequestContext, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, monkeypatch: MonkeyPatch, default_library: Library, ): @@ -136,7 +136,7 @@ def test_patron_auth_self_tests_post_with_no_libraries( self, controller: PatronAuthServiceSelfTestsController, post_request_context: RequestContext, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): auth_service, _ = create_simple_auth_integration() response = controller.process_patron_auth_service_self_tests(auth_service.id) @@ -149,7 +149,7 @@ def test_patron_auth_self_tests_test_post( self, controller: PatronAuthServiceSelfTestsController, post_request_context: RequestContext, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, monkeypatch: MonkeyPatch, db: DatabaseTransactionFixture, ): diff --git a/tests/api/test_authenticator.py b/tests/api/test_authenticator.py index b18d0d8639..5dcce67a1a 100644 --- a/tests/api/test_authenticator.py +++ b/tests/api/test_authenticator.py @@ -67,7 +67,10 @@ if TYPE_CHECKING: from ..fixtures.api_controller import ControllerFixture - from ..fixtures.authenticator import AuthProviderFixture + from ..fixtures.authenticator import ( + CreateAuthIntegrationFixture, + MilleniumAuthIntegrationFixture, + ) from ..fixtures.database import DatabaseTransactionFixture from ..fixtures.vendor_id import VendorIDFixture @@ -471,7 +474,7 @@ class TestAuthenticator: def test_init( self, controller_fixture: ControllerFixture, - create_millenium_auth_integration: Callable[..., AuthProviderFixture], + create_millenium_auth_integration: MilleniumAuthIntegrationFixture, ): db = controller_fixture.db @@ -601,7 +604,7 @@ class TestLibraryAuthenticator: def test_from_config_basic_auth_only( self, db: DatabaseTransactionFixture, - create_millenium_auth_integration: Callable[..., AuthProviderFixture], + create_millenium_auth_integration: MilleniumAuthIntegrationFixture, ): # Only a basic auth provider. create_millenium_auth_integration(db.default_library()) @@ -657,8 +660,8 @@ def test_config_succeeds_when_no_providers_configured( def test_configuration_exception_during_from_config_stored( self, db: DatabaseTransactionFixture, - create_millenium_auth_integration: Callable[..., AuthProviderFixture], - create_auth_integration_configuration: Callable[..., AuthProviderFixture], + create_millenium_auth_integration: MilleniumAuthIntegrationFixture, + create_auth_integration_configuration: CreateAuthIntegrationFixture, ): # If the initialization of an AuthenticationProvider from config # raises CannotLoadConfiguration or ImportError, the exception @@ -753,15 +756,15 @@ def __init__(self, *args, **kwargs): def test_register_provider_basic_auth( self, db: DatabaseTransactionFixture, - create_auth_integration_configuration: Callable[..., AuthProviderFixture], + create_auth_integration_configuration: CreateAuthIntegrationFixture, patron_auth_registry: PatronAuthRegistry, ): library = db.default_library() - protocol = patron_auth_registry.get_protocol(SIP2AuthenticationProvider) + protocol = patron_auth_registry.get_protocol(SIP2AuthenticationProvider, "") _, integration = create_auth_integration_configuration( protocol, library, - settings={ + settings_dict={ "url": "http://url/", "password": "secret", }, diff --git a/tests/api/test_axis.py b/tests/api/test_axis.py index a164a8b3e5..1ae132359c 100644 --- a/tests/api/test_axis.py +++ b/tests/api/test_axis.py @@ -6,7 +6,7 @@ import ssl import urllib from functools import partial -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock import pytest @@ -68,7 +68,7 @@ if TYPE_CHECKING: from ..fixtures.api_axis_files import AxisFilesFixture - from ..fixtures.authenticator import AuthProviderFixture + from ..fixtures.authenticator import SimpleAuthIntegrationFixture from ..fixtures.database import DatabaseTransactionFixture @@ -140,7 +140,7 @@ def test_external_integration(self, axis360: Axis360Fixture): def test__run_self_tests( self, axis360: Axis360Fixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): # Verify that Axis360API._run_self_tests() calls the right # methods. diff --git a/tests/api/test_bibliotheca.py b/tests/api/test_bibliotheca.py index dca27836be..62231870dd 100644 --- a/tests/api/test_bibliotheca.py +++ b/tests/api/test_bibliotheca.py @@ -4,15 +4,7 @@ import random from datetime import datetime, timedelta from io import BytesIO, StringIO -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Optional, - Protocol, - runtime_checkable, -) +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Protocol, runtime_checkable from unittest import mock from unittest.mock import MagicMock @@ -77,7 +69,7 @@ if TYPE_CHECKING: from tests.fixtures.api_bibliotheca_files import BibliothecaFilesFixture - from tests.fixtures.authenticator import AuthProviderFixture + from tests.fixtures.authenticator import SimpleAuthIntegrationFixture from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.time import Time @@ -110,7 +102,7 @@ def test_external_integration(self, bibliotheca_fixture: BibliothecaAPITestFixtu def test__run_self_tests( self, bibliotheca_fixture: BibliothecaAPITestFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): db = bibliotheca_fixture.db # Verify that BibliothecaAPI._run_self_tests() calls the right diff --git a/tests/api/test_enki.py b/tests/api/test_enki.py index 9c54da938c..390faa1cca 100644 --- a/tests/api/test_enki.py +++ b/tests/api/test_enki.py @@ -2,7 +2,7 @@ import datetime import json -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING import pytest @@ -31,7 +31,7 @@ if TYPE_CHECKING: from tests.fixtures.api_enki_files import EnkiFilesFixture - from tests.fixtures.authenticator import AuthProviderFixture + from tests.fixtures.authenticator import SimpleAuthIntegrationFixture class EnkiTestFixure: @@ -96,7 +96,7 @@ def test_collection(self, enki_test_fixture: EnkiTestFixure): def test__run_self_tests( self, enki_test_fixture: EnkiTestFixure, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): db = enki_test_fixture.db diff --git a/tests/api/test_odilo.py b/tests/api/test_odilo.py index 65451566b9..f3f0c1f791 100644 --- a/tests/api/test_odilo.py +++ b/tests/api/test_odilo.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING import pytest @@ -32,7 +32,7 @@ if TYPE_CHECKING: from ..fixtures.api_odilo_files import OdiloFilesFixture - from ..fixtures.authenticator import AuthProviderFixture + from ..fixtures.authenticator import SimpleAuthIntegrationFixture from ..fixtures.database import DatabaseTransactionFixture @@ -226,7 +226,7 @@ def test_external_integration(self, odilo: OdiloFixture): def test__run_self_tests( self, odilo: OdiloFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): """Verify that OdiloAPI._run_self_tests() calls the right methods. diff --git a/tests/api/test_overdrive.py b/tests/api/test_overdrive.py index 21c0cf210b..f778d19c1b 100644 --- a/tests/api/test_overdrive.py +++ b/tests/api/test_overdrive.py @@ -5,7 +5,7 @@ import os import random from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Dict from unittest.mock import MagicMock, create_autospec import pytest @@ -47,7 +47,7 @@ if TYPE_CHECKING: from ..fixtures.api_overdrive_files import OverdriveAPIFilesFixture - from ..fixtures.authenticator import AuthProviderFixture + from ..fixtures.authenticator import SimpleAuthIntegrationFixture from ..fixtures.time import Time @@ -122,7 +122,7 @@ def test_lock_in_format(self, overdrive_api_fixture: OverdriveAPIFixture): def test__run_self_tests( self, overdrive_api_fixture: OverdriveAPIFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): # Verify that OverdriveAPI._run_self_tests() calls the right # methods. diff --git a/tests/api/test_scripts.py b/tests/api/test_scripts.py index 7b78f973df..9d22f285ea 100644 --- a/tests/api/test_scripts.py +++ b/tests/api/test_scripts.py @@ -4,7 +4,7 @@ import datetime from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest @@ -63,7 +63,7 @@ from tests.fixtures.library import LibraryFixture if TYPE_CHECKING: - from tests.fixtures.authenticator import AuthProviderFixture + from tests.fixtures.authenticator import SimpleAuthIntegrationFixture from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.sample_covers import SampleCoversFixture from tests.fixtures.search import ExternalSearchFixture @@ -1493,7 +1493,7 @@ def patron(self, authdata, db: DatabaseTransactionFixture): def authentication_provider( self, db: DatabaseTransactionFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): barcode = "12345" pin = "abcd" diff --git a/tests/api/test_selftest.py b/tests/api/test_selftest.py index 2dc72298b0..5b3f16cd75 100644 --- a/tests/api/test_selftest.py +++ b/tests/api/test_selftest.py @@ -3,7 +3,7 @@ import datetime from io import StringIO -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -18,7 +18,7 @@ from core.util.problem_detail import ProblemDetail if TYPE_CHECKING: - from tests.fixtures.authenticator import AuthProviderFixture + from tests.fixtures.authenticator import SimpleAuthIntegrationFixture from tests.fixtures.database import DatabaseTransactionFixture @@ -26,7 +26,7 @@ class TestHasSelfTests: def test__determine_self_test_patron( self, db: DatabaseTransactionFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): """Test per-library default patron lookup for self-tests. @@ -95,7 +95,7 @@ def test__determine_self_test_patron( def test_default_patrons( self, db: DatabaseTransactionFixture, - create_simple_auth_integration: Callable[..., AuthProviderFixture], + create_simple_auth_integration: SimpleAuthIntegrationFixture, ): """Some self-tests must run with a patron's credentials. The default_patrons() method finds the default Patron for every diff --git a/tests/fixtures/authenticator.py b/tests/fixtures/authenticator.py index 77e8db76f8..57323f28a4 100644 --- a/tests/fixtures/authenticator.py +++ b/tests/fixtures/authenticator.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple, Type +from typing import Dict, Optional, Tuple, Type import pytest @@ -15,41 +15,56 @@ IntegrationLibraryConfiguration, ) from tests.api.saml.saml_strings import CORRECT_XML_WITH_ONE_SP +from tests.fixtures.database import ( + IntegrationConfigurationFixture, + IntegrationLibraryConfigurationFixture, +) AuthProviderFixture = Tuple[ IntegrationConfiguration, Optional[IntegrationLibraryConfiguration] ] -@pytest.fixture -def create_auth_integration_configuration( - create_integration_configuration, - create_integration_library_configuration: Callable[ - ..., IntegrationLibraryConfiguration - ], -) -> Callable[..., AuthProviderFixture]: - def create_integration( +class CreateAuthIntegrationFixture: + def __init__( + self, + integration_configuration: IntegrationConfigurationFixture, + integration_library_configuration: IntegrationLibraryConfigurationFixture, + ): + self.integration_configuration = integration_configuration + self.integration_library_configuration = integration_library_configuration + + def __call__( + self, protocol: str, library: Optional[Library], - settings: Optional[dict] = None, - library_settings: Optional[dict] = None, + settings_dict: Optional[Dict[str, str]] = None, + library_settings_dict: Optional[Dict[str, str]] = None, ) -> AuthProviderFixture: - settings = settings or {} - library_settings = library_settings or {} - integration = create_integration_configuration( + settings_dict = settings_dict or {} + library_settings_dict = library_settings_dict or {} + integration = self.integration_configuration( protocol, Goals.PATRON_AUTH_GOAL, - settings, + settings_dict, ) if library is not None: - library_integration = create_integration_library_configuration( - library, integration, library_settings + library_integration = self.integration_library_configuration( + library, integration, library_settings_dict ) else: library_integration = None return integration, library_integration - return create_integration + +@pytest.fixture +def create_auth_integration_configuration( + create_integration_configuration: IntegrationConfigurationFixture, + create_integration_library_configuration: IntegrationLibraryConfigurationFixture, +) -> CreateAuthIntegrationFixture: + return CreateAuthIntegrationFixture( + create_integration_configuration, create_integration_library_configuration + ) @pytest.fixture() @@ -57,25 +72,40 @@ def patron_auth_registry() -> PatronAuthRegistry: return PatronAuthRegistry() +class AuthProtocolFixture: + def __init__(self, registry: PatronAuthRegistry): + self.registry = registry + + def __call__(self, protocol: Type[AuthenticationProvider]) -> str: + return self.registry.get_protocol(protocol, "") + + @pytest.fixture def get_auth_protocol( patron_auth_registry: PatronAuthRegistry, -) -> Callable[[Type[AuthenticationProvider]], Optional[str]]: - return lambda x: patron_auth_registry.get_protocol(x) - +) -> AuthProtocolFixture: + return AuthProtocolFixture(patron_auth_registry) + + +class SimpleAuthIntegrationFixture: + def __init__( + self, + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, + ): + self.create_auth_integration_configuration = ( + create_auth_integration_configuration + ) + self.get_auth_protocol = get_auth_protocol -@pytest.fixture -def create_simple_auth_integration( - create_auth_integration_configuration: Callable[..., AuthProviderFixture], - get_auth_protocol: Callable[[Type[AuthenticationProvider]], Optional[str]], -) -> Callable[..., AuthProviderFixture]: - def create_integration( + def __call__( + self, library: Optional[Library] = None, test_identifier: str = "username1", test_password: str = "password1", ) -> AuthProviderFixture: - return create_auth_integration_configuration( - get_auth_protocol(SimpleAuthenticationProvider), + return self.create_auth_integration_configuration( + self.get_auth_protocol(SimpleAuthenticationProvider), library, dict( test_identifier=test_identifier, @@ -83,67 +113,111 @@ def create_integration( ), ) - return create_integration - @pytest.fixture -def create_millenium_auth_integration( - create_auth_integration_configuration: Callable[..., AuthProviderFixture], - get_auth_protocol: Callable[[Type[AuthenticationProvider]], Optional[str]], -) -> Callable[..., AuthProviderFixture]: - protocol = get_auth_protocol(MilleniumPatronAPI) +def create_simple_auth_integration( + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, +) -> SimpleAuthIntegrationFixture: + return SimpleAuthIntegrationFixture( + create_auth_integration_configuration, get_auth_protocol + ) + + +class MilleniumAuthIntegrationFixture: + def __init__( + self, + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, + ): + self.create_auth_integration_configuration = ( + create_auth_integration_configuration + ) + self.get_auth_protocol = get_auth_protocol - def create_integration( - library: Optional[Library] = None, **kwargs + def __call__( + self, library: Optional[Library] = None, **kwargs: str ) -> AuthProviderFixture: if "url" not in kwargs: kwargs["url"] = "http://url.com/" - return create_auth_integration_configuration( - protocol, + return self.create_auth_integration_configuration( + self.get_auth_protocol(MilleniumPatronAPI), library, kwargs, ) - return create_integration - @pytest.fixture -def create_sip2_auth_integration( - create_auth_integration_configuration: Callable[..., AuthProviderFixture], - get_auth_protocol: Callable[[Type[AuthenticationProvider]], Optional[str]], -) -> Callable[..., AuthProviderFixture]: - protocol = get_auth_protocol(SIP2AuthenticationProvider) +def create_millenium_auth_integration( + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, +) -> MilleniumAuthIntegrationFixture: + return MilleniumAuthIntegrationFixture( + create_auth_integration_configuration, get_auth_protocol + ) + + +class Sip2AuthIntegrationFixture: + def __init__( + self, + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, + ): + self.create_auth_integration_configuration = ( + create_auth_integration_configuration + ) + self.get_auth_protocol = get_auth_protocol - def create_integration( - library: Optional[Library] = None, **kwargs + def __call__( + self, library: Optional[Library] = None, **kwargs: str ) -> AuthProviderFixture: if "url" not in kwargs: kwargs["url"] = "url.com" - return create_auth_integration_configuration( - protocol, + return self.create_auth_integration_configuration( + self.get_auth_protocol(SIP2AuthenticationProvider), library, kwargs, ) - return create_integration - @pytest.fixture -def create_saml_auth_integration( - create_auth_integration_configuration: Callable[..., AuthProviderFixture], - get_auth_protocol: Callable[[Type[AuthenticationProvider]], Optional[str]], -) -> Callable[..., AuthProviderFixture]: - protocol = get_auth_protocol(SAMLWebSSOAuthenticationProvider) +def create_sip2_auth_integration( + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, +) -> Sip2AuthIntegrationFixture: + return Sip2AuthIntegrationFixture( + create_auth_integration_configuration, get_auth_protocol + ) + + +class SamlAuthIntegrationFixture: + def __init__( + self, + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, + ): + self.create_auth_integration_configuration = ( + create_auth_integration_configuration + ) + self.get_auth_protocol = get_auth_protocol - def create_integration( - library: Optional[Library] = None, **kwargs + def __call__( + self, library: Optional[Library] = None, **kwargs: str ) -> AuthProviderFixture: if "service_provider_xml_metadata" not in kwargs: kwargs["service_provider_xml_metadata"] = CORRECT_XML_WITH_ONE_SP - return create_auth_integration_configuration( - protocol, + return self.create_auth_integration_configuration( + self.get_auth_protocol(SAMLWebSSOAuthenticationProvider), library, kwargs, ) - return create_integration + +@pytest.fixture +def create_saml_auth_integration( + create_auth_integration_configuration: CreateAuthIntegrationFixture, + get_auth_protocol: AuthProtocolFixture, +) -> SamlAuthIntegrationFixture: + return SamlAuthIntegrationFixture( + create_auth_integration_configuration, get_auth_protocol + ) diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 523ae8eb72..c63cf6f81a 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -8,7 +8,7 @@ import time import uuid from textwrap import dedent -from typing import Callable, Generator, Iterable, List, Optional, Tuple +from typing import Generator, Iterable, List, Optional, Tuple import pytest import sqlalchemy @@ -1025,46 +1025,59 @@ def db( tr.close() -@pytest.fixture -def create_integration_configuration( - db: DatabaseTransactionFixture, -) -> Callable[..., IntegrationConfiguration]: - def create_integration( - protocol: str, goal: Goals, settings: Optional[dict] = None +class IntegrationConfigurationFixture: + def __init__(self, db: DatabaseTransactionFixture): + self.db = db + + def __call__( + self, protocol: Optional[str], goal: Goals, settings_dict: Optional[dict] = None ) -> IntegrationConfiguration: integration, _ = create( - db.session, + self.db.session, IntegrationConfiguration, - name=db.fresh_str(), + name=self.db.fresh_str(), protocol=protocol, goal=goal, - settings_dict=settings or {}, + settings_dict=settings_dict or {}, ) return integration - return create_integration - @pytest.fixture -def create_integration_library_configuration( +def create_integration_configuration( db: DatabaseTransactionFixture, -) -> Callable[..., IntegrationLibraryConfiguration]: - def create_library_integration( +) -> IntegrationConfigurationFixture: + fixture = IntegrationConfigurationFixture(db) + return fixture + + +class IntegrationLibraryConfigurationFixture: + def __init__(self, db: DatabaseTransactionFixture): + self.db = db + + def __call__( + self, library: Library, parent: IntegrationConfiguration, - settings: Optional[dict] = None, + settings_dict: Optional[dict] = None, ) -> IntegrationLibraryConfiguration: - settings = settings or {} + settings_dict = settings_dict or {} integration, _ = create( - db.session, + self.db.session, IntegrationLibraryConfiguration, parent=parent, library=library, - settings_dict=settings, + settings_dict=settings_dict, ) return integration - return create_library_integration + +@pytest.fixture +def create_integration_library_configuration( + db: DatabaseTransactionFixture, +) -> IntegrationLibraryConfigurationFixture: + fixture = IntegrationLibraryConfigurationFixture(db) + return fixture class DBStatementCounter: