Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix session handling for ICAT backend #156

Merged
merged 6 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 41 additions & 46 deletions common/icat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_session_details_helper,
logout_icat_client,
refresh_client_session,
create_client,
get_entity_by_id,
update_entity_by_id,
delete_entity_by_id,
Expand All @@ -29,115 +30,109 @@ class PythonICATBackend(Backend):
"""

def __init__(self):
# Client object is created here as well as in login() to avoid uncaught
# exceptions where the object is None. This could happen where a user tries to
# use an endpoint before logging in. Also helps to give a bit of certainty to
# what's stored here
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
pass

def login(self, credentials):
# Client object is re-created here so session IDs aren't overwritten in the
# database
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
client = create_client()

# Syntax for Python ICAT
login_details = {
"username": credentials["username"],
"password": credentials["password"],
}
try:
session_id = self.client.login(credentials["mechanism"], login_details)
session_id = client.login(credentials["mechanism"], login_details)
return session_id
except ICATSessionError:
raise AuthenticationError("User credentials are incorrect")

@requires_session_id
def get_session_details(self, session_id):
self.client.sessionId = session_id
return get_session_details_helper(self.client)
def get_session_details(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_session_details_helper(client)

@requires_session_id
def refresh(self, session_id):
self.client.sessionId = session_id
return refresh_client_session(self.client)
def refresh(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return refresh_client_session(client)

@requires_session_id
@queries_records
def logout(self, session_id):
self.client.sessionId = session_id
return logout_icat_client(self.client)
def logout(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return logout_icat_client(client)

@requires_session_id
@queries_records
def get_with_filters(self, session_id, table, filters):
return get_entity_with_filters(self.client, table.__name__, filters)
def get_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_with_filters(client, table.__name__, filters)

@requires_session_id
@queries_records
def create(self, session_id, table, data):
pass
def create(self, session_id, table, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def update(self, session_id, table, data):
pass
def update(self, session_id, table, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def get_one_with_filters(self, session_id, table, filters):
pass
def get_one_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def count_with_filters(self, session_id, table, filters):
pass
def count_with_filters(self, session_id, table, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def get_with_id(self, session_id, table, id_):
return get_entity_by_id(self.client, table.__name__, id_, True)
def get_with_id(self, session_id, table, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_by_id(client, table.__name__, id_, True)

@requires_session_id
@queries_records
def delete_with_id(self, session_id, table, id_):
return delete_entity_by_id(self.client, table.__name__, id_)
def delete_with_id(self, session_id, table, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return delete_entity_by_id(client, table.__name__, id_)

@requires_session_id
@queries_records
def update_with_id(self, session_id, table, id_, data):
return update_entity_by_id(self.client, table.__name__, id_, data)
def update_with_id(self, session_id, table, id_, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return update_entity_by_id(client, table.__name__, id_, data)

@requires_session_id
@queries_records
def get_instrument_facilitycycles_with_filters(
self, session_id, instrument_id, filters
self, session_id, instrument_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()

@requires_session_id
@queries_records
def count_instrument_facilitycycles_with_filters(
self, session_id, instrument_id, filters
self, session_id, instrument_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_facility_cycles_for_instrument_count(instrument_id, filters)

@requires_session_id
@queries_records
def get_instrument_facilitycycle_investigations_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_investigations_for_instrument_in_facility_cycle(instrument_id, facilitycycle_id, filters)

@requires_session_id
@queries_records
def count_instrument_facilitycycles_investigations_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs
):
pass
client = kwargs["client"] if kwargs["client"] else create_client()
# return get_investigations_for_instrument_in_facility_cycle_count(instrument_id, facilitycycle_id, filters)
25 changes: 23 additions & 2 deletions common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
PythonICATOrderFilter,
)

import icat.client
from common.config import config

log = logging.getLogger()

Expand All @@ -28,16 +30,28 @@ def requires_session_id(method):
Decorator for Python ICAT backend methods that looks out for session errors when
using the API. The API call runs and an ICATSessionError may be raised due to an
expired session, invalid session ID etc.


The session ID from the request is set here, so there is no requirement for a user
to use the login endpoint, they can go straight into using the API so long as they
have a valid session ID (be it created from this API, or from an alternative such as
scigateway-auth).

This assumes the session ID is the second argument of the function where this
decorator is applied, which is reasonable to assume considering the current method
signatures of all the endpoints.

:param method: The method for the backend operation
:raises AuthenticationError: If a valid session_id is not provided with the request
"""

@wraps(method)
def wrapper_requires_session(*args, **kwargs):
try:
client = create_client()
client.sessionId = args[1]
# Client object put into kwargs so it can be accessed by backend functions
kwargs["client"] = client

client = args[0].client
# Find out if session has expired
session_time = client.getRemainingMinutes()
log.info("Session time: %d", session_time)
Expand All @@ -51,6 +65,13 @@ def wrapper_requires_session(*args, **kwargs):
return wrapper_requires_session


def create_client():
client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert()
)
return client


def get_session_details_helper(client):
"""
Retrieve details regarding the current session within `client`
Expand Down
4 changes: 4 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase

from sqlalchemy.exc import IntegrityError
from datetime import datetime, timedelta

from common.database.helpers import (
delete_row_by_id,
Expand Down Expand Up @@ -63,6 +64,9 @@ def setUp(self):
self.bad_credentials_header = {"Authorization": "Bearer BadTest"}
session = SESSION()
session.ID = "Test"
session.EXPIREDATETIME = datetime.now() + timedelta(hours=1)
session.username = "Test User"

insert_row_into_table(SESSION, session)

def tearDown(self):
Expand Down