diff --git a/datagateway_api/common/icat/backend.py b/datagateway_api/common/icat/backend.py index f7af97eb..d4167546 100644 --- a/datagateway_api/common/icat/backend.py +++ b/datagateway_api/common/icat/backend.py @@ -1,13 +1,12 @@ import logging -import icat.client from icat.exception import ICATSessionError from datagateway_api.common.backend import Backend -from datagateway_api.common.config import config from datagateway_api.common.exceptions import AuthenticationError from datagateway_api.common.helpers import queries_records from datagateway_api.common.icat.helpers import ( + create_client, create_entities, delete_entity_by_id, get_count_with_filters, @@ -36,21 +35,11 @@ class PythonICATBackend(Backend): """ def __init__(self): - # Client object is created here as well as in login() to avoid uncaught - # exceptions where the object is None. This could happen where a user tries to - # use an endpoint before logging in. Also helps to give a bit of certainty to - # what's stored here - self.client = icat.client.Client( - config.get_icat_url(), checkCert=config.get_icat_check_cert(), - ) + pass def login(self, credentials): log.info("Logging in to get session ID") - # Client object is re-created here so session IDs aren't overwritten in the - # database - self.client = icat.client.Client( - config.get_icat_url(), checkCert=config.get_icat_check_cert(), - ) + client = create_client() # Syntax for Python ICAT login_details = { @@ -58,108 +47,109 @@ def login(self, credentials): "password": credentials["password"], } try: - session_id = self.client.login(credentials["mechanism"], login_details) + session_id = client.login(credentials["mechanism"], login_details) return session_id except ICATSessionError: raise AuthenticationError("User credentials are incorrect") @requires_session_id - def get_session_details(self, session_id): + def get_session_details(self, session_id, **kwargs): log.info("Getting session details for session: %s", session_id) - self.client.sessionId = session_id - return get_session_details_helper(self.client) + client = kwargs["client"] if kwargs["client"] else create_client() + return get_session_details_helper(client) @requires_session_id - def refresh(self, session_id): + def refresh(self, session_id, **kwargs): log.info("Refreshing session: %s", session_id) - self.client.sessionId = session_id - return refresh_client_session(self.client) + client = kwargs["client"] if kwargs["client"] else create_client() + return refresh_client_session(client) @requires_session_id @queries_records - def logout(self, session_id): - self.client.sessionId = session_id - return logout_icat_client(self.client) + def logout(self, session_id, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return logout_icat_client(client) @requires_session_id @queries_records - def get_with_filters(self, session_id, entity_type, filters): - self.client.sessionId = session_id - return get_entity_with_filters(self.client, entity_type, filters) + def get_with_filters(self, session_id, entity_type, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_entity_with_filters(client, entity_type, filters) @requires_session_id @queries_records - def create(self, session_id, entity_type, data): - self.client.sessionId = session_id - return create_entities(self.client, entity_type, data) + def create(self, session_id, entity_type, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return create_entities(client, entity_type, data) @requires_session_id @queries_records - def update(self, session_id, entity_type, data): - self.client.sessionId = session_id - return update_entities(self.client, entity_type, data) + def update(self, session_id, entity_type, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return update_entities(client, entity_type, data) @requires_session_id @queries_records - def get_one_with_filters(self, session_id, entity_type, filters): - self.client.sessionId = session_id - return get_first_result_with_filters(self.client, entity_type, filters) + def get_one_with_filters(self, session_id, entity_type, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_first_result_with_filters(client, entity_type, filters) @requires_session_id @queries_records - def count_with_filters(self, session_id, entity_type, filters): - self.client.sessionId = session_id - return get_count_with_filters(self.client, entity_type, filters) + def count_with_filters(self, session_id, entity_type, filters, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_count_with_filters(client, entity_type, filters) @requires_session_id @queries_records - def get_with_id(self, session_id, entity_type, id_): - return get_entity_by_id(self.client, entity_type, id_, True) + def get_with_id(self, session_id, entity_type, id_, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return get_entity_by_id(client, entity_type, id_, True) @requires_session_id @queries_records - def delete_with_id(self, session_id, entity_type, id_): - return delete_entity_by_id(self.client, entity_type, id_) + def delete_with_id(self, session_id, entity_type, id_, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return delete_entity_by_id(client, entity_type, id_) @requires_session_id @queries_records - def update_with_id(self, session_id, entity_type, id_, data): - return update_entity_by_id(self.client, entity_type, id_, data) + def update_with_id(self, session_id, entity_type, id_, data, **kwargs): + client = kwargs["client"] if kwargs["client"] else create_client() + return update_entity_by_id(client, entity_type, id_, data) @requires_session_id @queries_records def get_facility_cycles_for_instrument_with_filters( - self, session_id, instrument_id, filters, + self, session_id, instrument_id, filters, **kwargs, ): - self.client.sessionId = session_id - return get_facility_cycles_for_instrument(self.client, instrument_id, filters) + client = kwargs["client"] if kwargs["client"] else create_client() + return get_facility_cycles_for_instrument(client, instrument_id, filters) @requires_session_id @queries_records def get_facility_cycles_for_instrument_count_with_filters( - self, session_id, instrument_id, filters, + self, session_id, instrument_id, filters, **kwargs, ): - self.client.sessionId = session_id - return get_facility_cycles_for_instrument_count( - self.client, instrument_id, filters, - ) + client = kwargs["client"] if kwargs["client"] else create_client() + return get_facility_cycles_for_instrument_count(client, instrument_id, filters) @requires_session_id @queries_records def get_investigations_for_instrument_in_facility_cycle_with_filters( - self, session_id, instrument_id, facilitycycle_id, filters, + self, session_id, instrument_id, facilitycycle_id, filters, **kwargs, ): - self.client.sessionId = session_id + client = kwargs["client"] if kwargs["client"] else create_client() return get_investigations_for_instrument_in_facility_cycle( - self.client, instrument_id, facilitycycle_id, filters, + client, instrument_id, facilitycycle_id, filters, ) @requires_session_id @queries_records def get_investigation_count_for_instrument_facility_cycle_with_filters( - self, session_id, instrument_id, facilitycycle_id, filters, + self, session_id, instrument_id, facilitycycle_id, filters, **kwargs, ): - self.client.sessionId = session_id + client = kwargs["client"] if kwargs["client"] else create_client() return get_investigations_for_instrument_in_facility_cycle_count( - self.client, instrument_id, facilitycycle_id, filters, + client, instrument_id, facilitycycle_id, filters, ) diff --git a/datagateway_api/common/icat/helpers.py b/datagateway_api/common/icat/helpers.py index 71835aa4..b3297fb7 100644 --- a/datagateway_api/common/icat/helpers.py +++ b/datagateway_api/common/icat/helpers.py @@ -3,6 +3,7 @@ import logging +import icat.client from icat.entities import getTypeMap from icat.exception import ( ICATInternalError, @@ -13,6 +14,7 @@ ICATValidationError, ) +from datagateway_api.common.config import config from datagateway_api.common.date_handler import DateHandler from datagateway_api.common.exceptions import ( AuthenticationError, @@ -37,6 +39,15 @@ def requires_session_id(method): using the API. The API call runs and an ICATSessionError may be raised due to an expired session, invalid session ID etc. + The session ID from the request is set here, so there is no requirement for a user + to use the login endpoint, they can go straight into using the API so long as they + have a valid session ID (be it created from this API, or from an alternative such as + scigateway-auth). + + This assumes the session ID is the second argument of the function where this + decorator is applied, which is reasonable to assume considering the current method + signatures of all the endpoints. + :param method: The method for the backend operation :raises AuthenticationError: If a valid session_id is not provided with the request """ @@ -44,8 +55,11 @@ def requires_session_id(method): @wraps(method) def wrapper_requires_session(*args, **kwargs): try: + client = create_client() + client.sessionId = args[1] + # Client object put into kwargs so it can be accessed by backend functions + kwargs["client"] = client - client = args[0].client # Find out if session has expired session_time = client.getRemainingMinutes() log.info("Session time: %d", session_time) @@ -59,6 +73,13 @@ def wrapper_requires_session(*args, **kwargs): return wrapper_requires_session +def create_client(): + client = icat.client.Client( + config.get_icat_url(), checkCert=config.get_icat_check_cert(), + ) + return client + + def get_session_details_helper(client): """ Retrieve details regarding the current session within `client` diff --git a/test/test_helpers.py b/test/test_helpers.py index e5e35244..059cea1a 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta from unittest import TestCase from sqlalchemy.exc import IntegrityError @@ -65,6 +66,9 @@ def setUp(self): self.bad_credentials_header = {"Authorization": "Bearer BadTest"} session = SESSION() session.ID = "Test" + session.EXPIREDATETIME = datetime.now() + timedelta(hours=1) + session.username = "Test User" + insert_row_into_table(SESSION, session) def tearDown(self):