diff --git a/common/database_helpers.py b/common/database_helpers.py index c9f9c0b2..78ecd870 100644 --- a/common/database_helpers.py +++ b/common/database_helpers.py @@ -2,36 +2,18 @@ import logging from abc import ABC, abstractmethod -from sqlalchemy import create_engine, asc, desc -from sqlalchemy.orm import sessionmaker +from sqlalchemy import asc, desc -from common.constants import Constants from common.exceptions import MissingRecordError, BadFilterError, BadRequestError +from common.session_manager import session_manager log = logging.getLogger() -class SessionManager(object): - _session = None - - @staticmethod - def get_icat_db_session(): - """ - Checks if a session exists, if it does it returns the session if not a new one is created - :return: ICAT DB session - """ - log.info(" Getting ICAT DB session") - if SessionManager._session is None: - engine = create_engine(Constants.DATABASE_URL) - Session = sessionmaker(bind=engine) - SessionManager._session = Session() - return SessionManager._session - - class Query(ABC): @abstractmethod def __init__(self, table): - self.session = SessionManager.get_icat_db_session() + self.session = session_manager.get_icat_db_session() self.table = table self.base_query = self.session.query(table) self.is_limited = False diff --git a/common/helpers.py b/common/helpers.py index eafd8596..3650027d 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -6,13 +6,13 @@ from flask_restful import reqparse from sqlalchemy.exc import IntegrityError -from common.database_helpers import SessionManager from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError, BadRequestError from common.models.db_models import SESSION - +from common.session_manager import session_manager log = logging.getLogger() + def requires_session_id(method): """ Decorator for endpoint resources that makes sure a valid session_id is provided in requests to that endpoint @@ -20,16 +20,18 @@ def requires_session_id(method): :returns a 403, "Forbidden" if a valid session_id is not provided with the request """ log.info("") + @wraps(method) def wrapper_requires_session(*args, **kwargs): log.info(" Authenticating consumer") try: - session = SessionManager.get_icat_db_session() + session = session_manager.get_icat_db_session() query = session.query(SESSION).filter( SESSION.ID == get_session_id_from_auth_header()).first() if query is not None: log.info(" Closing DB session") session.close() + session.close() log.info(" Consumer authenticated") return method(*args, **kwargs) else: @@ -38,6 +40,7 @@ def wrapper_requires_session(*args, **kwargs): return "Forbidden", 403 except AuthenticationError: return "Forbidden", 403 + return wrapper_requires_session