From 76353a15e8d5022e7ff2b7d129bc38bf1bc25bd7 Mon Sep 17 00:00:00 2001 From: Aaron Holmes Date: Thu, 24 Oct 2024 13:06:06 -0700 Subject: [PATCH] Alter how FeatureFlags are queried. Ligare needs a way to understand the base SQLAlchemy classes involved in querying a database. This is for a few reasons: - schema translation for systems that don't support schemas - supporting schemas in general - not overwritting `ScopedSession` registration when multiple `ScopedSession` instances are registered with Injector --- .../feature_flag/db_feature_flag_router.py | 78 ++++++++++--------- .../web/middleware/feature_flags/__init__.py | 56 +++++++++---- 2 files changed, 80 insertions(+), 54 deletions(-) diff --git a/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py b/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py index 0137bff8..a9a5bda9 100644 --- a/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py +++ b/src/platform/Ligare/platform/feature_flag/db_feature_flag_router.py @@ -7,6 +7,7 @@ from sqlalchemy import Boolean, Column, String, Unicode from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.orm.scoping import ScopedSession from sqlalchemy.orm.session import Session from typing_extensions import override @@ -24,6 +25,8 @@ class FeatureFlag(FeatureFlagBaseData): class FeatureFlagTableBase(ABC): + __tablename__: str + def __init__( # pyright: ignore[reportMissingSuperCall] self, /, @@ -36,9 +39,9 @@ def __init__( # pyright: ignore[reportMissingSuperCall] ) __tablename__: str - name: Column[Unicode] | str - description: Column[Unicode] | str - enabled: Column[Boolean] | bool + name: str + description: str + enabled: bool class FeatureFlagTable: @@ -70,17 +73,15 @@ def __repr__(self) -> str: class DBFeatureFlagRouter(CachingFeatureFlagRouter[TFeatureFlag]): - # The SQLAlchemy table type used for querying from the type[FeatureFlag] database table - _feature_flag: type[FeatureFlagTableBase] - # The SQLAlchemy session used for connecting to and querying the database - _session: Session - @inject def __init__( - self, feature_flag: type[FeatureFlagTableBase], session: Session, logger: Logger + self, + feature_flag: type[FeatureFlagTableBase], + scoped_session: ScopedSession, + logger: Logger, ) -> None: self._feature_flag = feature_flag - self._session = session + self._scoped_session = scoped_session super().__init__(logger) @override @@ -103,20 +104,21 @@ def set_feature_is_enabled(self, name: str, is_enabled: bool) -> FeatureFlagChan raise ValueError("`name` parameter is required and cannot be empty.") feature_flag: FeatureFlagTableBase - try: - feature_flag = ( - self._session.query(self._feature_flag) - .filter(self._feature_flag.name == name) - .one() - ) - except NoResultFound as e: - raise LookupError( - f"The feature flag `{name}` does not exist. It must be created before being accessed." - ) from e - - old_enabled_value = cast(bool | None, feature_flag.enabled) - feature_flag.enabled = is_enabled - self._session.commit() + with self._scoped_session() as session: + try: + feature_flag = ( + session.query(self._feature_flag) + .filter(self._feature_flag.name == name) + .one() + ) + except NoResultFound as e: + raise LookupError( + f"The feature flag `{name}` does not exist. It must be created before being accessed." + ) from e + + old_enabled_value = cast(bool | None, feature_flag.enabled) + feature_flag.enabled = is_enabled + session.commit() _ = super().set_feature_is_enabled(name, is_enabled) return FeatureFlagChange( @@ -149,11 +151,12 @@ def feature_is_enabled( if check_cache and super().feature_is_cached(name): return super().feature_is_enabled(name, default) - feature_flag = ( - self._session.query(self._feature_flag) - .filter(self._feature_flag.name == name) - .one_or_none() - ) + with self._scoped_session() as session: + feature_flag = ( + session.query(self._feature_flag) + .filter(self._feature_flag.name == name) + .one_or_none() + ) if feature_flag is None: self._logger.warning( @@ -192,14 +195,15 @@ def get_feature_flags( If `names` is `None` this sequence contains _all_ feature flags in the database. Otherwise, the list is filtered. """ db_feature_flags: list[FeatureFlagTableBase] - if names is None: - db_feature_flags = self._session.query(self._feature_flag).all() - else: - db_feature_flags = ( - self._session.query(self._feature_flag) - .filter(cast(Column[String], self._feature_flag.name).in_(names)) - .all() - ) + with self._scoped_session() as session: + if names is None: + db_feature_flags = session.query(self._feature_flag).all() + else: + db_feature_flags = ( + session.query(self._feature_flag) + .filter(cast(Column[String], self._feature_flag.name).in_(names)) + .all() + ) feature_flags = tuple( self._create_feature_flag( diff --git a/src/web/Ligare/web/middleware/feature_flags/__init__.py b/src/web/Ligare/web/middleware/feature_flags/__init__.py index d9c1e158..e792d67d 100644 --- a/src/web/Ligare/web/middleware/feature_flags/__init__.py +++ b/src/web/Ligare/web/middleware/feature_flags/__init__.py @@ -6,6 +6,8 @@ from connexion import FlaskApp, request from flask import Blueprint, Flask, abort from injector import Binder, Injector, Module, inject, provider, singleton +from Ligare.database.dependency_injection import ScopedSessionModule +from Ligare.database.types import MetaBase from Ligare.platform.feature_flag.caching_feature_flag_router import ( CachingFeatureFlagRouter, ) @@ -77,25 +79,45 @@ def _provide_feature_flag_router( return cast(FeatureFlagRouter[FeatureFlag], injector.get(self._t_feature_flag)) -class DBFeatureFlagRouterModule(FeatureFlagRouterModule[DBFeatureFlag]): - def __init__(self) -> None: - super().__init__(DBFeatureFlagRouter) +class DBFeatureFlagRouterModule: + class _DBFeatureFlagRouterModule(FeatureFlagRouterModule[DBFeatureFlag]): + _feature_flag_table: type[FeatureFlagTableBase] + _bases: list[MetaBase | type[MetaBase]] | None = None - @singleton - @provider - def _provide_db_feature_flag_router( - self, injector: Injector - ) -> FeatureFlagRouter[DBFeatureFlag]: - return cast( - FeatureFlagRouter[DBFeatureFlag], injector.get(self._t_feature_flag) - ) + def __init__(self) -> None: + super().__init__(DBFeatureFlagRouter) - @singleton - @provider - def _provide_db_feature_flag_router_table_base(self) -> type[FeatureFlagTableBase]: - # FeatureFlagTable is a FeatureFlagTableBase provided through - # SQLAlchemy's declarative meta API - return cast(type[FeatureFlagTableBase], FeatureFlagTable) + @override + def configure(self, binder: Binder) -> None: + binder.install(ScopedSessionModule(self._bases)) + + @singleton + @provider + def _provide_db_feature_flag_router( + self, injector: Injector + ) -> FeatureFlagRouter[DBFeatureFlag]: + return cast( + FeatureFlagRouter[DBFeatureFlag], injector.get(self._t_feature_flag) + ) + + @singleton + @provider + def _provide_db_feature_flag_router_table_base( + self, + ) -> type[FeatureFlagTableBase]: + # FeatureFlagTable is a FeatureFlagTableBase provided through + # SQLAlchemy's declarative meta API + # return cast(type[FeatureFlagTableBase], FeatureFlagTable) + return self._feature_flag_table + + def __new__( + cls, + feature_flag_table: type[FeatureFlagTableBase], + bases: list[MetaBase | type[MetaBase]] | None = None, + ) -> "type[DBFeatureFlagRouterModule._DBFeatureFlagRouterModule]": + cls._DBFeatureFlagRouterModule._feature_flag_table = feature_flag_table + cls._DBFeatureFlagRouterModule._bases = bases + return cls._DBFeatureFlagRouterModule class CachingFeatureFlagRouterModule(FeatureFlagRouterModule[CachingFeatureFlag]):