diff --git a/common/icat/backend.py b/common/icat/backend.py index 0bc6f1a2..9a07cb6d 100644 --- a/common/icat/backend.py +++ b/common/icat/backend.py @@ -10,6 +10,7 @@ get_session_details_helper, logout_icat_client, refresh_client_session, + create_client, get_entity_by_id, update_entity_by_id, delete_entity_by_id, @@ -33,20 +34,10 @@ class PythonICATBackend(Backend): """ def __init__(self): - # Client object is created here as well as in login() to avoid uncaught - # exceptions where the object is None. This could happen where a user tries to - # use an endpoint before logging in. Also helps to give a bit of certainty to - # what's stored here - self.client = icat.client.Client( - config.get_icat_url(), checkCert=config.get_icat_check_cert() - ) + pass def login(self, credentials): - # Client object is re-created here so session IDs aren't overwritten in the - # database - self.client = icat.client.Client( - config.get_icat_url(), checkCert=config.get_icat_check_cert() - ) + client = create_client() # Syntax for Python ICAT login_details = { @@ -54,99 +45,102 @@ def login(self, credentials): "password": credentials["password"], } try: - session_id = self.client.login(credentials["mechanism"], login_details) + session_id = client.login(credentials["mechanism"], login_details) return session_id except ICATSessionError: raise AuthenticationError("User credentials are incorrect") @requires_session_id - def get_session_details(self, session_id): - self.client.sessionId = session_id - return get_session_details_helper(self.client) + def get_session_details(self, session_id, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_session_details_helper(client) @requires_session_id - def refresh(self, session_id): - self.client.sessionId = session_id - return refresh_client_session(self.client) + def refresh(self, session_id, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return refresh_client_session(client) @requires_session_id @queries_records - def logout(self, session_id): - self.client.sessionId = session_id - return logout_icat_client(self.client) + def logout(self, session_id, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return logout_icat_client(client) @requires_session_id @queries_records - def get_with_filters(self, session_id, table, filters): - self.client.sessionId = session_id - return get_entity_with_filters(self.client, table.__name__, filters) + def get_with_filters(self, session_id, table, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_entity_with_filters(client, table.__name__, filters) @requires_session_id @queries_records - def create(self, session_id, table, data): - self.client.sessionId = session_id - return create_entities(self.client, table.__name__, data) + def create(self, session_id, table, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return update_entities(client, table.__name__, data) @requires_session_id @queries_records - def update(self, session_id, table, data): - self.client.sessionId = session_id - return update_entities(self.client, table.__name__, data) + def update(self, session_id, table, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return update_entities(client, table.__name__, data) @requires_session_id @queries_records - def get_one_with_filters(self, session_id, table, filters): - self.client.sessionId = session_id - return get_first_result_with_filters(self.client, table.__name__, filters) + def get_one_with_filters(self, session_id, table, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_first_result_with_filters(client, table.__name__, filters) @requires_session_id @queries_records - def count_with_filters(self, session_id, table, filters): - self.client.sessionId = session_id - return get_count_with_filters(self.client, table.__name__, filters) + def count_with_filters(self, session_id, table, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_count_with_filters(client, table.__name__, filters) @requires_session_id @queries_records - def get_with_id(self, session_id, table, id_): - return get_entity_by_id(self.client, table.__name__, id_, True) + def get_with_id(self, session_id, table, id_, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_entity_by_id(client, table.__name__, id_, True) @requires_session_id @queries_records - def delete_with_id(self, session_id, table, id_): - return delete_entity_by_id(self.client, table.__name__, id_) + def delete_with_id(self, session_id, table, id_, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return delete_entity_by_id(client, table.__name__, id_) @requires_session_id @queries_records - def update_with_id(self, session_id, table, id_, data): - return update_entity_by_id(self.client, table.__name__, id_, data) + def update_with_id(self, session_id, table, id_, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return update_entity_by_id(client, table.__name__, id_, data) @requires_session_id @queries_records def get_instrument_facilitycycles_with_filters( - self, session_id, instrument_id, filters + self, session_id, instrument_id, filters, **kwargs ): - pass + client = kwargs["client"] if kwargs["client"] else create_client() @requires_session_id @queries_records def count_instrument_facilitycycles_with_filters( - self, session_id, instrument_id, filters + self, session_id, instrument_id, filters, **kwargs ): - pass + client = kwargs["client"] if kwargs["client"] else create_client() # return get_facility_cycles_for_instrument_count(instrument_id, filters) @requires_session_id @queries_records def get_instrument_facilitycycle_investigations_with_filters( - self, session_id, instrument_id, facilitycycle_id, filters + self, session_id, instrument_id, facilitycycle_id, filters, **kwargs ): - pass + client = kwargs["client"] if kwargs["client"] else create_client() # return get_investigations_for_instrument_in_facility_cycle(instrument_id, facilitycycle_id, filters) @requires_session_id @queries_records def count_instrument_facilitycycles_investigations_with_filters( - self, session_id, instrument_id, facilitycycle_id, filters + self, session_id, instrument_id, facilitycycle_id, filters, **kwargs ): - pass + client = kwargs["client"] if kwargs["client"] else create_client() # return get_investigations_for_instrument_in_facility_cycle_count(instrument_id, facilitycycle_id, filters) diff --git a/common/icat/helpers.py b/common/icat/helpers.py index af39edb2..c51af7ca 100644 --- a/common/icat/helpers.py +++ b/common/icat/helpers.py @@ -25,6 +25,9 @@ from common.icat.filters import PythonICATLimitFilter, PythonICATWhereFilter from common.icat.query import ICATQuery +import icat.client +from common.config import config + log = logging.getLogger() @@ -33,7 +36,16 @@ def requires_session_id(method): Decorator for Python ICAT backend methods that looks out for session errors when using the API. The API call runs and an ICATSessionError may be raised due to an expired session, invalid session ID etc. - + + The session ID from the request is set here, so there is no requirement for a user + to use the login endpoint, they can go straight into using the API so long as they + have a valid session ID (be it created from this API, or from an alternative such as + scigateway-auth). + + This assumes the session ID is the second argument of the function where this + decorator is applied, which is reasonable to assume considering the current method + signatures of all the endpoints. + :param method: The method for the backend operation :raises AuthenticationError: If a valid session_id is not provided with the request """ @@ -41,8 +53,11 @@ def requires_session_id(method): @wraps(method) def wrapper_requires_session(*args, **kwargs): try: + client = create_client() + client.sessionId = args[1] + # Client object put into kwargs so it can be accessed by backend functions + kwargs["client"] = client - client = args[0].client # Find out if session has expired session_time = client.getRemainingMinutes() log.info("Session time: %d", session_time) @@ -56,6 +71,13 @@ def wrapper_requires_session(*args, **kwargs): return wrapper_requires_session +def create_client(): + client = icat.client.Client( + config.get_icat_url(), checkCert=config.get_icat_check_cert() + ) + return client + + def get_session_details_helper(client): """ Retrieve details regarding the current session within `client` diff --git a/test/test_helpers.py b/test/test_helpers.py index ce4d517b..79efb864 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,6 +1,7 @@ from unittest import TestCase from sqlalchemy.exc import IntegrityError +from datetime import datetime, timedelta from common.database.helpers import ( delete_row_by_id, @@ -63,6 +64,9 @@ def setUp(self): self.bad_credentials_header = {"Authorization": "Bearer BadTest"} session = SESSION() session.ID = "Test" + session.EXPIREDATETIME = datetime.now() + timedelta(hours=1) + session.username = "Test User" + insert_row_into_table(SESSION, session) def tearDown(self):