From 719aed7cdb0304c3a43dfedc25f58262493999c8 Mon Sep 17 00:00:00 2001 From: Keiran Price Date: Mon, 29 Jul 2019 09:28:43 +0100 Subject: [PATCH] Use SessionManager --- common/database_helpers.py | 20 ++++++++++---------- common/helpers.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/common/database_helpers.py b/common/database_helpers.py index a8366cbb..7a215238 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -15,15 +15,15 @@ class SessionManager(object): _session = None @staticmethod -def get_icat_db_session(): - """ + def get_icat_db_session(): + """ Checks if a session exists, if it does it returns the session if not a new one is created :return: ICAT DB session - """ - log.info(" Getting ICAT DB session") + """ + log.info(" Getting ICAT DB session") if SessionManager._session is None: - engine = create_engine(Constants.DATABASE_URL) - Session = sessionmaker(bind=engine) + engine = create_engine(Constants.DATABASE_URL) + Session = sessionmaker(bind=engine) SessionManager._session = Session() return SessionManager._session @@ -31,7 +31,7 @@ def get_icat_db_session(): class Query(ABC): @abstractmethod def __init__(self, table): - self.session = get_icat_db_session() + self.session = SessionManager.get_icat_db_session() self.table = table self.base_query = self.session.query(table) self.is_limited = False @@ -53,7 +53,7 @@ def commit_changes(self): class ReadQuery(Query): def __init__(self, table): - super.__init__(table) + super().__init__(table) self.include_related_entities = False def execute_query(self): @@ -152,7 +152,7 @@ def apply_filter(self, query): class SkipFilter(QueryFilter): def __init__(self, skip_value): self.skip_value = skip_value - + def apply_filter(self, query): query.base_query = query.base_query.offset(self.skip_value) @@ -196,6 +196,7 @@ def get_query_filter(filter): else: raise BadFilterError(f" Bad filter: {filter}") + def insert_row_into_table(table, row): """ Insert the given row into its table @@ -275,7 +276,6 @@ def get_rows_by_filter(table, filters): return list(map(lambda x: x.to_dict(), results)) - def get_filtered_row_count(table, filters): """ returns the count of the rows that match a given filter in a given table diff --git a/common/helpers.py b/common/helpers.py index 47062dc1..eafd8596 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -6,7 +6,7 @@ from flask_restful import reqparse from sqlalchemy.exc import IntegrityError -from common.database_helpers import get_icat_db_session +from common.database_helpers import SessionManager from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError, BadRequestError from common.models.db_models import SESSION @@ -24,7 +24,7 @@ def requires_session_id(method): def wrapper_requires_session(*args, **kwargs): log.info(" Authenticating consumer") try: - session = get_icat_db_session() + session = SessionManager.get_icat_db_session() query = session.query(SESSION).filter( SESSION.ID == get_session_id_from_auth_header()).first() if query is not None: