Skip to content

Commit

Permalink
Merge pull request #156 from ral-facilities/feature/fix-session-handl…
Browse files Browse the repository at this point in the history
…ing-#135

Fix session handling for ICAT backend
  • Loading branch information
MRichards99 authored Nov 2, 2020
2 parents e419d8c + f1fd838 commit bf64e9b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 48 deletions.
87 changes: 41 additions & 46 deletions common/icat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_session_details_helper,
logout_icat_client,
refresh_client_session,
create_client,
get_entity_by_id,
update_entity_by_id,
delete_entity_by_id,
Expand All @@ -29,115 +30,109 @@ class PythonICATBackend(Backend):
"""

def __init__(self):
# Client object is created here as well as in login() to avoid uncaught
# exceptions where the object is None. This could happen where a user tries to
# use an endpoint before logging in. Also helps to give a bit of certainty to
# what's stored here
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
pass

def login(self, credentials):
# Client object is re-created here so session IDs aren't overwritten in the
# database
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
client = create_client()

# Syntax for Python ICAT
login_details = {
"username": credentials["username"],
"password": credentials["password"],
}
try:
session_id = self.client.login(credentials["mechanism"], login_details)
session_id = client.login(credentials["mechanism"], login_details)
return session_id
except ICATSessionError:
raise AuthenticationError("User credentials are incorrect")

@requires_session_id
def get_session_details(self, session_id):
self.client.sessionId = session_id
return get_session_details_helper(self.client)
def get_session_details(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_session_details_helper(client)

@requires_session_id
def refresh(self, session_id):
self.client.sessionId = session_id
return refresh_client_session(self.client)
def refresh(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return refresh_client_session(client)

@requires_session_id
@queries_records
def logout(self, session_id):
self.client.sessionId = session_id
return logout_icat_client(self.client)
def logout(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return logout_icat_client(client)

@requires_session_id
@queries_records
def get_with_filters(self, session_id, table, filters):
return get_entity_with_filters(self.client, table.__name__, filters)
def get_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_with_filters(client, table.__name__, filters)

@requires_session_id
@queries_records
def create(self, session_id, table, data):
pass
def create(self, session_id, table, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def update(self, session_id, table, data):
pass
def update(self, session_id, table, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def get_one_with_filters(self, session_id, table, filters):
pass
def get_one_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def count_with_filters(self, session_id, table, filters):
pass
def count_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def get_with_id(self, session_id, table, id_):
return get_entity_by_id(self.client, table.__name__, id_, True)
def get_with_id(self, session_id, table, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_by_id(client, table.__name__, id_, True)

@requires_session_id
@queries_records
def delete_with_id(self, session_id, table, id_):
return delete_entity_by_id(self.client, table.__name__, id_)
def delete_with_id(self, session_id, table, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return delete_entity_by_id(client, table.__name__, id_)

@requires_session_id
@queries_records
def update_with_id(self, session_id, table, id_, data):
return update_entity_by_id(self.client, table.__name__, id_, data)
def update_with_id(self, session_id, table, id_, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return update_entity_by_id(client, table.__name__, id_, data)

@requires_session_id
@queries_records
def get_instrument_facilitycycles_with_filters(
self, session_id, instrument_id, filters
self, session_id, instrument_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def count_instrument_facilitycycles_with_filters(
self, session_id, instrument_id, filters
self, session_id, instrument_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_facility_cycles_for_instrument_count(instrument_id, filters)

@requires_session_id
@queries_records
def get_instrument_facilitycycle_investigations_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_investigations_for_instrument_in_facility_cycle(instrument_id, facilitycycle_id, filters)

@requires_session_id
@queries_records
def count_instrument_facilitycycles_investigations_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_investigations_for_instrument_in_facility_cycle_count(instrument_id, facilitycycle_id, filters)
25 changes: 23 additions & 2 deletions common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from common.icat.query import ICATQuery

import icat.client
from common.config import config

log = logging.getLogger()

Expand All @@ -28,16 +30,28 @@ def requires_session_id(method):
Decorator for Python ICAT backend methods that looks out for session errors when
using the API. The API call runs and an ICATSessionError may be raised due to an
expired session, invalid session ID etc.
The session ID from the request is set here, so there is no requirement for a user
to use the login endpoint, they can go straight into using the API so long as they
have a valid session ID (be it created from this API, or from an alternative such as
scigateway-auth).
This assumes the session ID is the second argument of the function where this
decorator is applied, which is reasonable to assume considering the current method
signatures of all the endpoints.
:param method: The method for the backend operation
:raises AuthenticationError: If a valid session_id is not provided with the request
"""

@wraps(method)
def wrapper_requires_session(*args, **kwargs):
try:
client = create_client()
client.sessionId = args[1]
# Client object put into kwargs so it can be accessed by backend functions
kwargs["client"] = client

client = args[0].client
# Find out if session has expired
session_time = client.getRemainingMinutes()
log.info("Session time: %d", session_time)
Expand All @@ -51,6 +65,13 @@ def wrapper_requires_session(*args, **kwargs):
return wrapper_requires_session


def create_client():
client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
return client


def get_session_details_helper(client):
"""
Retrieve details regarding the current session within `client`
Expand Down
4 changes: 4 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase

from sqlalchemy.exc import IntegrityError
from datetime import datetime, timedelta

from common.database.helpers import (
delete_row_by_id,
Expand Down Expand Up @@ -63,6 +64,9 @@ def setUp(self):
self.bad_credentials_header = {"Authorization": "Bearer BadTest"}
session = SESSION()
session.ID = "Test"
session.EXPIREDATETIME = datetime.now() + timedelta(hours=1)
session.username = "Test User"

insert_row_into_table(SESSION, session)

def tearDown(self):
Expand Down

0 comments on commit bf64e9b

Please sign in to comment.