diff --git a/Makefile b/Makefile index c09a20e0..0f48c611 100644 --- a/Makefile +++ b/Makefile @@ -41,16 +41,6 @@ TOX_DIR := .tox DEFAULT_TARGET ?= dev .DEFAULT_GOAL := $(DEFAULT_TARGET) - -ifeq ($(DEFAULT_TARGET),dev) - BUILD_TARGET := $(SETUP_DEV_SENTINEL) -else ifeq ($(DEFAULT_TARGET),cicd) - BUILD_TARGET := $(SETUP_CICD_SENTINEL) -else - $(error DEFAULT_TARGET must be one of "dev" or "cicd") -endif - - ACTIVATE_VENV := . $(VENV)/bin/activate REPORT_VENV_USAGE := echo '\nActivate your venv with `. $(VENV)/bin/activate`' @@ -73,6 +63,15 @@ SETUP_DEV_SENTINEL = $(MAKE_ARTIFACT_DIRECTORY)/setup_dev_sentinel SETUP_CICD_SENTINEL = $(MAKE_ARTIFACT_DIRECTORY)/setup_cicd_sentinel PYPROJECT_FILES_SENTINEL = $(MAKE_ARTIFACT_DIRECTORY)/pyproject_sentinel +ifeq ($(DEFAULT_TARGET),dev) + BUILD_TARGET := $(SETUP_DEV_SENTINEL) +else ifeq ($(DEFAULT_TARGET),cicd) + BUILD_TARGET := $(SETUP_CICD_SENTINEL) +else + $(error DEFAULT_TARGET must be one of "dev" or "cicd") +endif + + $(PYPROJECT_FILES_SENTINEL): $(VENV) $(MAKE) $(PYPROJECT_FILES) touch $@ @@ -196,10 +195,10 @@ test-bandit : $(VENV) $(BUILD_TARGET) -r . test-pytest : $(VENV) $(BUILD_TARGET) - -$(ACTIVATE_VENV) && \ + $(ACTIVATE_VENV) && \ PYTEST_TARGET=$(PYTEST_TARGET) tox && PYTEST_EXIT_CODE=0 || PYTEST_EXIT_CODE=$$?; \ coverage html --data-file=$(REPORTS_DIR)/$(PYTEST_REPORT)/.coverage; \ - junit2html $(REPORTS_DIR)/$(PYTEST_REPORT)/pytest.xml $(REPORTS_DIR)/$(PYTEST_REPORT)/pytest.html; \ + junit2html $(REPORTS_DIR)/$(PYTEST_REPORT)/pytest.xml $(REPORTS_DIR)/$(PYTEST_REPORT)/pytest.html && \ exit $$PYTEST_EXIT_CODE .PHONY: test test-pytest test-bandit test-pyright test-ruff test-isort diff --git a/pyproject.toml b/pyproject.toml index caf5c109..d5104d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ reportUnnecessaryTypeIgnoreComment = "information" reportUnusedCallResult = "information" reportMissingTypeStubs = "information" reportWildcardImportFromLibrary = "warning" +reportDeprecated = "error" [tool.pytest.ini_options] pythonpath = [ diff --git a/src/database/Ligare/database/config.py b/src/database/Ligare/database/config.py index 7b55aee3..e9855faa 100644 --- a/src/database/Ligare/database/config.py +++ b/src/database/Ligare/database/config.py @@ -3,6 +3,7 @@ from Ligare.programming.config import AbstractConfig from pydantic import BaseModel from pydantic.config import ConfigDict +from typing_extensions import override class DatabaseConnectArgsConfig(BaseModel): @@ -36,6 +37,10 @@ def __init__(self, **data: Any): elif self.connection_string.startswith("postgresql://"): self.connect_args = PostgreSQLDatabaseConnectArgsConfig(**model_data) + @override + def post_load(self) -> None: + return super().post_load() + connection_string: str = "sqlite:///:memory:" sqlalchemy_echo: bool = False # the static field allows Pydantic to store @@ -44,4 +49,8 @@ def __init__(self, **data: Any): class Config(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + database: DatabaseConfig diff --git a/src/database/Ligare/database/dependency_injection.py b/src/database/Ligare/database/dependency_injection.py index 8900a616..56d1ef3f 100644 --- a/src/database/Ligare/database/dependency_injection.py +++ b/src/database/Ligare/database/dependency_injection.py @@ -1,8 +1,10 @@ -from injector import Binder, CallableProvider, Injector, Module, inject, singleton +from injector import Binder, CallableProvider, Injector, inject, singleton from Ligare.database.config import Config, DatabaseConfig from Ligare.database.engine import DatabaseEngine from Ligare.database.types import MetaBase +from Ligare.programming.config import AbstractConfig from Ligare.programming.dependency_injection import ConfigModule +from Ligare.programming.patterns.dependency_injection import ConfigurableModule from sqlalchemy.orm.scoping import ScopedSession from sqlalchemy.orm.session import Session from typing_extensions import override @@ -10,11 +12,16 @@ from .config import DatabaseConfig -class ScopedSessionModule(Module): +class ScopedSessionModule(ConfigurableModule): """ Configure SQLAlchemy Session depedencies for Injector. """ + @override + @staticmethod + def get_config_type() -> type[AbstractConfig]: + return DatabaseConfig + _bases: list[MetaBase | type[MetaBase]] | None = None @override diff --git a/src/database/Ligare/database/migrations/alembic.py b/src/database/Ligare/database/migrations/alembic.py deleted file mode 100644 index ddec9385..00000000 --- a/src/database/Ligare/database/migrations/alembic.py +++ /dev/null @@ -1,209 +0,0 @@ -# TODO integrate Alembic into Ligare -# import logging -# from configparser import ConfigParser -# from dataclasses import dataclass -# from functools import lru_cache -# from logging.config import fileConfig -# from typing import Any, List, Optional, Protocol, cast -# -# from alembic import context -# from psycopg2.errors import UndefinedTable -# from sqlalchemy import engine_from_config, pool -# from sqlalchemy.engine import Connectable, Connection, Engine -# from sqlalchemy.exc import ProgrammingError -# from sqlalchemy.sql.schema import MetaData, Table -# -# from Ligare.database.migrations import DialectHelper, MetaBaseType -# -# -# class type_include_object(Protocol): -# def __call__( -# self, -# object: Table, -# name: str, -# type_: str, -# reflected: Any, -# compare_to: Any, -# ) -> bool: -# ... -# -# -# class type_include_schemas(Protocol): -# def __call__(self, names: list[str]) -> type_include_object: -# ... -# -# -# @dataclass -# class type_metadata: -# include_schemas: type_include_schemas -# target_metadata: List[MetaData] -# schemas: List[str] -# -# -# class AlembicEnvSetup: -# _connection_string: str -# _bases: list[MetaBaseType] -# -# def __init__(self, connection_string: str, bases: list[MetaBaseType]) -> None: -# self._connection_string = connection_string -# self._bases = bases -# -# @lru_cache(maxsize=1) -# def get_config(self): -# # this is the Alembic Config object, which provides -# # access to the values within the .ini file in use. -# config = context.config -# -# # Interpret the config file for Python logging. -# # This line sets up loggers basically. -# if config.config_file_name is not None: -# # raise Exception("Config file is missing.") -# fileConfig(config.config_file_name) -# -# config.set_main_option("sqlalchemy.url", self._connection_string) -# -# return config -# -# @lru_cache(maxsize=1) -# def get_metadata(self): -# # add your model's MetaData object here -# # for 'autogenerate' support -# # from myapp import mymodel -# # target_metadata = mymodel.Base.metadata -# # from CAP.database.models.CAP import Base -# # from CAP.database.models.identity import IdentityBase -# # from CAP.database.models.platform import PlatformBase -# -# def include_schemas(names: List[str]): -# def include_object( -# object: Table, -# name: str, -# type_: str, -# reflected: Any, -# compare_to: Any, -# ): -# if type_ == "table": -# return object.schema in names -# return True -# -# return include_object -# -# target_metadata = [base.metadata for base in self._bases] -# schemas: list[str] = [] -# for base in self._bases: -# schema = DialectHelper.get_schema(base) -# if schema is not None: -# schemas.append(schema) -# -# return type_metadata(include_schemas, target_metadata, schemas) -# -# def _configure_context(self, connection: Connection | Connectable | Engine): -# metadata = self.get_metadata() -# target_metadata = metadata.target_metadata -# include_schemas = metadata.include_schemas -# schemas = metadata.schemas -# -# if connection.engine is not None and connection.engine.name == "sqlite": -# context.configure( -# connection=cast(Connection, connection), -# target_metadata=target_metadata, -# compare_type=True, -# include_schemas=True, -# include_object=include_schemas(schemas), -# render_as_batch=True, -# ) -# else: -# context.configure( -# connection=cast(Connection, connection), -# target_metadata=target_metadata, -# compare_type=True, -# include_schemas=True, -# include_object=include_schemas(schemas), -# ) -# -# def _run_migrations(self, connection: Connection | Connectable | Engine): -# if connection.engine is None: -# raise Exception( -# "SQLAlchemy Session is not bound to an engine. This is not supported." -# ) -# -# metadata = self.get_metadata() -# schemas = metadata.schemas -# with context.begin_transaction(): -# try: -# if connection.engine.name == "postgresql": -# _ = connection.execute( -# f"SET search_path TO {','.join(schemas)},public;" -# ) -# context.run_migrations() -# except ProgrammingError as error: -# # This occurs when downgrading from the very last version -# # because the `alembic_version` table is dropped. The exception -# # can be safely ignored because the migration commits the transaction -# # before the failure, and there is nothing left for Alembic to do. -# if not ( -# type(error.orig) is UndefinedTable -# and "DELETE FROM alembic_version" in error.statement -# ): -# raise -# -# def run_migrations_offline(self, connection_string: str): -# """Run migrations in 'offline' mode. -# -# This configures the context with just a URL -# and not an Engine, though an Engine is acceptable -# here as well. By skipping the Engine creation -# we don't even need a DBAPI to be available. -# -# Calls to context.execute() here emit the given string to the -# script output. -# -# """ -# -# config = self.get_config() -# metadata = self.get_metadata() -# target_metadata = metadata.target_metadata -# include_schemas = metadata.include_schemas -# schemas = metadata.schemas -# -# url = config.get_main_option("sqlalchemy.url") -# context.configure( -# url=url, -# target_metadata=target_metadata, -# literal_binds=True, -# dialect_opts={"paramstyle": "named"}, -# compare_type=True, -# include_schemas=True, -# include_object=include_schemas(schemas), -# ) -# -# with context.begin_transaction(): -# context.run_migrations() -# -# def run_migrations_online(self, connection_string: str): -# """Run migrations in 'online' mode. -# -# In this scenario we need to create an Engine -# and associate a connection with the context. -# -# """ -# config = self.get_config() -# -# connectable: Connectable = cast(dict[Any, Any], config.attributes).get( -# "connection", None -# ) -# -# if connectable: -# self._configure_context(connectable) -# self._run_migrations(connectable) -# else: -# connectable = engine_from_config( -# config.get_section(config.config_ini_section), -# prefix="sqlalchemy.", -# poolclass=pool.NullPool, -# ) -# -# with connectable.connect() as connection: -# self._configure_context(connection) -# self._run_migrations(connection) -# diff --git a/src/database/Ligare/database/migrations/alembic/env_setup.py b/src/database/Ligare/database/migrations/alembic/env_setup.py index 1f642bf5..c55272de 100644 --- a/src/database/Ligare/database/migrations/alembic/env_setup.py +++ b/src/database/Ligare/database/migrations/alembic/env_setup.py @@ -172,7 +172,7 @@ def _run_migrations( try: if connection.engine.name == "postgresql": _ = connection.execute( - f"SET search_path TO {','.join(schemas)},public;" + f"SET search_path TO {','.join(schemas + ['public'])};" ) context.run_migrations() except ProgrammingError as error: diff --git a/src/database/Ligare/database/schema/__init__.py b/src/database/Ligare/database/schema/__init__.py index be2eed47..9c2c6b1c 100644 --- a/src/database/Ligare/database/schema/__init__.py +++ b/src/database/Ligare/database/schema/__init__.py @@ -7,7 +7,7 @@ _dialect_type_map = {"sqlite": SQLiteDialect, "postgresql": PostgreSQLDialect} -def get_type_from_dialect(dialect: Dialect): +def get_type_from_dialect(dialect: Dialect) -> PostgreSQLDialect | SQLiteDialect: if not _dialect_type_map.get(dialect.name): raise ValueError( f"Unexpected dialect with name `{dialect.name}`. Expected one of {list(_dialect_type_map.keys())}." @@ -16,6 +16,6 @@ def get_type_from_dialect(dialect: Dialect): return _dialect_type_map[dialect.name](dialect) -def get_type_from_op(op: Operations): +def get_type_from_op(op: Operations) -> PostgreSQLDialect | SQLiteDialect: dialect: Dialect = op.get_bind().dialect return get_type_from_dialect(dialect) diff --git a/src/database/Ligare/database/schema/dialect.py b/src/database/Ligare/database/schema/dialect.py index 4b64fe56..f6db4cef 100644 --- a/src/database/Ligare/database/schema/dialect.py +++ b/src/database/Ligare/database/schema/dialect.py @@ -8,7 +8,7 @@ class DialectBase(ABC): supports_schemas: bool = False @staticmethod - def get_schema(meta: MetaBase): + def get_schema(meta: MetaBase) -> str | None: table_args = hasattr(meta, "__table_args__") and meta.__table_args__ or None if isinstance(table_args, dict): @@ -21,7 +21,7 @@ def iterate_table_names( dialect: "DialectBase", schema_tables: dict[MetaBase, list[str]], table_name_callback: TableNameCallback, - ): + ) -> None: """ Call `table_name_callback` once for every table in every Base. @@ -40,13 +40,13 @@ def iterate_table_names( dialect_schema, full_table_name, base_table, meta_base ) - def get_dialect_schema(self, meta: MetaBase): + def get_dialect_schema(self, meta: MetaBase) -> str | None: if self.supports_schemas: return DialectBase.get_schema(meta) return None - def get_full_table_name(self, table_name: str, meta: MetaBase): + def get_full_table_name(self, table_name: str, meta: MetaBase) -> str: """ If the dialect supports schemas, then the table name does not have the schema prepended. In dialects that don't support schemas, e.g., SQLite, the table name has the schema prepended. diff --git a/src/database/test/unit/migrations/alembic/test_ligare_alembic.py b/src/database/test/unit/migrations/alembic/test_ligare_alembic.py index 00f6b2e4..b9e0bc6e 100644 --- a/src/database/test/unit/migrations/alembic/test_ligare_alembic.py +++ b/src/database/test/unit/migrations/alembic/test_ligare_alembic.py @@ -76,6 +76,7 @@ def test__LigareAlembic__passes_through_to_alembic_with_default_config_when_not_ ) ligare_alembic = LigareAlembic(None, MagicMock()) + ligare_alembic._write_ligare_alembic_config = MagicMock() # pyright: ignore[reportPrivateUsage] ligare_alembic.run() assert alembic_main.called diff --git a/src/identity/Ligare/identity/config.py b/src/identity/Ligare/identity/config.py index a4b94f39..b5ef49d1 100644 --- a/src/identity/Ligare/identity/config.py +++ b/src/identity/Ligare/identity/config.py @@ -3,6 +3,7 @@ from Ligare.programming.config import AbstractConfig from pydantic import BaseModel from pydantic.config import ConfigDict +from typing_extensions import override class SSOSettingsConfig(BaseModel): @@ -33,6 +34,10 @@ def __init__(self, **data: Any): if self.protocol == "SAML2": self.settings = SAML2Config(**model_data) + @override + def post_load(self) -> None: + return super().post_load() + protocol: str = "SAML2" # the static field allows Pydantic to store # values from a dictionary @@ -40,4 +45,8 @@ def __init__(self, **data: Any): class Config(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + sso: SSOConfig diff --git a/src/identity/Ligare/identity/dependency_injection.py b/src/identity/Ligare/identity/dependency_injection.py index 06e72cfe..fff58eec 100644 --- a/src/identity/Ligare/identity/dependency_injection.py +++ b/src/identity/Ligare/identity/dependency_injection.py @@ -10,11 +10,7 @@ class SSOModule(Module): - def __init__(self): # , metadata: str, settings: AnyDict) -> None: - """ - metadata can be XML or a URL - """ - super().__init__() + pass class SAML2Module(SSOModule): diff --git a/src/platform/Ligare/platform/feature_flag/__init__.py b/src/platform/Ligare/platform/feature_flag/__init__.py index 082c8fe8..3e8ec3d2 100644 --- a/src/platform/Ligare/platform/feature_flag/__init__.py +++ b/src/platform/Ligare/platform/feature_flag/__init__.py @@ -1,5 +1,15 @@ from .caching_feature_flag_router import CachingFeatureFlagRouter +from .caching_feature_flag_router import FeatureFlag as CacheFeatureFlag from .db_feature_flag_router import DBFeatureFlagRouter -from .feature_flag_router import FeatureFlagRouter +from .db_feature_flag_router import FeatureFlag as DBFeatureFlag +from .feature_flag_router import FeatureFlag, FeatureFlagChange, FeatureFlagRouter -__all__ = ("FeatureFlagRouter", "CachingFeatureFlagRouter", "DBFeatureFlagRouter") +__all__ = ( + "FeatureFlagRouter", + "CachingFeatureFlagRouter", + "DBFeatureFlagRouter", + "FeatureFlag", + "CacheFeatureFlag", + "DBFeatureFlag", + "FeatureFlagChange", +) diff --git a/src/platform/Ligare/platform/feature_flag/caching_feature_flag_router.py b/src/platform/Ligare/platform/feature_flag/caching_feature_flag_router.py index bf86cb84..c6be383c 100644 --- a/src/platform/Ligare/platform/feature_flag/caching_feature_flag_router.py +++ b/src/platform/Ligare/platform/feature_flag/caching_feature_flag_router.py @@ -1,11 +1,19 @@ from logging import Logger +from typing import Generic, Sequence, cast +from injector import inject from typing_extensions import override -from .feature_flag_router import FeatureFlagRouter +from .feature_flag_router import FeatureFlag as FeatureFlagBaseData +from .feature_flag_router import FeatureFlagChange, FeatureFlagRouter, TFeatureFlag -class CachingFeatureFlagRouter(FeatureFlagRouter): +class FeatureFlag(FeatureFlagBaseData): + pass + + +class CachingFeatureFlagRouter(Generic[TFeatureFlag], FeatureFlagRouter[TFeatureFlag]): + @inject def __init__(self, logger: Logger) -> None: self._logger: Logger = logger self._feature_flags: dict[str, bool] = {} @@ -35,7 +43,7 @@ def _validate_name(self, name: str): raise ValueError("`name` parameter is required and cannot be empty.") @override - def set_feature_is_enabled(self, name: str, is_enabled: bool) -> None: + def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChange: """ Enables or disables a feature flag in the in-memory dictionary of feature flags. @@ -43,17 +51,23 @@ def set_feature_is_enabled(self, name: str, is_enabled: bool) -> None: :param str name: The feature flag to check. :param bool is_enabled: Whether the feature flag is to be enabled or disabled. + :return FeatureFlagChange: An object representing the previous and new values of the changed feature flag. """ self._validate_name(name) if type(is_enabled) != bool: raise TypeError("`is_enabled` must be a boolean.") - self._notify_change(name, is_enabled, self._feature_flags.get(name)) + old_enabled_value = self._feature_flags.get(name) + self._notify_change(name, is_enabled, old_enabled_value) self._feature_flags[name] = is_enabled - return super().set_feature_is_enabled(name, is_enabled) + _ = super().set_feature_is_enabled(name, is_enabled) + + return FeatureFlagChange( + name=name, old_value=old_enabled_value, new_value=is_enabled + ) @override def feature_is_enabled(self, name: str, default: bool = False) -> bool: @@ -79,3 +93,32 @@ def feature_is_cached(self, name: str): self._validate_name(name) return name in self._feature_flags + + @override + def _create_feature_flag(self, name: str, enabled: bool) -> TFeatureFlag: + return cast(TFeatureFlag, FeatureFlag(name, enabled)) + + @override + def get_feature_flags( + self, names: list[str] | None = None + ) -> Sequence[TFeatureFlag]: + """ + Get all feature flags and their status. + + :params list[str] | None names: Get only the flags contained in this list. + :return tuple[TFeatureFlag]: An immutable sequence (a tuple) of feature flags. + If `names` is `None` this sequence contains _all_ feature flags in the cache. Otherwise, the list is filtered. + """ + if names is None: + return tuple( + self._create_feature_flag(name=key, enabled=value) + for key, value in self._feature_flags.items() + ) + else: + return tuple( + ( + self._create_feature_flag(name=key, enabled=value) + for key, value in self._feature_flags.items() + if key in names + ) + ) diff --git a/src/platform/Ligare/platform/feature_flag/database/__init__.py b/src/platform/Ligare/platform/feature_flag/database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/platform/Ligare/platform/feature_flag/database/migrate.py b/src/platform/Ligare/platform/feature_flag/database/migrate.py new file mode 100644 index 00000000..ad24a2d6 --- /dev/null +++ b/src/platform/Ligare/platform/feature_flag/database/migrate.py @@ -0,0 +1,79 @@ +"""Add feature flags""" + +import sqlalchemy as sa +from alembic.operations.base import Operations +from Ligare.database.schema import get_type_from_op +from sqlalchemy import false + +from ..db_feature_flag_router import FeatureFlagTable + + +# fmt: off +def upgrade(op: Operations): + dialect = get_type_from_op(op) + dialect_supports_schemas = dialect.supports_schemas + get_full_table_name = dialect.get_full_table_name + get_dialect_schema = dialect.get_dialect_schema + timestamp_sql = dialect.timestamp_sql + + base_schema_name = get_dialect_schema(FeatureFlagTable) # pyright: ignore[reportArgumentType] + if dialect_supports_schemas: + if base_schema_name: + op.execute(f'CREATE SCHEMA IF NOT EXISTS {base_schema_name}') + + full_table_name = get_full_table_name('feature_flag', FeatureFlagTable) # pyright: ignore[reportArgumentType] + _ = op.create_table(full_table_name, + sa.Column('ctime', sa.DateTime(), server_default=sa.text(timestamp_sql), nullable=False), + sa.Column('mtime', sa.DateTime(), server_default=sa.text(timestamp_sql), nullable=False), + sa.Column('name', sa.Unicode(), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=True, server_default=false()), + sa.Column('description', sa.Unicode(), nullable=False), + sa.PrimaryKeyConstraint('name'), + schema=base_schema_name + ) + + if dialect.DIALECT_NAME == 'postgresql': + op.execute(""" +CREATE OR REPLACE FUNCTION func_update_mtime() +RETURNS TRIGGER LANGUAGE 'plpgsql' AS +' +BEGIN + NEW.mtime = now(); + RETURN NEW; +END; +';""") + + op.execute(f""" +CREATE TRIGGER trigger_update_mtime +BEFORE UPDATE ON {base_schema_name}.{full_table_name} +FOR EACH ROW EXECUTE PROCEDURE func_update_mtime();""") + + else: + op.execute(f""" +CREATE TRIGGER IF NOT EXISTS '{full_table_name}.trigger_update_mtime' +BEFORE UPDATE +ON '{full_table_name}' +FOR EACH ROW +BEGIN + UPDATE '{full_table_name}' + SET mtime=CURRENT_TIMESTAMP + WHERE name = NEW.name; +END;""") + + +def downgrade(op: Operations): + dialect = get_type_from_op(op) + get_full_table_name = dialect.get_full_table_name + get_dialect_schema = dialect.get_dialect_schema + + base_schema_name = get_dialect_schema(FeatureFlagTable) # pyright: ignore[reportArgumentType] + full_table_name = get_full_table_name('feature_flag', FeatureFlagTable) # pyright: ignore[reportArgumentType] + + if dialect.DIALECT_NAME == 'postgresql': + op.execute(f'DROP SCHEMA {base_schema_name} CASCADE;') + op.execute("DROP FUNCTION func_update_mtime;") + op.execute('COMMIT;') + + else: + op.execute(f"""DROP TRIGGER '{full_table_name}.trigger_update_mtime';""") + op.drop_table(full_table_name, schema=base_schema_name) diff --git a/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py b/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py index a3aa493d..0137bff8 100644 --- a/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py +++ b/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py @@ -1,18 +1,29 @@ from abc import ABC +from dataclasses import dataclass from logging import Logger -from typing import Type, cast, overload +from typing import Sequence, Type, TypeVar, cast, overload from injector import inject -from sqlalchemy import Boolean, Column, Unicode +from sqlalchemy import Boolean, Column, String, Unicode from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm.session import Session from typing_extensions import override from .caching_feature_flag_router import CachingFeatureFlagRouter +from .feature_flag_router import FeatureFlag as FeatureFlagBaseData +from .feature_flag_router import FeatureFlagChange -class FeatureFlag(ABC): +@dataclass(frozen=True) +class FeatureFlag(FeatureFlagBaseData): + description: str | None + + +TFeatureFlag = TypeVar("TFeatureFlag", bound=FeatureFlag, covariant=True) + + +class FeatureFlagTableBase(ABC): def __init__( # pyright: ignore[reportMissingSuperCall] self, /, @@ -21,7 +32,7 @@ def __init__( # pyright: ignore[reportMissingSuperCall] enabled: bool | None = False, ) -> None: raise NotImplementedError( - f"`{FeatureFlag.__class__.__name__}` should only be used for type checking." + f"`{FeatureFlagTableBase.__class__.__name__}` should only be used for type checking." ) __tablename__: str @@ -31,7 +42,9 @@ def __init__( # pyright: ignore[reportMissingSuperCall] class FeatureFlagTable: - def __new__(cls, base: Type[DeclarativeMeta]) -> type[FeatureFlag]: + __table_args__ = {"schema": "platform"} + + def __new__(cls, base: Type[DeclarativeMeta]) -> type[FeatureFlagTableBase]: class _FeatureFlag(base): """ A feature flag. @@ -53,32 +66,34 @@ class _FeatureFlag(base): def __repr__(self) -> str: return "" % (self.name) - return cast(type[FeatureFlag], _FeatureFlag) + return cast(type[FeatureFlagTableBase], _FeatureFlag) -class DBFeatureFlagRouter(CachingFeatureFlagRouter): - _feature_flag: type[FeatureFlag] +class DBFeatureFlagRouter(CachingFeatureFlagRouter[TFeatureFlag]): + # The SQLAlchemy table type used for querying from the type[FeatureFlag] database table + _feature_flag: type[FeatureFlagTableBase] + # The SQLAlchemy session used for connecting to and querying the database _session: Session @inject def __init__( - self, feature_flag: type[FeatureFlag], session: Session, logger: Logger + self, feature_flag: type[FeatureFlagTableBase], session: Session, logger: Logger ) -> None: self._feature_flag = feature_flag self._session = session super().__init__(logger) @override - def set_feature_is_enabled(self, name: str, is_enabled: bool): + def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChange: """ Enable or disable a feature flag in the database. This method caches the value of `is_enabled` for the specified feature flag unless saving to the database fails. - name: The feature flag to check. - - is_enabled: Whether the feature flag is to be enabled or disabled. + :param str name: The feature flag to check. + :param bool is_enabled: Whether the feature flag is to be enabled or disabled. + :return FeatureFlagChange: An object representing the previous and new values of the changed feature flag. """ if type(name) != str: @@ -87,11 +102,11 @@ def set_feature_is_enabled(self, name: str, is_enabled: bool): if not name: raise ValueError("`name` parameter is required and cannot be empty.") - feature_flag: FeatureFlag + feature_flag: FeatureFlagTableBase try: feature_flag = ( self._session.query(self._feature_flag) - .filter(cast(Column[Unicode], self._feature_flag.name) == name) + .filter(self._feature_flag.name == name) .one() ) except NoResultFound as e: @@ -99,9 +114,14 @@ def set_feature_is_enabled(self, name: str, is_enabled: bool): f"The feature flag `{name}` does not exist. It must be created before being accessed." ) from e + old_enabled_value = cast(bool | None, feature_flag.enabled) feature_flag.enabled = is_enabled self._session.commit() - super().set_feature_is_enabled(name, is_enabled) + _ = super().set_feature_is_enabled(name, is_enabled) + + return FeatureFlagChange( + name=name, old_value=old_enabled_value, new_value=is_enabled + ) @overload def feature_is_enabled(self, name: str, default: bool = False) -> bool: ... @@ -131,7 +151,7 @@ def feature_is_enabled( feature_flag = ( self._session.query(self._feature_flag) - .filter(cast(Column[Unicode], self._feature_flag.name) == name) + .filter(self._feature_flag.name == name) .one_or_none() ) @@ -143,6 +163,55 @@ def feature_is_enabled( is_enabled = cast(bool, feature_flag.enabled) - super().set_feature_is_enabled(name, is_enabled) + _ = super().set_feature_is_enabled(name, is_enabled) return is_enabled + + @override + def _create_feature_flag( + self, name: str, enabled: bool, description: str | None = None + ) -> TFeatureFlag: + parent_feature_flag = super()._create_feature_flag(name, enabled) + return cast( + TFeatureFlag, + FeatureFlag( + parent_feature_flag.name, parent_feature_flag.enabled, description + ), + ) + + @override + def get_feature_flags( + self, names: list[str] | None = None + ) -> Sequence[TFeatureFlag]: + """ + Get all feature flags and their status from the database. + This methods updates the cache to the values retrieved from the database. + + :param list[str] | None names: Get only the flags contained in this list. + :return tuple[TFeatureFlag]: An immutable sequence (a tuple) of feature flags. + If `names` is `None` this sequence contains _all_ feature flags in the database. Otherwise, the list is filtered. + """ + db_feature_flags: list[FeatureFlagTableBase] + if names is None: + db_feature_flags = self._session.query(self._feature_flag).all() + else: + db_feature_flags = ( + self._session.query(self._feature_flag) + .filter(cast(Column[String], self._feature_flag.name).in_(names)) + .all() + ) + + feature_flags = tuple( + self._create_feature_flag( + name=cast(str, feature_flag.name), + enabled=cast(bool, feature_flag.enabled), + description=cast(str, feature_flag.description), + ) + for feature_flag in db_feature_flags + ) + + # cache the feature flags + for feature_flag in feature_flags: + _ = super().set_feature_is_enabled(feature_flag.name, feature_flag.enabled) + + return feature_flags diff --git a/src/platform/Ligare/platform/feature_flag/feature_flag_router.py b/src/platform/Ligare/platform/feature_flag/feature_flag_router.py index d66ac562..71f368b2 100644 --- a/src/platform/Ligare/platform/feature_flag/feature_flag_router.py +++ b/src/platform/Ligare/platform/feature_flag/feature_flag_router.py @@ -1,7 +1,25 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Sequence, TypeVar -class FeatureFlagRouter(ABC): +@dataclass(frozen=True) +class FeatureFlag: + name: str + enabled: bool + + +@dataclass(frozen=True) +class FeatureFlagChange: + name: str + old_value: bool | None + new_value: bool | None + + +TFeatureFlag = TypeVar("TFeatureFlag", bound=FeatureFlag, covariant=True) + + +class FeatureFlagRouter(Generic[TFeatureFlag], ABC): """ The base feature flag router. All feature flag routers should extend this class. @@ -21,12 +39,13 @@ def _notify_change( """ @abstractmethod - def set_feature_is_enabled(self, name: str, is_enabled: bool) -> None: + def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChange: """ Enable or disable a feature flag. :param str name: The name of the feature flag. :param bool is_enabled: If `True`, the feature is enabled. If `False`, the feature is disabled. + :return FeatureFlagChange: An object representing the previous and new values of the changed feature flag. """ @abstractmethod @@ -38,3 +57,28 @@ def feature_is_enabled(self, name: str, default: bool = False) -> bool: :param bool default: A default value to return for cases where a feature flag may not exist. Defaults to False. :return bool: If `True`, the feature is enabled. If `False`, the feature is disabled. """ + + @abstractmethod + def _create_feature_flag(self, name: str, enabled: bool) -> FeatureFlag: + """ + Subclasses should override this in order to instantiate type-safe + instances of `TFeatureFlag` to any other `FeatureFlag` subclasses + in the type hierarchy. + + :param str name: _description_ + :param bool enabled: _description_ + :return TFeatureFlag: An instance of `TFeatureFlag` + """ + + @abstractmethod + def get_feature_flags( + self, names: list[str] | None = None + ) -> Sequence[TFeatureFlag]: + """ + Get all feature flags and whether they are enabled. + If `names` is not `None`, this only returns the enabled state of the flags in the list. + + :param list[str] | None names: Get only the flags contained in this list. + :return tuple[TFeatureFlag]: An immutable sequence (a tuple) of feature flags. + If `names` is `None` this sequence contains _all_ feature flags. Otherwise, the list is filtered. + """ diff --git a/src/platform/test/unit/feature_flags/test_caching_feature_flag_router.py b/src/platform/test/unit/feature_flags/test_caching_feature_flag_router.py index 14e708f7..952c6dab 100644 --- a/src/platform/test/unit/feature_flags/test_caching_feature_flag_router.py +++ b/src/platform/test/unit/feature_flags/test_caching_feature_flag_router.py @@ -5,15 +5,17 @@ from Ligare.platform.feature_flag.caching_feature_flag_router import ( CachingFeatureFlagRouter, ) +from Ligare.platform.feature_flag.feature_flag_router import FeatureFlag from mock import MagicMock from pytest import LogCaptureFixture _FEATURE_FLAG_TEST_NAME = "foo_feature" +_FEATURE_FLAG_LOGGER_NAME = "FeatureFlagLogger" def test__feature_is_enabled__disallows_empty_name(): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) with pytest.raises(ValueError): _ = caching_feature_flag_router.feature_is_enabled("") @@ -21,34 +23,34 @@ def test__feature_is_enabled__disallows_empty_name(): @pytest.mark.parametrize("name", [0, False, True, {}, [], (0,)]) def test__feature_is_enabled__disallows_non_string_names(name: Any): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) with pytest.raises(TypeError): _ = caching_feature_flag_router.feature_is_enabled(name) def test__set_feature_is_enabled__disallows_empty_name(): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) with pytest.raises(ValueError): - caching_feature_flag_router.set_feature_is_enabled("", False) + _ = caching_feature_flag_router.set_feature_is_enabled("", False) @pytest.mark.parametrize("name", [0, False, True, {}, [], (0,)]) def test__set_feature_is_enabled__disallows_non_string_names(name: Any): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) with pytest.raises(TypeError): - caching_feature_flag_router.set_feature_is_enabled(name, False) + _ = caching_feature_flag_router.set_feature_is_enabled(name, False) @pytest.mark.parametrize("value", [None, "", "False", "True", 0, 1, -1, {}, [], (0,)]) def test__set_feature_is_enabled__disallows_non_bool_values(value: Any): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) with pytest.raises(TypeError) as e: _ = caching_feature_flag_router.set_feature_is_enabled( @@ -60,18 +62,51 @@ def test__set_feature_is_enabled__disallows_non_bool_values(value: Any): @pytest.mark.parametrize("value", [True, False]) def test__set_feature_is_enabled__sets_correct_value(value: bool): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) - caching_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, value) + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) is_enabled = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) assert is_enabled == value +@pytest.mark.parametrize("value", [True, False]) +def test__set_feature_is_enabled__returns_correct_initial_state(value: bool): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + initial_change = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) + + assert initial_change.name == _FEATURE_FLAG_TEST_NAME + assert initial_change.old_value is None + assert initial_change.new_value == value + + +@pytest.mark.parametrize("value", [True, False]) +def test__set_feature_is_enabled__returns_correct_new_and_old_state(value: bool): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + initial_change = caching_feature_flag_router.set_feature_is_enabled( # pyright: ignore[reportUnusedVariable] + _FEATURE_FLAG_TEST_NAME, value + ) + change = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, not value + ) + + assert change.name == _FEATURE_FLAG_TEST_NAME + assert change.old_value == value + assert change.new_value == (not value) + + def test__feature_is_enabled__defaults_to_false_when_flag_does_not_exist(): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) is_enabled = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) @@ -80,10 +115,12 @@ def test__feature_is_enabled__defaults_to_false_when_flag_does_not_exist(): @pytest.mark.parametrize("value", [True, False]) def test__set_feature_is_enabled__caches_new_flag(value: bool): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) - caching_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, value) + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) assert _FEATURE_FLAG_TEST_NAME in caching_feature_flag_router._feature_flags # pyright: ignore[reportPrivateUsage] is_enabled = caching_feature_flag_router._feature_flags.get(_FEATURE_FLAG_TEST_NAME) # pyright: ignore[reportPrivateUsage] assert is_enabled == value @@ -91,12 +128,14 @@ def test__set_feature_is_enabled__caches_new_flag(value: bool): @pytest.mark.parametrize("value", [True, False]) def test__feature_is_enabled__uses_cache(value: bool): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) mock_dict = MagicMock() caching_feature_flag_router._feature_flags = mock_dict # pyright: ignore[reportPrivateUsage] - caching_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, value) + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) _ = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) _ = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) @@ -106,16 +145,18 @@ def test__feature_is_enabled__uses_cache(value: bool): @pytest.mark.parametrize("enable", [True, False]) def test__set_feature_is_enabled__resets_cache_when_flag_enable_is_set(enable: bool): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) - caching_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, enable + ) _ = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) first_value = caching_feature_flag_router.feature_is_enabled( _FEATURE_FLAG_TEST_NAME ) - caching_feature_flag_router.set_feature_is_enabled( + _ = caching_feature_flag_router.set_feature_is_enabled( _FEATURE_FLAG_TEST_NAME, not enable ) _ = caching_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) @@ -132,10 +173,12 @@ def test__set_feature_is_enabled__notifies_when_setting_new_flag( value: bool, caplog: LogCaptureFixture, ): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) - caching_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, value) + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) assert f"Setting new feature flag '{_FEATURE_FLAG_TEST_NAME}' to `{value}`." in { record.msg for record in caplog.records @@ -173,14 +216,101 @@ def test__set_feature_is_enabled__notifies_when_changing_flag( expected_log_msg: str, caplog: LogCaptureFixture, ): - logger = logging.getLogger("FeatureFlagLogger") - caching_feature_flag_router = CachingFeatureFlagRouter(logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) - caching_feature_flag_router.set_feature_is_enabled( + _ = caching_feature_flag_router.set_feature_is_enabled( _FEATURE_FLAG_TEST_NAME, first_value ) - caching_feature_flag_router.set_feature_is_enabled( + _ = caching_feature_flag_router.set_feature_is_enabled( _FEATURE_FLAG_TEST_NAME, second_value ) assert expected_log_msg in {record.msg for record in caplog.records} + + +@pytest.mark.parametrize("value", [True, False]) +def test__feature_is_cached__correctly_determines_whether_value_is_cached(value: bool): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + _ = caching_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) + + feature_is_cached = caching_feature_flag_router.feature_is_cached( + _FEATURE_FLAG_TEST_NAME + ) + + assert feature_is_cached + + +def test___create_feature_flag__returns_correct_TFeatureFlag(): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + feature_flag = caching_feature_flag_router._create_feature_flag( # pyright: ignore[reportPrivateUsage] + _FEATURE_FLAG_TEST_NAME, True + ) + + assert isinstance(feature_flag, FeatureFlag) + + +@pytest.mark.parametrize("value", [True, False]) +def test___create_feature_flag__creates_correct_TFeatureFlag(value: bool): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + feature_flag = caching_feature_flag_router._create_feature_flag( # pyright: ignore[reportPrivateUsage] + _FEATURE_FLAG_TEST_NAME, value + ) + + assert feature_flag.name == _FEATURE_FLAG_TEST_NAME + assert feature_flag.enabled == value + + +def test__get_feature_flags__returns_empty_sequence_when_no_flags_exist(): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + feature_flags = caching_feature_flag_router.get_feature_flags() + + assert isinstance(feature_flags, tuple) + assert not feature_flags + + +def test__get_feature_flags__returns_all_existing_flags(): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + FLAG_COUNT = 4 + + for i in range(FLAG_COUNT): + _ = caching_feature_flag_router.set_feature_is_enabled( + f"{_FEATURE_FLAG_TEST_NAME}{i}", (i % 2) == 0 + ) + + feature_flags = caching_feature_flag_router.get_feature_flags() + + assert isinstance(feature_flags, tuple) + assert len(feature_flags) == FLAG_COUNT + + +def test__get_feature_flags__returns_filtered_flags(): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + caching_feature_flag_router = CachingFeatureFlagRouter[FeatureFlag](logger) + + FLAG_COUNT = 4 + FILTERED_FLAG_NAME = f"{_FEATURE_FLAG_TEST_NAME}2" + + for i in range(FLAG_COUNT): + _ = caching_feature_flag_router.set_feature_is_enabled( + f"{_FEATURE_FLAG_TEST_NAME}{i}", (i % 2) == 0 + ) + + feature_flags = caching_feature_flag_router.get_feature_flags([FILTERED_FLAG_NAME]) + + assert isinstance(feature_flags, tuple) + assert len(feature_flags) == 1 + assert feature_flags[0].name == FILTERED_FLAG_NAME + assert feature_flags[0].enabled # (2 % 2) == 0 ## - True diff --git a/src/platform/test/unit/feature_flags/test_db_feature_flag_router.py b/src/platform/test/unit/feature_flags/test_db_feature_flag_router.py index 7288fbeb..392a5233 100644 --- a/src/platform/test/unit/feature_flags/test_db_feature_flag_router.py +++ b/src/platform/test/unit/feature_flags/test_db_feature_flag_router.py @@ -6,6 +6,7 @@ from Ligare.database.dependency_injection import ScopedSessionModule from Ligare.platform.feature_flag.db_feature_flag_router import ( DBFeatureFlagRouter, + FeatureFlag, FeatureFlagTable, ) from Ligare.programming.dependency_injection import ConfigModule @@ -15,6 +16,7 @@ _FEATURE_FLAG_TEST_NAME = "foo_feature" _FEATURE_FLAG_TEST_DESCRIPTION = "foo description" +_FEATURE_FLAG_LOGGER_NAME = "FeatureFlagLogger" class PlatformMetaBase(DeclarativeMeta): @@ -28,7 +30,7 @@ class PlatformBase(object): from Ligare.database.testing.config import inmemory_database_config PlatformBase = declarative_base(cls=PlatformBase, metaclass=PlatformMetaBase) -FeatureFlag = FeatureFlagTable(PlatformBase) +FeatureFlagTableBase = FeatureFlagTable(PlatformBase) @pytest.fixture() @@ -45,19 +47,24 @@ def feature_flag_session(): return session -def _create_feature_flag(session: Session): +def _create_feature_flag( + session: Session, name: str | None = None, description: str | None = None +): session.add( - FeatureFlag( - name=_FEATURE_FLAG_TEST_NAME, description=_FEATURE_FLAG_TEST_DESCRIPTION + FeatureFlagTableBase( + name=_FEATURE_FLAG_TEST_NAME if name is None else name, + description=_FEATURE_FLAG_TEST_DESCRIPTION + if description is None + else description, ) ) session.commit() def test__feature_is_enabled__defaults_to_false(feature_flag_session: Session): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) _create_feature_flag(feature_flag_session) @@ -72,9 +79,9 @@ def test__feature_is_enabled__uses_default_when_flag_does_not_exist( default: bool, feature_flag_session: Session, ): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) is_enabled = db_feature_flag_router.feature_is_enabled( @@ -85,9 +92,9 @@ def test__feature_is_enabled__uses_default_when_flag_does_not_exist( def test__feature_is_enabled__disallows_empty_name(feature_flag_session: Session): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) with pytest.raises(ValueError): @@ -98,9 +105,9 @@ def test__feature_is_enabled__disallows_empty_name(feature_flag_session: Session def test__feature_is_enabled__disallows_non_string_names( name: Any, feature_flag_session: Session ): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) with pytest.raises(TypeError): @@ -110,50 +117,50 @@ def test__feature_is_enabled__disallows_non_string_names( def test__set_feature_is_enabled__fails_when_flag_does_not_exist( feature_flag_session: Session, ): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) with pytest.raises(LookupError): - db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, True) + _ = db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, True) def test__set_feature_is_enabled__disallows_empty_name(feature_flag_session: Session): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) _create_feature_flag(feature_flag_session) with pytest.raises(ValueError): - db_feature_flag_router.set_feature_is_enabled("", False) + _ = db_feature_flag_router.set_feature_is_enabled("", False) @pytest.mark.parametrize("name", [0, False, True, {}, [], (0,)]) def test__set_feature_is_enabled__disallows_non_string_names( name: Any, feature_flag_session: Session ): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) with pytest.raises(TypeError): - db_feature_flag_router.set_feature_is_enabled(name, False) + _ = db_feature_flag_router.set_feature_is_enabled(name, False) @pytest.mark.parametrize("enable", [True, False]) def test__set_feature_is_enabled__sets_correct_value( enable: bool, feature_flag_session: Session ): - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter( - FeatureFlag, feature_flag_session, logger + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger ) _create_feature_flag(feature_flag_session) - db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) + _ = db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) is_enabled = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) assert is_enabled == enable @@ -165,10 +172,12 @@ def test__set_feature_is_enabled__caches_flags(enable: bool, mocker: MockerFixtu session_query_mock = mocker.patch("sqlalchemy.orm.session.Session.query") session_mock.query = session_query_mock - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter(FeatureFlag, session_mock, logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, session_mock, logger + ) - db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) + _ = db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) _ = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) _ = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) @@ -187,8 +196,10 @@ def test__feature_is_enabled__checks_cache( "Ligare.platform.feature_flag.caching_feature_flag_router.CachingFeatureFlagRouter.set_feature_is_enabled" ) - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter(FeatureFlag, session_mock, logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, session_mock, logger + ) _ = db_feature_flag_router.feature_is_enabled( _FEATURE_FLAG_TEST_NAME, False, check_cache[0] @@ -210,8 +221,10 @@ def test__feature_is_enabled__sets_cache( ) feature_is_enabled_mock.return_value = True - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter(FeatureFlag, session_mock, logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, session_mock, logger + ) _ = db_feature_flag_router.feature_is_enabled( _FEATURE_FLAG_TEST_NAME, False, check_cache[0] @@ -228,17 +241,283 @@ def test__set_feature_is_enabled__resets_cache_when_flag_enable_is_set( session_query_mock = mocker.patch("sqlalchemy.orm.session.Session.query") session_mock.query = session_query_mock - logger = logging.getLogger("FeatureFlagLogger") - db_feature_flag_router = DBFeatureFlagRouter(FeatureFlag, session_mock, logger) + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, session_mock, logger + ) - db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) + _ = db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, enable) _ = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) first_value = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) - db_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, not enable) + _ = db_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, not enable + ) _ = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) second_value = db_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) assert session_query_mock.call_count == 2 assert first_value == enable assert second_value == (not enable) + + +@pytest.mark.parametrize("value", [True, False]) +def test__set_feature_is_enabled__returns_correct_initial_state( + value: bool, feature_flag_session: Session +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + _create_feature_flag(feature_flag_session) + + initial_change = db_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, value + ) + + assert initial_change.name == _FEATURE_FLAG_TEST_NAME + assert initial_change.old_value == False + assert initial_change.new_value == value + + +@pytest.mark.parametrize("value", [True, False]) +def test__set_feature_is_enabled__returns_correct_new_and_old_state( + value: bool, feature_flag_session: Session +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + _create_feature_flag(feature_flag_session) + + initial_change = db_feature_flag_router.set_feature_is_enabled( # pyright: ignore[reportUnusedVariable] + _FEATURE_FLAG_TEST_NAME, value + ) + change = db_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, not value + ) + + assert change.name == _FEATURE_FLAG_TEST_NAME + assert change.old_value == value + assert change.new_value == (not value) + + +def test___create_feature_flag__returns_correct_TFeatureFlag( + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + feature_flag = db_feature_flag_router._create_feature_flag( # pyright: ignore[reportPrivateUsage] + _FEATURE_FLAG_TEST_NAME, True, _FEATURE_FLAG_TEST_DESCRIPTION + ) + + assert isinstance(feature_flag, FeatureFlag) + + +@pytest.mark.parametrize("value", [True, False]) +def test___create_feature_flag__creates_correct_TFeatureFlag( + value: bool, + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + feature_flag = db_feature_flag_router._create_feature_flag( # pyright: ignore[reportPrivateUsage] + _FEATURE_FLAG_TEST_NAME, value, _FEATURE_FLAG_TEST_DESCRIPTION + ) + + assert feature_flag.name == _FEATURE_FLAG_TEST_NAME + assert feature_flag.description == _FEATURE_FLAG_TEST_DESCRIPTION + assert feature_flag.enabled == value + + +def test__get_feature_flags__returns_empty_sequence_when_no_flags_exist( + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + feature_flags = db_feature_flag_router.get_feature_flags() + + assert isinstance(feature_flags, tuple) + assert not feature_flags + + +@pytest.mark.parametrize( + "added_flags", + [ + {f"{_FEATURE_FLAG_TEST_NAME}2": True}, + {f"{_FEATURE_FLAG_TEST_NAME}2": False}, + { + f"{_FEATURE_FLAG_TEST_NAME}1": True, + f"{_FEATURE_FLAG_TEST_NAME}2": False, + f"{_FEATURE_FLAG_TEST_NAME}3": True, + f"{_FEATURE_FLAG_TEST_NAME}4": False, + f"{_FEATURE_FLAG_TEST_NAME}5": True, + f"{_FEATURE_FLAG_TEST_NAME}6": False, + }, + ], +) +def test__get_feature_flags__returns_all_existing_flags( + added_flags: dict[str, bool], + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + for flag_name, enabled in added_flags.items(): + _create_feature_flag(feature_flag_session, flag_name) + _ = db_feature_flag_router.set_feature_is_enabled(flag_name, enabled) + + feature_flags = db_feature_flag_router.get_feature_flags() + feature_flags_dict = { + feature_flag.name: (feature_flag.enabled, feature_flag.description) + for feature_flag in feature_flags + } + + assert isinstance(feature_flags, tuple) + assert len(feature_flags) == len(added_flags) + + for filtered_flag in feature_flags_dict: + assert filtered_flag in feature_flags_dict + assert feature_flags_dict[filtered_flag][0] == added_flags[filtered_flag] + assert feature_flags_dict[filtered_flag][1] == _FEATURE_FLAG_TEST_DESCRIPTION + + +@pytest.mark.parametrize( + "added_flags,filtered_flags", + [ + [{f"{_FEATURE_FLAG_TEST_NAME}2": True}, [f"{_FEATURE_FLAG_TEST_NAME}2"]], + [{f"{_FEATURE_FLAG_TEST_NAME}2": False}, [f"{_FEATURE_FLAG_TEST_NAME}2"]], + [ + { + f"{_FEATURE_FLAG_TEST_NAME}1": True, + f"{_FEATURE_FLAG_TEST_NAME}2": False, + f"{_FEATURE_FLAG_TEST_NAME}3": True, + f"{_FEATURE_FLAG_TEST_NAME}4": False, + f"{_FEATURE_FLAG_TEST_NAME}5": True, + f"{_FEATURE_FLAG_TEST_NAME}6": False, + }, + [f"{_FEATURE_FLAG_TEST_NAME}1", f"{_FEATURE_FLAG_TEST_NAME}2"], + ], + ], +) +def test__get_feature_flags__returns_filtered_flags( + added_flags: dict[str, bool], + filtered_flags: list[str], + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + for flag_name, enabled in added_flags.items(): + _create_feature_flag(feature_flag_session, flag_name) + _ = db_feature_flag_router.set_feature_is_enabled(flag_name, enabled) + + feature_flags = db_feature_flag_router.get_feature_flags(filtered_flags) + feature_flags_dict = { + feature_flag.name: (feature_flag.enabled, feature_flag.description) + for feature_flag in feature_flags + } + + assert isinstance(feature_flags, tuple) + assert len(feature_flags) == len(filtered_flags) + + for filtered_flag in filtered_flags: + assert filtered_flag in feature_flags_dict + assert feature_flags_dict[filtered_flag][0] == added_flags[filtered_flag] + assert feature_flags_dict[filtered_flag][1] == _FEATURE_FLAG_TEST_DESCRIPTION + + +@pytest.mark.parametrize( + "added_flags,filtered_flags", + [ + [{f"{_FEATURE_FLAG_TEST_NAME}2": True}, [f"{_FEATURE_FLAG_TEST_NAME}1"]], + [{f"{_FEATURE_FLAG_TEST_NAME}2": False}, [f"{_FEATURE_FLAG_TEST_NAME}3"]], + [ + { + f"{_FEATURE_FLAG_TEST_NAME}1": True, + f"{_FEATURE_FLAG_TEST_NAME}2": False, + f"{_FEATURE_FLAG_TEST_NAME}3": True, + f"{_FEATURE_FLAG_TEST_NAME}4": False, + f"{_FEATURE_FLAG_TEST_NAME}5": True, + f"{_FEATURE_FLAG_TEST_NAME}6": False, + }, + [f"{_FEATURE_FLAG_TEST_NAME}7", f"{_FEATURE_FLAG_TEST_NAME}8"], + ], + ], +) +def test__get_feature_flags__returns_empty_sequence_when_flags_exist_but_filtered_list_items_do_not_exist( + added_flags: dict[str, bool], + filtered_flags: list[str], + feature_flag_session: Session, +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + for flag_name, enabled in added_flags.items(): + _create_feature_flag(feature_flag_session, flag_name) + _ = db_feature_flag_router.set_feature_is_enabled(flag_name, enabled) + + feature_flags = db_feature_flag_router.get_feature_flags(filtered_flags) + + assert isinstance(feature_flags, tuple) + assert len(feature_flags) == 0 + + +@pytest.mark.parametrize( + "added_flags", + [ + {f"{_FEATURE_FLAG_TEST_NAME}2": True}, + {f"{_FEATURE_FLAG_TEST_NAME}2": False}, + { + f"{_FEATURE_FLAG_TEST_NAME}1": True, + f"{_FEATURE_FLAG_TEST_NAME}2": False, + f"{_FEATURE_FLAG_TEST_NAME}3": True, + f"{_FEATURE_FLAG_TEST_NAME}4": False, + f"{_FEATURE_FLAG_TEST_NAME}5": True, + f"{_FEATURE_FLAG_TEST_NAME}6": False, + }, + ], +) +def test__get_feature_flags__caches_all_existing_flags_when_queried( + added_flags: dict[str, bool], feature_flag_session: Session, mocker: MockerFixture +): + logger = logging.getLogger(_FEATURE_FLAG_LOGGER_NAME) + db_feature_flag_router = DBFeatureFlagRouter[FeatureFlag]( + FeatureFlagTableBase, feature_flag_session, logger + ) + + for flag_name, enabled in added_flags.items(): + _create_feature_flag(feature_flag_session, flag_name) + _ = db_feature_flag_router.set_feature_is_enabled(flag_name, enabled) + + cache_mock = mocker.patch( + "Ligare.platform.feature_flag.caching_feature_flag_router.CachingFeatureFlagRouter.set_feature_is_enabled", + autospec=True, + ) + + _ = db_feature_flag_router.get_feature_flags() + + call_args_dict: dict[str, bool] = { + call.args[1]: call.args[2] for call in cache_mock.call_args_list + } + + # CachingFeatureFlagRouter.set_feature_is_enabled should be called + # once for every feature flag retrieved from the database + assert cache_mock.call_count == len(added_flags) + for flag_name, enabled in added_flags.items(): + assert call_args_dict[flag_name] == enabled diff --git a/src/platform/test/unit/feature_flags/test_feature_flag_router.py b/src/platform/test/unit/feature_flags/test_feature_flag_router.py index f6e92d5e..be89c1ce 100644 --- a/src/platform/test/unit/feature_flags/test_feature_flag_router.py +++ b/src/platform/test/unit/feature_flags/test_feature_flag_router.py @@ -1,21 +1,36 @@ import ast import inspect +from typing import Sequence -from Ligare.platform.feature_flag.feature_flag_router import FeatureFlagRouter +from Ligare.platform.feature_flag.feature_flag_router import ( + FeatureFlag, + FeatureFlagChange, + FeatureFlagRouter, +) from typing_extensions import override _FEATURE_FLAG_TEST_NAME = "foo_feature" -class TestFeatureFlagRouter(FeatureFlagRouter): +class TestFeatureFlagRouter(FeatureFlagRouter[FeatureFlag]): @override - def set_feature_is_enabled(self, name: str, is_enabled: bool) -> None: + def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChange: return super().set_feature_is_enabled(name, is_enabled) @override def feature_is_enabled(self, name: str, default: bool = False) -> bool: return super().feature_is_enabled(name, default) + @override + def _create_feature_flag(self, name: str, enabled: bool) -> FeatureFlag: + return super()._create_feature_flag(name, enabled) + + @override + def get_feature_flags( + self, names: list[str] | None = None + ) -> Sequence[FeatureFlag]: + return super().get_feature_flags(names) + class NotifyingFeatureFlagRouter(TestFeatureFlagRouter): def __init__(self) -> None: @@ -47,7 +62,9 @@ def test___notify_change__is_a_noop(): def test__set_feature_is_enabled__does_not_notify(): notifying_feature_flag_router = NotifyingFeatureFlagRouter() - notifying_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, True) + _ = notifying_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, True + ) assert notifying_feature_flag_router.notification_count == 0 @@ -63,7 +80,9 @@ def test__feature_is_enabled__does_not_notify(): def test__feature_is_enabled__does_not_notify_after_flag_is_set(): notifying_feature_flag_router = NotifyingFeatureFlagRouter() - notifying_feature_flag_router.set_feature_is_enabled(_FEATURE_FLAG_TEST_NAME, True) + _ = notifying_feature_flag_router.set_feature_is_enabled( + _FEATURE_FLAG_TEST_NAME, True + ) notifying_feature_flag_router.notification_count = 0 _ = notifying_feature_flag_router.feature_is_enabled(_FEATURE_FLAG_TEST_NAME) diff --git a/src/programming/Ligare/programming/collections/dict.py b/src/programming/Ligare/programming/collections/dict.py index 97ab78d2..cf11cc9e 100644 --- a/src/programming/Ligare/programming/collections/dict.py +++ b/src/programming/Ligare/programming/collections/dict.py @@ -1,8 +1,11 @@ from __future__ import annotations -from typing import Any, Union +from typing import Any, TypeVar, Union AnyDict = dict[Any, Union[Any, "AnyDict"]] +TKey = TypeVar("TKey") +TValue = TypeVar("TValue") +NestedDict = dict[TKey, Union[TValue, "NestedDict"]] def merge(a: AnyDict, b: AnyDict, skip_existing: bool = False): diff --git a/src/programming/Ligare/programming/config/__init__.py b/src/programming/Ligare/programming/config/__init__.py index 85557827..8020d68c 100644 --- a/src/programming/Ligare/programming/config/__init__.py +++ b/src/programming/Ligare/programming/config/__init__.py @@ -8,24 +8,44 @@ ConfigBuilderStateError, NotEndsWithConfigError, ) - -TConfig = TypeVar("TConfig") +from typing_extensions import Self class AbstractConfig(abc.ABC): - pass + @abc.abstractmethod + def post_load(self) -> None: + pass + + +TConfig = TypeVar("TConfig", bound=AbstractConfig) + +from collections import deque class ConfigBuilder(Generic[TConfig]): _root_config: type[TConfig] | None = None - _configs: list[type[AbstractConfig]] | None = None + _configs: deque[type[AbstractConfig]] | None = None - def with_root_config(self, config: "type[TConfig]"): - self._root_config = config + def with_root_config(self, config_type: type[TConfig]) -> Self: + self._root_config = config_type return self - def with_configs(self, configs: list[type[AbstractConfig]]): - self._configs = configs + def with_configs(self, configs: list[type[AbstractConfig]] | None) -> Self: + if configs is None: + return self + + if self._configs is None: + self._configs = deque(configs) + else: + self._configs.extend(configs) + + return self + + def with_config(self, config_type: type[AbstractConfig]) -> Self: + if self._configs is None: + self._configs = deque() + + self._configs.append(config_type) return self def build(self) -> type[TConfig]: @@ -37,16 +57,23 @@ def build(self) -> type[TConfig]: "Cannot build a config without any base config types specified." ) - _new_type_base = self._root_config if self._root_config else object + def test_type_name(config_type: type[AbstractConfig]): + if not config_type.__name__.endswith("Config"): + raise NotEndsWithConfigError( + f"Class name '{config_type.__name__}' is not a valid config class. The name must end with 'Config'" + ) + + _new_type_base = ( + self._root_config if self._root_config else self._configs.popleft() + ) + + test_type_name(_new_type_base) attrs: dict[Any, Any] = {} annotations: dict[str, Any] = {} for config in self._configs: - if not config.__name__.endswith("Config"): - raise NotEndsWithConfigError( - f"Class name '{config.__name__}' is not a valid config class. The name must end with 'Config'" - ) + test_type_name(config) config_name = config.__name__[: config.__name__.rindex("Config")].lower() annotations[config_name] = config @@ -73,4 +100,7 @@ def load_config( config_dict = merge(config_dict, config_overrides) config = config_type(**config_dict) + + config.post_load() + return config diff --git a/src/programming/Ligare/programming/patterns/dependency_injection.py b/src/programming/Ligare/programming/patterns/dependency_injection.py index 40be53c3..1a5393ce 100644 --- a/src/programming/Ligare/programming/patterns/dependency_injection.py +++ b/src/programming/Ligare/programming/patterns/dependency_injection.py @@ -3,6 +3,7 @@ from typing import Callable, TypeVar from injector import Binder, Module, Provider +from Ligare.programming.config import AbstractConfig from typing_extensions import override @@ -40,3 +41,12 @@ def __init__( def configure(self, binder: Binder) -> None: for interface, to in self._registrations.items(): binder.bind(interface, to) + + +from abc import ABC, abstractmethod + + +class ConfigurableModule(Module, ABC): + @staticmethod + @abstractmethod + def get_config_type() -> type[AbstractConfig]: ... diff --git a/src/programming/test/unit/test_config.py b/src/programming/test/unit/test_config.py index 998e0fa0..152ae3ae 100644 --- a/src/programming/test/unit/test_config.py +++ b/src/programming/test/unit/test_config.py @@ -6,28 +6,39 @@ ) from pydantic import BaseModel from pytest_mock import MockerFixture +from typing_extensions import override class FooConfig(BaseModel): - value: str - other_value: bool = False + foo_value: str + foo_other_value: bool = False class BarConfig(BaseModel): - value: str + bar_value: str class BazConfig(BaseModel, AbstractConfig): - value: str + @override + def post_load(self) -> None: + return super().post_load() + + baz_value: str class TestConfig(BaseModel, AbstractConfig): - foo: FooConfig = FooConfig(value="xyz") + @override + def post_load(self) -> None: + return super().post_load() + + foo: FooConfig = FooConfig(foo_value="xyz") bar: BarConfig | None = None class InvalidConfigClass(BaseModel, AbstractConfig): - pass + @override + def post_load(self) -> None: + return super().post_load() def test__Config__load_config__reads_toml_file(mocker: MockerFixture): @@ -38,29 +49,29 @@ def test__Config__load_config__reads_toml_file(mocker: MockerFixture): def test__Config__load_config__initializes_section_config_value(mocker: MockerFixture): - fake_config_dict = {"foo": {"value": "abc123"}} + fake_config_dict = {"foo": {"foo_value": "abc123"}} _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) config = load_config(TestConfig, "foo.toml") - assert config.foo.value == "abc123" + assert config.foo.foo_value == "abc123" def test__Config__load_config__initializes_section_config(mocker: MockerFixture): - fake_config_dict = {"bar": {"value": "abc123"}} + fake_config_dict = {"bar": {"bar_value": "abc123"}} _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) config = load_config(TestConfig, "foo.toml") assert config.bar is not None - assert config.bar.value == "abc123" + assert config.bar.bar_value == "abc123" def test__Config__load_config__applies_overrides(mocker: MockerFixture): - fake_config_dict = {"foo": {"value": "abc123"}} - override_config_dict = {"foo": {"value": "XYZ"}} + fake_config_dict = {"foo": {"foo_value": "abc123"}} + override_config_dict = {"foo": {"foo_value": "XYZ"}} _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) config = load_config(TestConfig, "foo.toml", override_config_dict) - assert config.foo.value == override_config_dict["foo"]["value"] + assert config.foo.foo_value == override_config_dict["foo"]["foo_value"] def test__ConfigBuilder__build__raises_error_when_no_root_config_and_no_section_configs_specified(): @@ -70,20 +81,20 @@ def test__ConfigBuilder__build__raises_error_when_no_root_config_and_no_section_ def test__ConfigBuilder__build__raises_error_when_section_class_name_is_invalid(): - config_builder = ConfigBuilder[TestConfig]() + config_builder = ConfigBuilder[InvalidConfigClass()]() _ = config_builder.with_configs([InvalidConfigClass]) with pytest.raises(NotEndsWithConfigError): _ = config_builder.build() -def test__ConfigBuilder__build__uses_object_as_root_config_when_no_root_config_specified(): - config_builder = ConfigBuilder[TestConfig]() +def test__ConfigBuilder__build__uses_first_config_as_root_config_when_no_root_config_specified(): + config_builder = ConfigBuilder[BazConfig]() _ = config_builder.with_configs([BazConfig]) config_type = config_builder.build() assert TestConfig not in config_type.__mro__ - assert BazConfig not in config_type.__mro__ - assert hasattr(config_type, "baz") - assert hasattr(config_type(), "baz") + assert BazConfig in config_type.__mro__ + assert "baz_value" in config_type.model_fields + assert hasattr(config_type(baz_value="abc"), "baz_value") def test__ConfigBuilder__build__uses_root_config_when_no_section_configs_specified(): @@ -97,7 +108,7 @@ def test__ConfigBuilder__build__uses_root_config_when_no_section_configs_specifi def test__ConfigBuilder__build__creates_config_type_when_multiple_configs_specified( mocker: MockerFixture, ): - fake_config_dict = {"baz": {"value": "ABC"}} + fake_config_dict = {"baz": {"baz_value": "ABC"}} _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) @@ -114,7 +125,7 @@ def test__ConfigBuilder__build__creates_config_type_when_multiple_configs_specif def test__ConfigBuilder__build__sets_dynamic_config_values_when_multiple_configs_specified( mocker: MockerFixture, ): - fake_config_dict = {"baz": {"value": "ABC"}} + fake_config_dict = {"baz": {"baz_value": "ABC"}} _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) @@ -126,5 +137,5 @@ def test__ConfigBuilder__build__sets_dynamic_config_values_when_multiple_configs assert hasattr(config, "baz") assert getattr(config, "baz") - assert getattr(getattr(config, "baz"), "value") - assert getattr(getattr(config, "baz"), "value") == "ABC" + assert getattr(getattr(config, "baz"), "baz_value") + assert getattr(getattr(config, "baz"), "baz_value") == "ABC" diff --git a/src/programming/test/unit/test_dependency_injection.py b/src/programming/test/unit/test_dependency_injection.py index 3597d5e2..47fd7298 100644 --- a/src/programming/test/unit/test_dependency_injection.py +++ b/src/programming/test/unit/test_dependency_injection.py @@ -1,10 +1,14 @@ from injector import Injector from Ligare.programming.config import AbstractConfig from Ligare.programming.dependency_injection import ConfigModule +from typing_extensions import override def test__ConfigModule__injector_binds_Config_module_to_AbstractConfig_by_default(): - class FooConfig(AbstractConfig): ... + class FooConfig(AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() foo_config = FooConfig() config_module = ConfigModule(foo_config) @@ -15,6 +19,10 @@ class FooConfig(AbstractConfig): ... def test__ConfigModule__injector_binds_configured_Config_module(): class FooConfig(AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + x: int = 123 foo_config = FooConfig() diff --git a/src/web/Ligare/web/application.py b/src/web/Ligare/web/application.py index 3a237d6a..db18405c 100644 --- a/src/web/Ligare/web/application.py +++ b/src/web/Ligare/web/application.py @@ -1,13 +1,18 @@ -""" -Compound Assay Platform Flask application. - -Flask entry point. -""" - import logging +from collections import defaultdict from dataclasses import dataclass from os import environ, path -from typing import Generator, Generic, Optional, TypeVar, cast +from typing import ( + Any, + Generator, + Generic, + Optional, + Protocol, + TypeVar, + cast, + final, + overload, +) import json_logging from connexion import FlaskApp @@ -16,8 +21,18 @@ from injector import Module from lib_programname import get_path_executed_script from Ligare.AWS.ssm import SSMParameters -from Ligare.programming.config import AbstractConfig, ConfigBuilder, load_config +from Ligare.programming.collections.dict import NestedDict +from Ligare.programming.config import ( + AbstractConfig, + ConfigBuilder, + TConfig, + load_config, +) +from Ligare.programming.config.exceptions import ConfigBuilderStateError from Ligare.programming.dependency_injection import ConfigModule +from Ligare.programming.patterns.dependency_injection import ConfigurableModule +from Ligare.web.exception import BuilderBuildError, InvalidBuilderStateError +from typing_extensions import Self, deprecated from .config import Config from .middleware import ( @@ -34,15 +49,34 @@ TApp = Flask | FlaskApp T_app = TypeVar("T_app", bound=TApp) +TAppConfig = TypeVar("TAppConfig", bound=Config) + @dataclass class AppInjector(Generic[T_app]): + """ + Contains an instantiated `T_app` application in `app`, + and its associated `FlaskInjector` IoC container. + + :param T_app Generic: An instance of Flask or FlaskApp. + :param flask_inject FlaskInject: The applications IoC container. + """ + app: T_app flask_injector: FlaskInjector @dataclass class CreateAppResult(Generic[T_app]): + """ + Contains an instantiated Flask application and its + associated application "container." This is either + the same Flask instance, or an OpenAPI application. + + :param flask_app Generic: The Flask application. + :param app_injector AppInjector[T_app]: The application's wrapper and IoC container. + """ + flask_app: Flask app_injector: AppInjector[T_app] @@ -54,6 +88,7 @@ class CreateAppResult(Generic[T_app]): # In Python 3.12 we can use generics in functions, # but we target >= Python 3.10. This is one way # around that limitation. +@deprecated("`App` is deprecated. Use `ApplicationBuilder`.") class App(Generic[T_app]): """ Create a new generic type for the application instance. @@ -62,6 +97,7 @@ class App(Generic[T_app]): T_app: Either `Flask` or `FlaskApp` """ + @deprecated("`App.create` is deprecated. Use `ApplicationBuilder`.") @staticmethod def create( config_filename: str = "config.toml", @@ -79,105 +115,316 @@ def create( """ return cast( CreateAppResult[T_app], - create_app(config_filename, application_configs, application_modules), + _create_app(config_filename, application_configs, application_modules), ) +class UseConfigurationCallback(Protocol[TConfig]): + """ + The callback for configuring an application's configuration. + + :param TConfig Protocol: The AbstractConfig type to be configured. + """ + + def __call__( + self, + config_builder: ConfigBuilder[TConfig], + config_overrides: dict[str, Any], + ) -> "None | ConfigBuilder[TConfig]": + """ + Set up parameters for the application's configuration. + + :param ConfigBuilder[TConfig] config_builder: The ConfigBuilder instance. + :param dict[str, Any] config_overrides: A dictionary of key/values that are applied over all keys that might exist in an instantiated config. + :raises InvalidBuilderStateError: Upon a call to `build()`, the builder is misconfigured. + :raises BuilderBuildError: Upon a call to `build()`, a failure occurred during the instantiation of the configuration. + :raises Exception: Upon a call to `build()`, an unknown error occurred. + :return None | ConfigBuilder[TConfig]: The callback may return `None` or the received `ConfigBuilder` instance so as to support the use of lambdas. This return value is not used. + """ + + +@final +class ApplicationConfigBuilder(Generic[TConfig]): + _DEFAULT_CONFIG_FILENAME: str = "config.toml" + + def __init__(self) -> None: + self._config_value_overrides: dict[str, Any] = {} + self._config_builder: ConfigBuilder[TConfig] = ConfigBuilder[TConfig]() + self._config_filename: str = ApplicationConfigBuilder._DEFAULT_CONFIG_FILENAME + self._use_filename: bool = False + self._use_ssm: bool = False + + def with_config_builder(self, config_builder: ConfigBuilder[TConfig]) -> Self: + self._config_builder = config_builder + return self + + def with_root_config_type(self, config_type: type[TConfig]) -> Self: + _ = self._config_builder.with_root_config(config_type) + return self + + def with_config_types(self, configs: list[type[AbstractConfig]] | None) -> Self: + _ = self._config_builder.with_configs(configs) + return self + + def with_config_type(self, config_type: type[AbstractConfig]) -> Self: + _ = self._config_builder.with_config(config_type) + return self + + def with_config_value_overrides(self, values: dict[str, Any]) -> Self: + self._config_value_overrides = values + return self + + def with_config_filename(self, filename: str) -> Self: + self._config_filename = filename + self._use_filename = True + return self + + def enable_ssm(self, value: bool) -> Self: + """ + Try to load config from AWS SSM. If `use_filename` was configured, + a failed attempt to load from SSM will instead attempt to load from + the configured filename. If `use_filename` is not configured and SSM + fails, an exception is raised. If SSM succeeds, `build` will not + load from the configured filename. + + :param bool value: Whether to use SSM + :return Self: + """ + self._use_ssm = value + return self + + def build(self) -> TConfig | None: + if not (self._use_ssm or self._use_filename): + raise InvalidBuilderStateError( + "Cannot build the application config without either `use_ssm` or `use_filename` having been configured." + ) + + try: + config_type = self._config_builder.build() + except ConfigBuilderStateError as e: + raise BuilderBuildError( + "A root config must be specified using `with_root_config` before calling `build()`." + ) from e + + full_config: TConfig | None = None + SSM_FAIL_ERROR_MSG = "Unable to load configuration. SSM parameter load failed and the builder is configured not to load from a file." + if self._use_ssm: + try: + # requires that aws-ssm.ini exists and is correctly configured + ssm_parameters = SSMParameters() + full_config = ssm_parameters.load_config(config_type) + + if not self._use_filename and full_config is None: + raise BuilderBuildError(SSM_FAIL_ERROR_MSG) + except Exception as e: + if self._use_filename: + logging.getLogger().info("SSM parameter load failed.", exc_info=e) + else: + raise BuilderBuildError(SSM_FAIL_ERROR_MSG) from e + + if self._use_filename and full_config is None: + if self._config_value_overrides: + full_config = load_config( + config_type, self._config_filename, self._config_value_overrides + ) + else: + full_config = load_config(config_type, self._config_filename) + + return full_config + + +class ApplicationConfigBuilderCallback(Protocol[TAppConfig]): + def __call__( + self, + config_builder: ApplicationConfigBuilder[TAppConfig], + ) -> "None | ApplicationConfigBuilder[TAppConfig]": ... + + +@final +class ApplicationBuilder(Generic[T_app]): + def __init__(self) -> None: + self._modules: list[Module | type[Module]] = [] + self._config_overrides: dict[str, Any] = {} + + _APPLICATION_CONFIG_BUILDER_PROPERTY_NAME: str = "__application_config_builder" + + @property + def _application_config_builder(self) -> ApplicationConfigBuilder[Config]: + builder = getattr( + self, ApplicationBuilder._APPLICATION_CONFIG_BUILDER_PROPERTY_NAME, None + ) + + if builder is None: + builder = ApplicationConfigBuilder[Config]() + self._application_config_builder = builder.with_root_config_type(Config) + + return builder + + @_application_config_builder.setter + def _application_config_builder(self, value: ApplicationConfigBuilder[Config]): + setattr( + self, ApplicationBuilder._APPLICATION_CONFIG_BUILDER_PROPERTY_NAME, value + ) + + @overload + def with_module(self, module: Module) -> Self: ... + @overload + def with_module(self, module: type[Module]) -> Self: ... + def with_module(self, module: Module | type[Module]) -> Self: + module_type = type(module) if isinstance(module, Module) else module + if issubclass(module_type, ConfigurableModule): + _ = self._application_config_builder.with_config_type( + module_type.get_config_type() + ) + + self._modules.append(module) + return self + + def with_modules(self, modules: list[Module | type[Module]] | None) -> Self: + if modules is not None: + for module in modules: + _ = self.with_module(module) + return self + + @overload + def use_configuration( + self, + __application_config_builder_callback: ApplicationConfigBuilderCallback[Config], + ) -> Self: + """ + Execute changes to the builder's `ApplicationConfigBuilder[TAppConfig]` instance. + + `__builder_callback` can return `None`, or the instance of `ApplicationConfigBuilder[TAppConfig]` passed to its `config_builder` argument. + This allowance is so lambdas can be used; `ApplicationBuilder[T_app, TAppConfig]` does not use the return value. + """ + ... + + @overload + def use_configuration( + self, __application_config_builder: ApplicationConfigBuilder[Config] + ) -> Self: + """Replace the builder's default `ApplicationConfigBuilder[TAppConfig]` instance, or any instance previously assigned.""" + ... + + def use_configuration( + self, + application_config_builder: ApplicationConfigBuilderCallback[Config] + | ApplicationConfigBuilder[Config], + ) -> Self: + if callable(application_config_builder): + _ = application_config_builder(self._application_config_builder) + else: + self._application_config_builder = application_config_builder + + return self + + def with_flask_app_name(self, value: str | None) -> Self: + self._config_overrides["app_name"] = value + return self + + def with_flask_env(self, value: str | None) -> Self: + self._config_overrides["env"] = value + return self + + def build(self) -> CreateAppResult[T_app]: + config_overrides: NestedDict[str, Any] = defaultdict(dict) + + if ( + override_app_name := self._config_overrides.get("app_name", None) + ) is not None and override_app_name != "": + config_overrides["flask"]["app_name"] = override_app_name + + if ( + override_env := self._config_overrides.get("env", None) + ) is not None and override_env != "": + config_overrides["flask"]["env"] = override_env + + _ = self._application_config_builder.with_config_value_overrides( + config_overrides + ) + config = self._application_config_builder.build() + + if config is None: + raise BuilderBuildError( + "Failed to load the application configuration correctly." + ) + + if config.flask is None: + raise Exception("You must set [flask] in the application configuration.") + + if not config.flask.app_name: + raise Exception( + "You must set the Flask application name in the [flask.app_name] config or FLASK_APP envvar." + ) + + app: T_app + + if config.flask.openapi is not None: + openapi = configure_openapi(config) + app = cast(T_app, openapi) + else: + app = cast(T_app, configure_blueprint_routes(config)) + + register_error_handlers(app) + _ = register_api_request_handlers(app) + _ = register_api_response_handlers(app) + _ = register_context_middleware(app) + + application_modules = [ + ConfigModule(config, type(config)) + for (_, config) in cast( + Generator[tuple[str, AbstractConfig], None, None], config + ) + ] + (self._modules if self._modules else []) + # The `config` module cannot be overridden unless the application + # IoC container is fiddled with. `config` is the instance registered + # to `AbstractConfig`. + modules = application_modules + [ConfigModule(config, Config)] + flask_injector = configure_dependencies(app, application_modules=modules) + + flask_app = app.app if isinstance(app, FlaskApp) else app + return CreateAppResult[T_app]( + flask_app, AppInjector[T_app](app, flask_injector) + ) + + +@deprecated("`create_app` is deprecated. Use `ApplicationBuilder`.") def create_app( config_filename: str = "config.toml", # FIXME should be a list of PydanticDataclass application_configs: list[type[AbstractConfig]] | None = None, application_modules: list[Module | type[Module]] | None = None, - # FIXME eventually should replace with builders - # and configurators so this list of params doesn't - # just grow and grow. - # startup_builder: IStartupBuilder, - # config: Config, ) -> CreateAppResult[TApp]: """ - Do not use this method directly. Instead, use `App[T_app].create()` + Do not use this method directly. Instead, use `App[T_app].create()` or `ApplicationBuilder[TApp, TConfig]()` """ + return _create_app(config_filename, application_configs, application_modules) + + +def _create_app( + config_filename: str = "config.toml", + # FIXME should be a list of PydanticDataclass + application_configs: list[type[AbstractConfig]] | None = None, + application_modules: list[Module | type[Module]] | None = None, +) -> CreateAppResult[TApp]: # set up the default configuration as soon as possible # also required to call before json_logging.config_root_logger() logging.basicConfig(force=True) - config_overrides = {} - if environ.get("FLASK_APP"): - config_overrides["app_name"] = environ["FLASK_APP"] - - if environ.get("FLASK_ENV"): - config_overrides["env"] = environ["FLASK_ENV"] - - config_type = Config - if application_configs is not None: - # fmt: off - config_type = ConfigBuilder[Config]()\ - .with_root_config(Config)\ - .with_configs(application_configs)\ - .build() - # fmt: on - - full_config: Config | None = None - try: - # requires that aws-ssm.ini exists and is correctly configured - ssm_parameters = SSMParameters() - full_config = ssm_parameters.load_config(config_type) - except Exception as e: - logging.getLogger().warning(f"SSM parameter load failed: {e}") - - if full_config is None: - if config_overrides: - full_config = load_config( - config_type, config_filename, {"flask": config_overrides} - ) - else: - full_config = load_config(config_type, config_filename) - - full_config.prepare_env_for_flask() - - if full_config.flask is None: - raise Exception("You must set [flask] in the application configuration.") - - if not full_config.flask.app_name: - raise Exception( - "You must set the Flask application name in the [flask.app_name] config or FLASK_APP envvar." + application_builder = ( + ApplicationBuilder[TApp]() + .with_flask_app_name(environ.get("FLASK_APP", None)) + .with_flask_env(environ.get("FLASK_ENV", None)) + .with_modules(application_modules) + .use_configuration( + lambda config_builder: config_builder.enable_ssm(True) + .with_config_filename(config_filename) + .with_root_config_type(Config) + .with_config_types(application_configs) ) - - app: Flask | FlaskApp - - if full_config.flask.openapi is not None: - openapi = configure_openapi(full_config) - app = openapi - else: - app = configure_blueprint_routes(full_config) - - register_error_handlers(app) - _ = register_api_request_handlers(app) - _ = register_api_response_handlers(app) - _ = register_context_middleware(app) - # register_app_teardown_handlers(app) - - # Register every subconfig as a ConfigModule. - # This will allow subpackages to resolve their own config types, - # allow for type safety against objects of those types. - # Otherwise, they can resolve `AbstractConfig`, but then type - # safety is lost. - # Note that, if any `ConfigModule` is provided in `application_modules`, - # those will override the automatically generated `ConfigModule`s. - application_modules = [ - ConfigModule(config, type(config)) - for (_, config) in cast( - Generator[tuple[str, AbstractConfig], None, None], full_config - ) - ] + (application_modules if application_modules else []) - # The `full_config` module cannot be overridden unless the application - # IoC container is fiddled with. `full_config` is the instance registered - # to `AbstractConfig`. - modules = application_modules + [ConfigModule(full_config, Config)] - flask_injector = configure_dependencies(app, application_modules=modules) - - flask_app = app.app if isinstance(app, FlaskApp) else app - return CreateAppResult(flask_app, AppInjector(app, flask_injector)) + ) + app = application_builder.build() + return app def configure_openapi(config: Config, name: Optional[str] = None): @@ -194,26 +441,10 @@ def configure_openapi(config: Config, name: Optional[str] = None): "OpenAPI configuration is empty. Review the `openapi` section of your application's `config.toml`." ) - ## host configuration set up - ## TODO host/port setup should move into application initialization - ## and not be tied to connexion configuration - # host = "127.0.0.1" - # port = 5000 - ## TODO replace SERVER_NAME with host/port in config - # if environ.get("SERVER_NAME") is not None: - # (host, port_str) = environ["SERVER_NAME"].split(":") - # port = int(port_str) - - # connexion and openapi set up - # openapi_spec_dir: str = "app/swagger/" - # if environ.get("OPENAPI_SPEC_DIR"): - # openapi_spec_dir = environ["OPENAPI_SPEC_DIR"] - exec_dir = _get_exec_dir() connexion_app = FlaskApp( config.flask.app_name, - # TODO support relative OPENAPI_SPEC_DIR and prepend program_dir? specification_dir=exec_dir, # host=host, # port=port, @@ -313,7 +544,6 @@ def _import_blueprint_modules(app: Flask, blueprint_import_subdir: str): # find all Flask blueprints in # the module and register them for module_name, module_var in vars(module).items(): - # TODO why did we allow _blueprint when it's not a Blueprint? if module_name.endswith("_blueprint") or isinstance(module_var, Blueprint): blueprint_modules.append(module_var) diff --git a/src/web/Ligare/web/config.py b/src/web/Ligare/web/config.py index e92696c5..e41c2ebf 100644 --- a/src/web/Ligare/web/config.py +++ b/src/web/Ligare/web/config.py @@ -3,7 +3,9 @@ from typing import Literal from flask.config import Config as FlaskAppConfig +from Ligare.programming.config import AbstractConfig from pydantic import BaseModel +from typing_extensions import override class LoggingConfig(BaseModel): @@ -141,9 +143,6 @@ class ConfigObject: flask_app_config.from_object(ConfigObject) -from Ligare.programming.config import AbstractConfig - - class Config(BaseModel, AbstractConfig): logging: LoggingConfig = LoggingConfig() web: WebConfig = WebConfig() @@ -158,3 +157,7 @@ def update_flask_config(self, flask_app_config: FlaskAppConfig): self.flask._update_flask_config( # pyright: ignore[reportPrivateUsage] flask_app_config ) + + @override + def post_load(self) -> None: + self.prepare_env_for_flask() diff --git a/src/web/Ligare/web/exception.py b/src/web/Ligare/web/exception.py new file mode 100644 index 00000000..a70953e1 --- /dev/null +++ b/src/web/Ligare/web/exception.py @@ -0,0 +1,6 @@ +class InvalidBuilderStateError(Exception): + """The builder's state is invalid and the builder cannot execute `build()`.""" + + +class BuilderBuildError(Exception): + """The builder failed during execution of `build()`.""" diff --git a/src/web/Ligare/web/middleware/feature_flags/__init__.py b/src/web/Ligare/web/middleware/feature_flags/__init__.py new file mode 100644 index 00000000..d9487611 --- /dev/null +++ b/src/web/Ligare/web/middleware/feature_flags/__init__.py @@ -0,0 +1,308 @@ +from dataclasses import dataclass +from functools import wraps +from logging import Logger +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypedDict, TypeVar, cast + +from connexion import FlaskApp, request +from flask import Blueprint, Flask, abort +from injector import Binder, Injector, Module, inject, provider, singleton +from Ligare.platform.feature_flag.caching_feature_flag_router import ( + CachingFeatureFlagRouter, +) +from Ligare.platform.feature_flag.caching_feature_flag_router import ( + FeatureFlag as CachingFeatureFlag, +) +from Ligare.platform.feature_flag.db_feature_flag_router import DBFeatureFlagRouter +from Ligare.platform.feature_flag.db_feature_flag_router import ( + FeatureFlag as DBFeatureFlag, +) +from Ligare.platform.feature_flag.db_feature_flag_router import ( + FeatureFlagTable, + FeatureFlagTableBase, +) +from Ligare.platform.feature_flag.feature_flag_router import ( + FeatureFlag, + FeatureFlagRouter, + TFeatureFlag, +) +from Ligare.programming.config import AbstractConfig +from Ligare.programming.patterns.dependency_injection import ConfigurableModule +from Ligare.web.middleware.sso import login_required +from pydantic import BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send +from typing_extensions import override + + +class FeatureFlagConfig(BaseModel): + api_base_url: str = "/platform" + access_role_name: str | bool | None = None + + +class Config(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + + feature_flag: FeatureFlagConfig + + +class FeatureFlagPatchRequest(TypedDict): + name: str + enabled: bool + + +@dataclass +class FeatureFlagPatch: + name: str + enabled: bool + + +class FeatureFlagRouterModule(ConfigurableModule, Generic[TFeatureFlag]): + def __init__(self, t_feature_flag: type[FeatureFlagRouter[TFeatureFlag]]) -> None: + self._t_feature_flag = t_feature_flag + super().__init__() + + @override + @staticmethod + def get_config_type() -> type[AbstractConfig]: + return Config + + @singleton + @provider + def _provide_feature_flag_router( + self, injector: Injector + ) -> FeatureFlagRouter[FeatureFlag]: + return injector.get(self._t_feature_flag) + + +class DBFeatureFlagRouterModule(FeatureFlagRouterModule[DBFeatureFlag]): + def __init__(self) -> None: + super().__init__(DBFeatureFlagRouter) + + @singleton + @provider + def _provide_db_feature_flag_router( + self, injector: Injector + ) -> FeatureFlagRouter[DBFeatureFlag]: + return injector.get(self._t_feature_flag) + + @singleton + @provider + def _provide_db_feature_flag_router_table_base(self) -> type[FeatureFlagTableBase]: + # FeatureFlagTable is a FeatureFlagTableBase provided through + # SQLAlchemy's declarative meta API + return cast(type[FeatureFlagTableBase], FeatureFlagTable) + + +class CachingFeatureFlagRouterModule(FeatureFlagRouterModule[CachingFeatureFlag]): + def __init__(self) -> None: + super().__init__(CachingFeatureFlagRouter) + + @singleton + @provider + def _provide_caching_feature_flag_router( + self, injector: Injector + ) -> FeatureFlagRouter[CachingFeatureFlag]: + return injector.get(self._t_feature_flag) + + +P = ParamSpec("P") +R = TypeVar("R") + + +@inject +def _get_feature_flag_blueprint(app: Flask, config: FeatureFlagConfig, log: Logger): + feature_flag_blueprint = Blueprint( + "feature_flag", __name__, url_prefix=f"{config.api_base_url}" + ) + + access_role = config.access_role_name + + def _login_required(require_flask_login: bool): + """ + Decorate an API endpoint with the correct flask_login authentication + method given the requirements of the API endpoint. + + require_flask_login is ignored if flask_login has been configured. + + If flask_login has _not_ been configured: + * If require_flask_login is True, a warning is logged and the request is aborted with a 405 error + * If require_flask_login is False, the endpoint function is executed + + :param bool require_flask_login: Determine whether flask_login must be configured for this endpoint to function + :return _type_: _description_ + """ + + def __login_required( + fn: Callable[P, R], + ) -> Callable[P, R]: + authorization_implementation: Callable[..., Any] + + if access_role is False: + authorization_implementation = fn + # None means no roles were specified, but a session is still required + elif access_role is None or access_role is True: + authorization_implementation = login_required(fn) + else: + authorization_implementation = login_required([access_role])(fn) + + @wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not hasattr(app, "login_manager"): + if require_flask_login: + log.warning( + "The Feature Flag module expects flask_login to be configured in order to control access to feature flag modifications. flask_login has not been configured, so the Feature Flag modification API is disabled." + ) + + return abort(405) + else: + return fn(*args, **kwargs) + return authorization_implementation(*args, **kwargs) + + return wrapper + + return __login_required + + @feature_flag_blueprint.route("/feature_flag", methods=("GET",)) + @_login_required(False) + @inject + def feature_flag(feature_flag_router: FeatureFlagRouter[FeatureFlag]): # pyright: ignore[reportUnusedFunction] + request_query_names: list[str] | None = request.query_params.getlist("name") + + feature_flags: Sequence[FeatureFlag] + missing_flags: set[str] | None = None + if request_query_names is None or not request_query_names: + feature_flags = feature_flag_router.get_feature_flags() + elif isinstance(request_query_names, list): # pyright: ignore[reportUnnecessaryIsInstance] + feature_flags = feature_flag_router.get_feature_flags(request_query_names) + missing_flags = set(request_query_names).difference( + set([feature_flag.name for feature_flag in feature_flags]) + ) + else: + raise ValueError("Unexpected type from Flask query parameters.") + + response: dict[str, Any] = {} + problems: list[Any] = [] + + if missing_flags is not None: + for missing_flag in missing_flags: + problems.append({ + "title": "feature flag not found", + "detail": "Queried feature flag does not exist.", + "instance": missing_flag, + "status": 404, + "type": None, + }) + response["problems"] = problems + + elif not feature_flags: + problems.append({ + "title": "No feature flags found", + "detail": "Queried feature flags do not exist.", + "instance": "", + "status": 404, + "type": None, + }) + response["problems"] = problems + + if feature_flags: + response["data"] = feature_flags + return response + else: + return response, 404 + + @feature_flag_blueprint.route("/feature_flag", methods=("PATCH",)) + @_login_required(True) + @inject + async def feature_flag_patch(feature_flag_router: FeatureFlagRouter[FeatureFlag]): # pyright: ignore[reportUnusedFunction] + feature_flags_request: list[FeatureFlagPatchRequest] = await request.json() + + feature_flags = [ + FeatureFlagPatch(name=flag["name"], enabled=flag["enabled"]) + for flag in feature_flags_request + ] + + changes: list[Any] = [] + problems: list[Any] = [] + for flag in feature_flags: + try: + change = feature_flag_router.set_feature_is_enabled( + flag.name, flag.enabled + ) + changes.append(change) + except LookupError: + problems.append({ + "title": "feature flag not found", + "detail": "Feature flag to PATCH does not exist. It must be created first.", + "instance": flag.name, + "status": 404, + "type": None, + }) + + response: dict[str, Any] = {} + + if problems: + response["problems"] = problems + + if changes: + response["data"] = changes + return response + else: + return response, 404 + + return feature_flag_blueprint + + +class FeatureFlagMiddlewareModule(Module): + """ + Enable the use of Feature Flags and a Feature Flag management API. + """ + + @override + def configure(self, binder: Binder) -> None: + super().configure(binder) + + def register_middleware(self, app: FlaskApp): + app.add_middleware(FeatureFlagMiddlewareModule.FeatureFlagMiddleware) + + class FeatureFlagMiddleware: + """ + ASGI middleware for Feature Flags. + + This middleware create a Flask blueprint the enables a Feature Flag management API. + """ + + _app: ASGIApp + + def __init__(self, app: ASGIApp): + super().__init__() + self._app = app + + @inject + async def __call__( + self, + scope: Scope, + receive: Receive, + send: Send, + app: Flask, + injector: Injector, + log: Logger, + ) -> None: + async def wrapped_send(message: Any) -> None: + # Only run during startup of the application + if ( + scope["type"] != "lifespan" + or message["type"] != "lifespan.startup.complete" + or not scope["app"] + ): + return await send(message) + + log.debug("Registering FeatureFlag blueprint.") + app.register_blueprint( + injector.call_with_injection(_get_feature_flag_blueprint) + ) + log.debug("FeatureFlag blueprint registered.") + + return await send(message) + + await self._app(scope, receive, wrapped_send) diff --git a/src/web/Ligare/web/middleware/openapi/__init__.py b/src/web/Ligare/web/middleware/openapi/__init__.py index c173ece2..d6a0ccec 100644 --- a/src/web/Ligare/web/middleware/openapi/__init__.py +++ b/src/web/Ligare/web/middleware/openapi/__init__.py @@ -187,6 +187,7 @@ def _headers_as_dict( def _log_all_api_requests( request: MiddlewareRequestDict, response: MiddlewareResponseDict, + app: Flask, config: Config, log: Logger, ): @@ -216,7 +217,10 @@ def _log_all_api_requests( f"{server.host}:{server.port}", f"{client.host}:{client.port}", "Anonymous" - if isinstance(current_user, AnonymousUserMixin) + if ( + isinstance(current_user, AnonymousUserMixin) + or not hasattr(app, "login_manager") + ) else current_user.get_id(), extra={ "props": { @@ -318,7 +322,13 @@ def __init__(self, app: ASGIApp): @inject async def __call__( - self, scope: Scope, receive: Receive, send: Send, config: Config, log: Logger + self, + scope: Scope, + receive: Receive, + send: Send, + config: Config, + log: Logger, + app: Flask, ) -> None: async def wrapped_send(message: Any) -> None: nonlocal scope @@ -331,7 +341,7 @@ async def wrapped_send(message: Any) -> None: response = cast(MiddlewareResponseDict, scope) request = cast(MiddlewareRequestDict, scope) - _log_all_api_requests(request, response, config, log) + _log_all_api_requests(request, response, app, config, log) return await send(message) diff --git a/src/web/Ligare/web/middleware/openapi/cors.py b/src/web/Ligare/web/middleware/openapi/cors.py index 7bdaa8a1..2e4313e0 100644 --- a/src/web/Ligare/web/middleware/openapi/cors.py +++ b/src/web/Ligare/web/middleware/openapi/cors.py @@ -13,7 +13,7 @@ def register_middleware(self, app: FlaskApp, config: Config): app.add_middleware( CORSMiddleware, position=MiddlewarePosition.BEFORE_EXCEPTION, - allow_origins=cors_config.origins, + allow_origins=cors_config.origins or [], allow_credentials=cors_config.allow_credentials, allow_methods=cors_config.allow_methods, allow_headers=cors_config.allow_headers, diff --git a/src/web/Ligare/web/middleware/sso.py b/src/web/Ligare/web/middleware/sso.py index b4e29003..e46207a9 100644 --- a/src/web/Ligare/web/middleware/sso.py +++ b/src/web/Ligare/web/middleware/sso.py @@ -14,6 +14,7 @@ TypedDict, TypeVar, cast, + overload, ) from urllib.parse import urlparse @@ -37,11 +38,13 @@ from flask_login import logout_user # pyright: ignore[reportUnknownVariableType] from flask_login import current_user from flask_login import login_required as flask_login_required -from injector import Binder, Injector, Module, inject +from injector import Binder, Injector, inject from Ligare.identity.config import Config, SAML2Config, SSOConfig from Ligare.identity.dependency_injection import SAML2Module, SSOModule from Ligare.identity.SAML2 import SAML2Client from Ligare.platform.identity.user_loader import Role, UserId, UserLoader, UserMixin +from Ligare.programming.config import AbstractConfig +from Ligare.programming.patterns.dependency_injection import ConfigurableModule from Ligare.web.config import Config from Ligare.web.encryption import decrypt_flask_cookie from saml2.validate import ( @@ -77,15 +80,80 @@ def __call__(self, user: AuthCheckUser, *args: Any, **kwargs: Any) -> bool: ... R = TypeVar("R") +@overload +def login_required() -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Require a user session for the decorated API. No further requirements are applied. + In effect, this uses `flask_login.login_required` and is used in its usual way. + + This is meant to be used as a decorator with `@login_required`. It is the the equivalent + of using the decorator `@flask_login.login_required`. + + :return Callable[[Callable[P, R]], Callable[P, R]]: Returns the `flask_login.login_required` decorated function. + """ + + +@overload +def login_required(function: Callable[P, R], /) -> Callable[P, R]: + """ + Require a user session for the decorated API. No further requirements are applied. + In effect, this passes along `flask_login.login_required` without modification. + + This can be used as a decorator with `@login_required()`, not `@login_required`, though + its use case is to wrap a function without using the decorator form, e.g., `wrapped_func = login_required(my_func)`. + This is the equivalent of `wrapped_func = flask_login.login_required(my_func)`. + + :return Callable[P, R]: Returns the `flask_login.login_required` wrapped function. + """ + + +@overload +def login_required( + roles: Sequence[Role | str], / +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Require a user session, and require that the user has at least one of the specified roles. + + :param Sequence[Role | str] roles: The list of roles the user can have that will allow them to access the decorated API. + :return Callable[[Callable[P, R]], Callable[P, R]]: Returns the decorated function. + """ + + +@overload def login_required( - roles: Sequence[Role] | Callable[P, R] | Callable[..., Any] | None = None, + roles: Sequence[Role | str], auth_check_override: AuthCheckOverrideCallable, / +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + Require a user session, and require that the user has at least one of the specified roles. + + `auth_check_override` is called to override authorization. If it returns True, the user is considered to have access to the API. + If it returns False, the roles are checked instead, and the user will have access to the API if they have one of the specified roles. + + :param Sequence[Role | str] roles: The list of roles the user can have that will allow them to access the decorated API. + :param AuthCheckOverrideCallable auth_check_override: The method that is called to override authorization. It receives the following parameters: + + * `user` is the current session user + + * `*args` will be any arguments passed without argument keywords. When using `login_required` as a + decorator, this will be an empty tuple. + + * `**kwargs` will be any parameters specified with keywords. When using `login_required` as a decorator, + this will be the parameters passed into the decorated method. + In the case of a Flask API endpoint, for example, this will be all of the endpoint method parameters. + :return Callable[[Callable[P, R]], Callable[P, R]]: _description_ + """ + + +def login_required( + roles: Sequence[Role | str] | Callable[P, R] | Callable[..., Any] | None = None, auth_check_override: AuthCheckOverrideCallable | None = None, -): + /, +) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R] | Callable[..., Any]: """ Require a valid Flask session before calling the decorated function. This method uses the list of `roles` to determine whether the current session user - has any of the roles listed. Alternatively, the use of `auth_check_override` can is used to + has any of the roles listed. Alternatively, the use of `auth_check_override` is used to bypass the role check. If the `auth_check_override` method returns True, the user is considered to have access to the decorated API endpoint. If the `auth_check_override` method returns False, `login_required` falls back to checking `roles`. @@ -107,8 +175,10 @@ def login_required( If `auth_check_override` is a callable, it will be called with the following parameters: * `user` is the current session user + * `*args` will be any arguments passed without argument keywords. When using `login_required` as a decorator, this will be an empty tuple. + * `**kwargs` will be any parameters specified with keywords. When using `login_required` as a decorator, this will be the parameters passed into the decorated method. In the case of a Flask API endpoint, for example, this will be all of the endpoint method parameters. @@ -154,7 +224,9 @@ def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response: # if roles is empty, no roles will intersect. # this means an empty list means "no roles have access" role_intersection = [ - role for role in user.roles if role in (roles or []) + str(role) + for role in user.roles + if (str(role) in ({str(r) for r in roles}) or []) ] if len(role_intersection) == 0: # this should end up raising a 401 exception @@ -439,7 +511,12 @@ def unauthorized(self): raise Unauthorized(response.data, response) -class SAML2MiddlewareModule(Module): +class SAML2MiddlewareModule(ConfigurableModule): # Module): + @override + @staticmethod + def get_config_type() -> type[AbstractConfig]: + return SSOConfig + @override def configure(self, binder: Binder) -> None: binder.install(SAML2Module) diff --git a/src/web/Ligare/web/testing/create_app.py b/src/web/Ligare/web/testing/create_app.py index 5e821823..234ff90e 100644 --- a/src/web/Ligare/web/testing/create_app.py +++ b/src/web/Ligare/web/testing/create_app.py @@ -5,7 +5,6 @@ from contextlib import _GeneratorContextManager # pyright: ignore[reportPrivateUsage] from contextlib import ExitStack from dataclasses import dataclass -from functools import lru_cache from types import ModuleType from typing import ( Any, @@ -22,7 +21,6 @@ import json_logging import pytest -import yaml from _pytest.fixtures import SubRequest from connexion import FlaskApp from flask import Flask, Request, Response, session @@ -36,10 +34,11 @@ from Ligare.platform.dependency_injection import UserLoaderModule from Ligare.platform.identity import Role, User from Ligare.platform.identity.user_loader import TRole, UserId, UserMixin +from Ligare.programming.collections.dict import NestedDict from Ligare.programming.config import AbstractConfig, ConfigBuilder from Ligare.programming.str import get_random_str from Ligare.web.application import ( - App, + ApplicationBuilder, CreateAppResult, FlaskAppResult, OpenAPIAppResult, @@ -303,12 +302,7 @@ def _client_configurable( self, mocker: MockerFixture, app_getter: Callable[ - [ - Config, - MockerFixture, - list[type[AbstractConfig]] | None, - list[Module | type[Module]] | None, - ], + [Config, MockerFixture, TAppInitHook | None], Generator[CreateAppResult[T_app], Any, None], ], client_getter: Callable[ @@ -321,16 +315,7 @@ def _client_getter( client_init_hook: TClientInitHook[T_app] | None = None, app_init_hook: TAppInitHook | None = None, ) -> Generator[ClientInjector[T_flask_client], Any, None]: - application_configs: list[type[AbstractConfig]] | None = None - application_modules: list[Module | type[Module]] | None = None - if app_init_hook is not None: - application_configs = [] - application_modules = [] - app_init_hook(application_configs, application_modules) - - application_result = next( - app_getter(config, mocker, application_configs, application_modules) - ) + application_result = next(app_getter(config, mocker, app_init_hook)) if client_init_hook is not None: client_init_hook(application_result) @@ -397,8 +382,7 @@ def __get_basic_flask_app( self, config: Config, mocker: MockerFixture, - application_configs: list[type[AbstractConfig]] | None = None, - application_modules: list[Module | type[Module]] | None = None, + app_init_hook: TAppInitHook | None = None, ) -> Generator[FlaskAppResult, Any, None]: # prevents the creation of a Connexion application if config.flask is not None: @@ -414,13 +398,27 @@ def __get_basic_flask_app( return_value=MagicMock(load_config=MagicMock(return_value=config)), ) - if application_configs is None: - application_configs = [] - if application_modules is None: - application_modules = [] - application_configs.append(SSOConfig) + application_configs: list[type[AbstractConfig]] | None = [] + application_modules: list[Module | type[Module]] | None = [] + + if app_init_hook is not None: + app_init_hook(application_configs, application_modules) + application_modules.append(SAML2MiddlewareModule) - app = App[Flask].create("config.toml", application_configs, application_modules) + + logging.basicConfig(force=True) + + application_builder = ( + ApplicationBuilder[Flask]() + .with_modules(application_modules) + .use_configuration( + lambda config_builder: config_builder.enable_ssm(True) + .with_config_filename("config.toml") + .with_root_config_type(Config) + .with_config_types(application_configs) + ) + ) + app = application_builder.build() yield app @pytest.fixture() @@ -520,8 +518,7 @@ def _get_real_openapi_app( self, config: Config, mocker: MockerFixture, - application_configs: list[type[AbstractConfig]] | None = None, - application_modules: list[Module | type[Module]] | None = None, + app_init_hook: TAppInitHook | None = None, ) -> Generator[OpenAPIAppResult, Any, None]: # prevents the creation of a Connexion application if config.flask is None or config.flask.openapi is None: @@ -535,24 +532,35 @@ def _get_real_openapi_app( return_value=MagicMock(load_config=MagicMock(return_value=config)), ) - if application_configs is None: - application_configs = [] - if application_modules is None: - application_modules = [] - application_configs.append(SSOConfig) - application_modules.append(SAML2MiddlewareModule) - application_modules.append( - UserLoaderModule( - loader=User, # pyright: ignore[reportArgumentType] - roles=Role, # pyright: ignore[reportArgumentType] - user_table=MagicMock(), # pyright: ignore[reportArgumentType] - role_table=MagicMock(), # pyright: ignore[reportArgumentType] - bases=[], - ) + _application_configs: list[type[AbstractConfig]] = [] + _application_modules = cast( + list[Module | type[Module]], + [ + SAML2MiddlewareModule, + UserLoaderModule( + loader=User, # pyright: ignore[reportArgumentType] + roles=Role, # pyright: ignore[reportArgumentType] + user_table=MagicMock(), # pyright: ignore[reportArgumentType] + role_table=MagicMock(), # pyright: ignore[reportArgumentType] + bases=[], + ), + ], ) - app = App[FlaskApp].create( - "config.toml", application_configs, application_modules + + if app_init_hook is not None: + app_init_hook(_application_configs, _application_modules) + + application_builder = ( + ApplicationBuilder[FlaskApp]() + .with_modules(_application_modules) + .use_configuration( + lambda config_builder: config_builder.enable_ssm(True) + .with_config_filename("config.toml") + .with_root_config_type(Config) + .with_config_types(_application_configs) + ) ) + app = application_builder.build() yield app @pytest.fixture() @@ -655,6 +663,7 @@ def openapi_client( def openapi_client_configurable( self, mocker: MockerFixture ) -> OpenAPIClientInjectorConfigurable: + # FIXME some day json_logging needs to be fixed _ = mocker.patch("Ligare.web.application.json_logging") return self._client_configurable( mocker, self._get_real_openapi_app, self._openapi_client @@ -693,32 +702,51 @@ def _flask_request_getter( return _flask_request_getter - @lru_cache - def _get_openapi_spec(self): - return yaml.safe_load( - """openapi: 3.0.3 -servers: - - url: http://testserver/ - description: Test Application -info: - title: "Test Application" - version: 3.0.3 -paths: - /: - get: - description: "Check whether the application is running." - operationId: "root.get" - parameters: [] - responses: - "200": - content: - application/json: - schema: - type: string - description: "Application is running correctly." - summary: "A simple method that returns 200 as long as the application is running." -""" - ) + # this is the YAML-parsed dictionary from this OpenAPI spec + # openapi: 3.0.3 + # servers: + # - url: http://testserver/ + # description: Test Application + # info: + # title: "Test Application" + # version: 3.0.3 + # paths: + # /: + # get: + # description: "Check whether the application is running." + # operationId: "root.get" + # parameters: [] + # responses: + # "200": + # content: + # application/json: + # schema: + # type: string + # description: "Application is running correctly." + # summary: "A simple method that returns 200 as long as the application is running." + _openapi_spec: NestedDict[str, Any] = { + "info": {"title": "Test Application", "version": "3.0.3"}, + "openapi": "3.0.3", + "paths": { + "/": { + "get": { + "description": "Check whether the application is running.", + "operationId": "root.get", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": {"schema": {"type": "string"}} + }, + "description": "Application is running correctly", + } + }, + "summary": "A simple method that returns 200 as long as the application is running.", + } + } + }, + "servers": [{"description": "Test Application", "url": "http://testserver/"}], + } @pytest.fixture() def openapi_mock_controller(self, request: FixtureRequest, mocker: MockerFixture): @@ -760,7 +788,7 @@ def begin(): if spec_loader_mock is None: spec_loader_mock = mocker.patch( "connexion.spec.Specification._load_spec_from_file", - return_value=self._get_openapi_spec(), + return_value=CreateOpenAPIApp._openapi_spec, ) def end(): @@ -775,6 +803,8 @@ def end(): mock_controller = MockController(begin=begin, end=end) try: + # TODO can this be a context manager instead of requiring + # the explicit begin() call? yield mock_controller finally: mock_controller.end() diff --git a/src/web/test/unit/application/test_create_flask_app.py b/src/web/test/unit/application/test_create_flask_app.py index b5dfae26..1a1c6983 100644 --- a/src/web/test/unit/application/test_create_flask_app.py +++ b/src/web/test/unit/application/test_create_flask_app.py @@ -7,7 +7,8 @@ from flask import Blueprint, Flask from Ligare.programming.config import AbstractConfig from Ligare.programming.str import get_random_str -from Ligare.web.application import App, configure_blueprint_routes +from Ligare.web.application import App # pyright: ignore[reportDeprecated] +from Ligare.web.application import configure_blueprint_routes from Ligare.web.config import Config, FlaskConfig from Ligare.web.testing.create_app import ( CreateFlaskApp, @@ -16,6 +17,7 @@ from mock import MagicMock from pydantic import BaseModel from pytest_mock import MockerFixture +from typing_extensions import override class TestCreateFlaskApp(CreateFlaskApp): @@ -258,7 +260,7 @@ def test__CreateFlaskApp__create_app__loads_config_from_toml( ) toml_filename = f"{TestCreateFlaskApp.test__CreateFlaskApp__create_app__loads_config_from_toml.__name__}-config.toml" - _ = App[Flask].create(config_filename=toml_filename) + _ = App[Flask].create(config_filename=toml_filename) # pyright: ignore[reportDeprecated] assert load_config_mock.called assert load_config_mock.call_args and load_config_mock.call_args[0] assert load_config_mock.call_args[0][1] == toml_filename @@ -281,9 +283,13 @@ def test__CreateFlaskApp__create_app__uses_custom_config_types( _ = mocker.patch("toml.load", return_value=toml_load_result) class CustomConfig(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + foo: str = get_random_str(k=26) - app = App[Flask].create( + app = App[Flask].create( # pyright: ignore[reportDeprecated] config_filename=toml_filename, application_configs=[CustomConfig] ) @@ -317,7 +323,7 @@ def test__CreateFlaskApp__create_app__updates_flask_config_from_envvars( _ = mocker.patch( "Ligare.web.application.load_config", return_value=basic_config ) - _ = App[Flask].create() + _ = App[Flask].create() # pyright: ignore[reportDeprecated] assert object.__getattribute__(basic_config.flask, config_var_name) == var_value @@ -356,9 +362,9 @@ def test__CreateFlaskApp__create_app__requires_application_name( if should_fail: with pytest.raises(Exception): - _ = App[Flask].create() + _ = App[Flask].create() # pyright: ignore[reportDeprecated] else: - _ = App[Flask].create() + _ = App[Flask].create() # pyright: ignore[reportDeprecated] def test__CreateFlaskApp__create_app__configures_appropriate_app_type_based_on_config( self, mocker: MockerFixture @@ -379,6 +385,6 @@ def test__CreateFlaskApp__create_app__configures_appropriate_app_type_based_on_c ) config = Config(flask=FlaskConfig(app_name=app_name)) _ = mocker.patch("Ligare.web.application.load_config", return_value=config) - _ = App[Flask].create(config_filename=toml_filename) + _ = App[Flask].create(config_filename=toml_filename) # pyright: ignore[reportDeprecated] configure_method_mock.assert_called_once_with(config) diff --git a/src/web/test/unit/application/test_create_openapi_app.py b/src/web/test/unit/application/test_create_openapi_app.py index 15c2f384..2c2c109c 100644 --- a/src/web/test/unit/application/test_create_openapi_app.py +++ b/src/web/test/unit/application/test_create_openapi_app.py @@ -5,12 +5,14 @@ from flask import Flask from Ligare.programming.config import AbstractConfig from Ligare.programming.str import get_random_str -from Ligare.web.application import App, configure_openapi +from Ligare.web.application import App # pyright:ignore[reportDeprecated] +from Ligare.web.application import ApplicationBuilder, configure_openapi from Ligare.web.config import Config, FlaskConfig, FlaskOpenApiConfig from Ligare.web.testing.create_app import CreateOpenAPIApp from mock import MagicMock from pydantic import BaseModel from pytest_mock import MockerFixture +from typing_extensions import override class TestCreateOpenAPIApp(CreateOpenAPIApp): @@ -53,7 +55,10 @@ def test__CreateOpenAPIApp__create_app__loads_config_from_toml( ) toml_filename = f"{TestCreateOpenAPIApp.test__CreateOpenAPIApp__create_app__loads_config_from_toml.__name__}-config.toml" - _ = App[Flask].create(config_filename=toml_filename) + application_builder = ApplicationBuilder[Flask]().use_configuration( + lambda config_builder: config_builder.with_config_filename(toml_filename) + ) + _ = application_builder.build() assert load_config_mock.called assert load_config_mock.call_args and load_config_mock.call_args[0] assert load_config_mock.call_args[0][1] == toml_filename @@ -76,9 +81,13 @@ def test__CreateOpenAPIApp__create_app__uses_custom_config_types( _ = mocker.patch("toml.load", return_value=toml_load_result) class CustomConfig(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + foo: str = get_random_str(k=26) - app = App[Flask].create( + app = App[Flask].create( # pyright:ignore[reportDeprecated] config_filename=toml_filename, application_configs=[CustomConfig] ) @@ -112,7 +121,7 @@ def test__CreateOpenAPIApp__create_app__updates_flask_config_from_envvars( _ = mocker.patch( "Ligare.web.application.load_config", return_value=basic_config ) - _ = App[Flask].create() + _ = App[Flask].create() # pyright:ignore[reportDeprecated] assert object.__getattribute__(basic_config.flask, config_var_name) == var_value @@ -151,9 +160,9 @@ def test__CreateOpenAPIApp__create_app__requires_application_name( if should_fail: with pytest.raises(Exception): - _ = App[Flask].create() + _ = App[Flask].create() # pyright:ignore[reportDeprecated] else: - _ = App[Flask].create() + _ = App[Flask].create() # pyright:ignore[reportDeprecated] def test__CreateOpenAPIApp__create_app__configures_appropriate_app_type_based_on_config( self, mocker: MockerFixture @@ -174,6 +183,6 @@ def test__CreateOpenAPIApp__create_app__configures_appropriate_app_type_based_on flask=FlaskConfig(app_name=app_name, openapi=FlaskOpenApiConfig()) ) _ = mocker.patch("Ligare.web.application.load_config", return_value=config) - _ = App[FlaskApp].create(config_filename=toml_filename) + _ = App[FlaskApp].create(config_filename=toml_filename) # pyright:ignore[reportDeprecated] configure_method_mock.assert_called_once_with(config) diff --git a/src/web/test/unit/middleware/test_feature_flags_middleware.py b/src/web/test/unit/middleware/test_feature_flags_middleware.py new file mode 100644 index 00000000..dfa20432 --- /dev/null +++ b/src/web/test/unit/middleware/test_feature_flags_middleware.py @@ -0,0 +1,434 @@ +from dataclasses import dataclass +from enum import auto +from typing import Generic, Sequence, TypeVar + +import pytest +from connexion import FlaskApp +from flask_login import UserMixin +from injector import Module +from Ligare.platform.dependency_injection import UserLoaderModule +from Ligare.platform.feature_flag import FeatureFlag +from Ligare.platform.feature_flag.feature_flag_router import FeatureFlagRouter +from Ligare.platform.identity.user_loader import Role as LoaderRole +from Ligare.programming.config import AbstractConfig +from Ligare.web.application import CreateAppResult, OpenAPIAppResult +from Ligare.web.config import Config +from Ligare.web.middleware.feature_flags import ( + CachingFeatureFlagRouterModule, + FeatureFlagConfig, + FeatureFlagMiddlewareModule, +) +from Ligare.web.testing.create_app import ( + CreateOpenAPIApp, + OpenAPIClientInjectorConfigurable, + OpenAPIMockController, +) +from mock import MagicMock +from pytest_mock import MockerFixture +from typing_extensions import override + + +@dataclass +class UserId: + user_id: int + username: str + + +class Role(LoaderRole): + User = auto() + Administrator = auto() + Operator = auto() + + @staticmethod + def items(): + return Role.__members__.items() + + +TRole = TypeVar("TRole", bound=Role, covariant=True) + + +class User(UserMixin, Generic[TRole]): + """ + Represents the user object stored in a session. + """ + + id: UserId + roles: Sequence[TRole] + + @override + def get_id(self): + """ + Override the UserMixin.get_id so the username is returned instead of `id` (the dataclass) + when `flask_login.login_user` calls this method to assign the + session `_user_id` key. + """ + return str(self.id.username) + + @override + def __init__(self, id: UserId, roles: Sequence[TRole] | None = None): + """ + Create a new user with the given user name or id, and a list of roles. + If roles are not given, an empty list is assigned by default. + """ + super().__init__() + + if roles is None: + roles = [] + + self.id = id + self.roles = roles + + +class TestFeatureFlagsMiddleware(CreateOpenAPIApp): + def _user_session_app_init_hook( + self, + application_configs: list[type[AbstractConfig]], + application_modules: list[Module | type[Module]], + ): + application_modules.append( + UserLoaderModule( + loader=User, # pyright: ignore[reportArgumentType] + roles=Role, + user_table=MagicMock(), # pyright: ignore[reportArgumentType] + role_table=MagicMock(), # pyright: ignore[reportArgumentType] + bases=[], + ) + ) + application_modules.append(CachingFeatureFlagRouterModule) + application_modules.append(FeatureFlagMiddlewareModule()) + + def test__FeatureFlagMiddleware__feature_flag_api_GET_requires_user_session_when_flask_login_is_configured( + self, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + def app_init_hook( + application_configs: list[type[AbstractConfig]], + application_modules: list[Module | type[Module]], + ): + application_modules.append(CachingFeatureFlagRouterModule) + application_modules.append(FeatureFlagMiddlewareModule()) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + app_init_hook=app_init_hook, + ) + ) + + response = app.client.get("/platform/feature_flag") + + assert response.status_code == 401 + + @pytest.mark.parametrize( + "flask_login_is_configured,user_has_session", + [[True, True], [False, True], [False, False]], + ) + def test__FeatureFlagMiddleware__feature_flag_api_GET_gets_feature_flags( + self, + flask_login_is_configured: bool, + user_has_session: bool, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + def app_init_hook( + application_configs: list[type[AbstractConfig]], + application_modules: list[Module | type[Module]], + ): + if not flask_login_is_configured: + application_modules.clear() + application_modules.append(CachingFeatureFlagRouterModule) + application_modules.append(FeatureFlagMiddlewareModule()) + + def client_init_hook(app: CreateAppResult[FlaskApp]): + caching_feature_flag_router = app.app_injector.flask_injector.injector.get( + FeatureFlagRouter[FeatureFlag] + ) + _ = caching_feature_flag_router.set_feature_is_enabled("foo_feature", True) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + client_init_hook=client_init_hook, + app_init_hook=app_init_hook, + ) + ) + + if user_has_session: + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.get("/platform/feature_flag") + else: + response = app.client.get("/platform/feature_flag") + + assert response.status_code == 200 + response_json = response.json() + assert (data := response_json.get("data", None)) is not None + assert len(data) == 1 + assert (name := data[0].get("name", None)) is not None + assert (enabled := data[0].get("enabled", None)) is not None + assert name == "foo_feature" + assert enabled == True + + @pytest.mark.parametrize("has_role", [True, False]) + def test__FeatureFlagMiddleware__feature_flag_api_GET_requires_specified_role_when_flask_login_is_configured( + self, + has_role: bool, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + get_feature_flag_mock = mocker.patch( + "Ligare.web.middleware.feature_flags.CachingFeatureFlagRouter.get_feature_flags", + return_value=[], + ) + + def client_init_hook(app: OpenAPIAppResult): + feature_flag_config = FeatureFlagConfig( + access_role_name="Operator", + api_base_url="/platform", # the default + ) + app.app_injector.flask_injector.injector.binder.bind( + FeatureFlagConfig, to=feature_flag_config + ) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, client_init_hook, self._user_session_app_init_hook + ) + ) + + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + [Role.Operator] if has_role else [], + ): + response = app.client.get("/platform/feature_flag") + + if has_role: + assert response.status_code == 404 + get_feature_flag_mock.assert_called_once() + else: + assert response.status_code == 401 + get_feature_flag_mock.assert_not_called() + + def test__FeatureFlagMiddleware__feature_flag_api_GET_returns_no_feature_flags_when_none_exist( + self, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, app_init_hook=self._user_session_app_init_hook + ) + ) + + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.get("/platform/feature_flag") + + assert response.status_code == 404 + response_json = response.json() + assert (problems := response_json.get("problems", None)) is not None + assert len(problems) == 1 + assert (title := problems[0].get("title", None)) is not None + assert title == "No feature flags found" + + def test__FeatureFlagMiddleware__feature_flag_api_GET_returns_feature_flags_when_they_exist( + self, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + def client_init_hook(app: CreateAppResult[FlaskApp]): + caching_feature_flag_router = app.app_injector.flask_injector.injector.get( + FeatureFlagRouter[FeatureFlag] + ) + _ = caching_feature_flag_router.set_feature_is_enabled("foo_feature", True) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + client_init_hook, + self._user_session_app_init_hook, + ) + ) + + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.get("/platform/feature_flag") + + assert response.status_code == 200 + response_json = response.json() + assert (data := response_json.get("data", None)) is not None + assert len(data) == 1 + assert data[0].get("enabled", None) is True + assert data[0].get("name", None) == "foo_feature" + + @pytest.mark.parametrize( + "query_flags", ["bar_feature", ["foo_feature", "baz_feature"]] + ) + def test__FeatureFlagMiddleware__feature_flag_api_GET_returns_specific_feature_flags_when_they_exist( + self, + query_flags: str | list[str], + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + def client_init_hook(app: CreateAppResult[FlaskApp]): + caching_feature_flag_router = app.app_injector.flask_injector.injector.get( + FeatureFlagRouter[FeatureFlag] + ) + _ = caching_feature_flag_router.set_feature_is_enabled("foo_feature", True) + _ = caching_feature_flag_router.set_feature_is_enabled("bar_feature", False) + _ = caching_feature_flag_router.set_feature_is_enabled("baz_feature", True) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + client_init_hook, + self._user_session_app_init_hook, + ) + ) + + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.get( + "/platform/feature_flag", params={"name": query_flags} + ) + + assert response.status_code == 200 + response_json = response.json() + assert (data := response_json.get("data", None)) is not None + if isinstance(query_flags, str): + assert len(data) == 1 + assert data[0].get("enabled", None) is False + assert data[0].get("name", None) == query_flags + else: + assert len(data) == len(query_flags) + for i, flag in enumerate(query_flags): + assert data[i].get("enabled", None) is True + assert data[i].get("name", None) == flag + + @pytest.mark.parametrize( + "flask_login_is_configured,user_has_session,error_code", + [[True, False, 401], [False, True, 405], [False, False, 405]], + ) + def test__FeatureFlagMiddleware__feature_flag_api_PATCH_requires_user_session_and_flask_login( + self, + flask_login_is_configured: bool, + user_has_session: bool, + error_code: int, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + set_feature_flag_mock = mocker.patch( + "Ligare.web.middleware.feature_flags.CachingFeatureFlagRouter.set_feature_is_enabled", + return_value=[], + ) + + def app_init_hook( + application_configs: list[type[AbstractConfig]], + application_modules: list[Module | type[Module]], + ): + if not flask_login_is_configured: + application_modules.clear() + application_modules.append(CachingFeatureFlagRouterModule) + application_modules.append(FeatureFlagMiddlewareModule()) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + app_init_hook=app_init_hook, + ) + ) + + if user_has_session: + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.patch( + "/platform/feature_flag", + json=[{"name": "foo_feature", "enabled": False}], + ) + else: + response = app.client.patch( + "/platform/feature_flag", + json=[{"name": "foo_feature", "enabled": False}], + ) + + assert response.status_code == error_code + set_feature_flag_mock.assert_not_called() + + def test__FeatureFlagMiddleware__feature_flag_api_PATCH_modifies_feature_flag_when_user_has_session_and_flask_login_is_configured( + self, + openapi_config: Config, + openapi_client_configurable: OpenAPIClientInjectorConfigurable, + openapi_mock_controller: OpenAPIMockController, + mocker: MockerFixture, + ): + def client_init_hook(app: CreateAppResult[FlaskApp]): + caching_feature_flag_router = app.app_injector.flask_injector.injector.get( + FeatureFlagRouter[FeatureFlag] + ) + _ = caching_feature_flag_router.set_feature_is_enabled("foo_feature", True) + + openapi_mock_controller.begin() + app = next( + openapi_client_configurable( + openapi_config, + client_init_hook, + self._user_session_app_init_hook, + ) + ) + + with self.get_authenticated_request_context( + app, + User, # pyright: ignore[reportArgumentType] + mocker, + ): + response = app.client.patch( + "/platform/feature_flag", + json=[{"name": "foo_feature", "enabled": False}], + ) + + assert response.status_code == 200 + response_json = response.json() + assert (data := response_json.get("data", None)) is not None + assert len(data) == 1 + assert data[0].get("name", None) == "foo_feature" + assert data[0].get("new_value", None) == False + assert data[0].get("old_value", None) == True diff --git a/src/web/test/unit/test_config.py b/src/web/test/unit/test_config.py index ae36a857..37fe3c24 100644 --- a/src/web/test/unit/test_config.py +++ b/src/web/test/unit/test_config.py @@ -1,8 +1,14 @@ import pytest +from flask import Flask from Ligare.programming.collections.dict import AnyDict -from Ligare.programming.config import load_config +from Ligare.programming.config import AbstractConfig, load_config +from Ligare.programming.config.exceptions import ConfigBuilderStateError +from Ligare.web.application import ApplicationBuilder, ApplicationConfigBuilder from Ligare.web.config import Config +from Ligare.web.exception import BuilderBuildError, InvalidBuilderStateError +from pydantic import BaseModel from pytest_mock import MockerFixture +from typing_extensions import override def test__Config__load_config__reads_toml_file(mocker: MockerFixture): @@ -37,9 +43,169 @@ def test__Config__prepare_env_for_flask__requires_flask_secret_key_when_sessions } _ = mocker.patch("io.open") _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) - config = load_config(Config, "foo.toml") with pytest.raises( Exception, match=r"^`flask.session.cookie.secret_key` must be set in config.$" ): - config.prepare_env_for_flask() + _ = load_config(Config, "foo.toml") + + +@pytest.mark.parametrize("mode", ["ssm", "filename"]) +def test__ApplicationConfigBuilder__build__succeeds_with_either_ssm_or_filename( + mode: str, mocker: MockerFixture +): + fake_config_dict = {"logging": {"log_level": "DEBUG"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + _ = mocker.patch("Ligare.AWS.ssm.SSMParameters.load_config") + + application_config_builder = ApplicationConfigBuilder[ + Config + ]().with_root_config_type(Config) + + if mode == "ssm": + _ = application_config_builder.enable_ssm(True) + else: + _ = application_config_builder.with_config_filename("foo.toml") + + _ = application_config_builder.build() + + +def test__ApplicationConfigBuilder__build__raises_InvalidBuilderStateError_without_ssm_or_filename( + mocker: MockerFixture, +): + fake_config_dict = {"logging": {"log_level": "DEBUG"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + _ = mocker.patch("Ligare.AWS.ssm.SSMParameters.load_config") + + application_config_builder = ApplicationConfigBuilder[Config]() + + with pytest.raises(InvalidBuilderStateError): + _ = application_config_builder.build() + + +def test__ApplicationConfigBuilder__build__raises_BuilderBuildError_when_ssm_fails_and_filename_not_configured( + mocker: MockerFixture, +): + fake_config_dict = {"logging": {"log_level": "DEBUG"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + _ = mocker.patch( + "Ligare.AWS.ssm.SSMParameters.load_config", + side_effect=Exception("Test mode failure."), + ) + + application_config_builder = ApplicationConfigBuilder[Config]().enable_ssm(True) + + with pytest.raises(BuilderBuildError): + _ = application_config_builder.build() + + +def test__ApplicationConfigBuilder__build__uses_filename_when_ssm_fails( + mocker: MockerFixture, +): + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads") + _ = mocker.patch( + "Ligare.AWS.ssm.SSMParameters.load_config", + side_effect=Exception("Test mode failure."), + ) + toml_mock = mocker.patch("toml.load") + + application_config_builder = ( + ApplicationConfigBuilder[Config]() + .with_root_config_type(Config) + .with_config_filename("foo.toml") + ) + + _ = application_config_builder.build() + + assert toml_mock.called + + +def test__ApplicationConfigBuilder__build__requires_root_config( + mocker: MockerFixture, +): + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads") + + application_config_builder = ApplicationConfigBuilder[ + Config + ]().with_config_filename("foo.toml") + + with pytest.raises( + BuilderBuildError, + match="A root config must be specified", + ) as e: + _ = application_config_builder.build() + + assert isinstance(e.value.__cause__, ConfigBuilderStateError) + + +def test__ApplicationConfigBuilder__build__applies_additional_configs( + mocker: MockerFixture, +): + fake_config_dict = {"logging": {"log_level": "DEBUG"}, "test": {"foo": "bar"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + + class TestConfig(BaseModel, AbstractConfig): + @override + def post_load(self) -> None: + return super().post_load() + + foo: str + + application_config_builder = ( + ApplicationConfigBuilder[Config]() + .with_root_config_type(Config) + .with_config_type(TestConfig) + .with_config_filename("foo.toml") + ) + config = application_config_builder.build() + + assert config is not None + assert hasattr(config, "test") + assert hasattr(getattr(config, "test"), "foo") + assert getattr(getattr(config, "test"), "foo") == "bar" + + +def test__ApplicationConfigBuilder__build__applies_config_overrides( + mocker: MockerFixture, +): + fake_config_dict = {"logging": {"log_level": "DEBUG"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + + application_config_builder = ( + ApplicationConfigBuilder[Config]() + .with_root_config_type(Config) + .with_config_value_overrides({"logging": {"log_level": "INFO"}}) + .with_config_filename("foo.toml") + ) + config = application_config_builder.build() + + assert config is not None + assert hasattr(config, "logging") + assert hasattr(getattr(config, "logging"), "log_level") + assert getattr(getattr(config, "logging"), "log_level") == "INFO" + + +# FIXME move to application tests +def test__ApplicationBuilder__build__something(mocker: MockerFixture): + fake_config_dict = {"logging": {"log_level": "DEBUG"}, "flask": {"app_name": "app"}} + _ = mocker.patch("io.open") + _ = mocker.patch("toml.decoder.loads", return_value=fake_config_dict) + + application_builder = ApplicationBuilder[Flask]().use_configuration( + lambda config_builder: config_builder.with_root_config_type( + Config + ).with_config_filename("foo.toml") + ) + + _ = ( + application_builder.with_flask_app_name("overridden_app") + .with_flask_env("overridden_dev") + .build() + )