Skip to content

Commit

Permalink
Use SessionManager
Browse files Browse the repository at this point in the history
  • Loading branch information
keiranjprice101 committed Jul 29, 2019
1 parent 0e7dfa4 commit 719aed7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
20 changes: 10 additions & 10 deletions common/database_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ class SessionManager(object):
_session = None

@staticmethod
def get_icat_db_session():
"""
def get_icat_db_session():
"""
Checks if a session exists, if it does it returns the session if not a new one is created
:return: ICAT DB session
"""
log.info(" Getting ICAT DB session")
"""
log.info(" Getting ICAT DB session")
if SessionManager._session is None:
engine = create_engine(Constants.DATABASE_URL)
Session = sessionmaker(bind=engine)
engine = create_engine(Constants.DATABASE_URL)
Session = sessionmaker(bind=engine)
SessionManager._session = Session()
return SessionManager._session


class Query(ABC):
@abstractmethod
def __init__(self, table):
self.session = get_icat_db_session()
self.session = SessionManager.get_icat_db_session()
self.table = table
self.base_query = self.session.query(table)
self.is_limited = False
Expand All @@ -53,7 +53,7 @@ def commit_changes(self):
class ReadQuery(Query):

def __init__(self, table):
super.__init__(table)
super().__init__(table)
self.include_related_entities = False

def execute_query(self):
Expand Down Expand Up @@ -152,7 +152,7 @@ def apply_filter(self, query):
class SkipFilter(QueryFilter):
def __init__(self, skip_value):
self.skip_value = skip_value

def apply_filter(self, query):
query.base_query = query.base_query.offset(self.skip_value)

Expand Down Expand Up @@ -196,6 +196,7 @@ def get_query_filter(filter):
else:
raise BadFilterError(f" Bad filter: {filter}")


def insert_row_into_table(table, row):
"""
Insert the given row into its table
Expand Down Expand Up @@ -275,7 +276,6 @@ def get_rows_by_filter(table, filters):
return list(map(lambda x: x.to_dict(), results))



def get_filtered_row_count(table, filters):
"""
returns the count of the rows that match a given filter in a given table
Expand Down
4 changes: 2 additions & 2 deletions common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flask_restful import reqparse
from sqlalchemy.exc import IntegrityError

from common.database_helpers import get_icat_db_session
from common.database_helpers import SessionManager
from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError, BadRequestError
from common.models.db_models import SESSION

Expand All @@ -24,7 +24,7 @@ def requires_session_id(method):
def wrapper_requires_session(*args, **kwargs):
log.info(" Authenticating consumer")
try:
session = get_icat_db_session()
session = SessionManager.get_icat_db_session()
query = session.query(SESSION).filter(
SESSION.ID == get_session_id_from_auth_header()).first()
if query is not None:
Expand Down

0 comments on commit 719aed7

Please sign in to comment.