Skip to content

Commit

Permalink
Merge branch 'master' into feature/test-multiple-backends-#150
Browse files Browse the repository at this point in the history
  • Loading branch information
MRichards99 committed Nov 13, 2020
2 parents 55397e5 + 8e3e161 commit 9625cc1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 60 deletions.
108 changes: 49 additions & 59 deletions datagateway_api/common/icat/backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging

import icat.client
from icat.exception import ICATSessionError

from datagateway_api.common.backend import Backend
from datagateway_api.common.config import config
from datagateway_api.common.exceptions import AuthenticationError
from datagateway_api.common.helpers import queries_records
from datagateway_api.common.icat.helpers import (
create_client,
create_entities,
delete_entity_by_id,
get_count_with_filters,
Expand Down Expand Up @@ -36,130 +35,121 @@ class PythonICATBackend(Backend):
"""

def __init__(self):
# Client object is created here as well as in login() to avoid uncaught
# exceptions where the object is None. This could happen where a user tries to
# use an endpoint before logging in. Also helps to give a bit of certainty to
# what's stored here
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert(),
)
pass

def login(self, credentials):
log.info("Logging in to get session ID")
# Client object is re-created here so session IDs aren't overwritten in the
# database
self.client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert(),
)
client = create_client()

# Syntax for Python ICAT
login_details = {
"username": credentials["username"],
"password": credentials["password"],
}
try:
session_id = self.client.login(credentials["mechanism"], login_details)
session_id = client.login(credentials["mechanism"], login_details)
return session_id
except ICATSessionError:
raise AuthenticationError("User credentials are incorrect")

@requires_session_id
def get_session_details(self, session_id):
def get_session_details(self, session_id, **kwargs):
log.info("Getting session details for session: %s", session_id)
self.client.sessionId = session_id
return get_session_details_helper(self.client)
client = kwargs["client"] if kwargs["client"] else create_client()
return get_session_details_helper(client)

@requires_session_id
def refresh(self, session_id):
def refresh(self, session_id, **kwargs):
log.info("Refreshing session: %s", session_id)
self.client.sessionId = session_id
return refresh_client_session(self.client)
client = kwargs["client"] if kwargs["client"] else create_client()
return refresh_client_session(client)

@requires_session_id
@queries_records
def logout(self, session_id):
self.client.sessionId = session_id
return logout_icat_client(self.client)
def logout(self, session_id, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return logout_icat_client(client)

@requires_session_id
@queries_records
def get_with_filters(self, session_id, entity_type, filters):
self.client.sessionId = session_id
return get_entity_with_filters(self.client, entity_type, filters)
def get_with_filters(self, session_id, entity_type, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_with_filters(client, entity_type, filters)

@requires_session_id
@queries_records
def create(self, session_id, entity_type, data):
self.client.sessionId = session_id
return create_entities(self.client, entity_type, data)
def create(self, session_id, entity_type, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return create_entities(client, entity_type, data)

@requires_session_id
@queries_records
def update(self, session_id, entity_type, data):
self.client.sessionId = session_id
return update_entities(self.client, entity_type, data)
def update(self, session_id, entity_type, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return update_entities(client, entity_type, data)

@requires_session_id
@queries_records
def get_one_with_filters(self, session_id, entity_type, filters):
self.client.sessionId = session_id
return get_first_result_with_filters(self.client, entity_type, filters)
def get_one_with_filters(self, session_id, entity_type, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_first_result_with_filters(client, entity_type, filters)

@requires_session_id
@queries_records
def count_with_filters(self, session_id, entity_type, filters):
self.client.sessionId = session_id
return get_count_with_filters(self.client, entity_type, filters)
def count_with_filters(self, session_id, entity_type, filters, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_count_with_filters(client, entity_type, filters)

@requires_session_id
@queries_records
def get_with_id(self, session_id, entity_type, id_):
return get_entity_by_id(self.client, entity_type, id_, True)
def get_with_id(self, session_id, entity_type, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return get_entity_by_id(client, entity_type, id_, True)

@requires_session_id
@queries_records
def delete_with_id(self, session_id, entity_type, id_):
return delete_entity_by_id(self.client, entity_type, id_)
def delete_with_id(self, session_id, entity_type, id_, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return delete_entity_by_id(client, entity_type, id_)

@requires_session_id
@queries_records
def update_with_id(self, session_id, entity_type, id_, data):
return update_entity_by_id(self.client, entity_type, id_, data)
def update_with_id(self, session_id, entity_type, id_, data, **kwargs):
client = kwargs["client"] if kwargs["client"] else create_client()
return update_entity_by_id(client, entity_type, id_, data)

@requires_session_id
@queries_records
def get_facility_cycles_for_instrument_with_filters(
self, session_id, instrument_id, filters,
self, session_id, instrument_id, filters, **kwargs,
):
self.client.sessionId = session_id
return get_facility_cycles_for_instrument(self.client, instrument_id, filters)
client = kwargs["client"] if kwargs["client"] else create_client()
return get_facility_cycles_for_instrument(client, instrument_id, filters)

@requires_session_id
@queries_records
def get_facility_cycles_for_instrument_count_with_filters(
self, session_id, instrument_id, filters,
self, session_id, instrument_id, filters, **kwargs,
):
self.client.sessionId = session_id
return get_facility_cycles_for_instrument_count(
self.client, instrument_id, filters,
)
client = kwargs["client"] if kwargs["client"] else create_client()
return get_facility_cycles_for_instrument_count(client, instrument_id, filters)

@requires_session_id
@queries_records
def get_investigations_for_instrument_in_facility_cycle_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters,
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs,
):
self.client.sessionId = session_id
client = kwargs["client"] if kwargs["client"] else create_client()
return get_investigations_for_instrument_in_facility_cycle(
self.client, instrument_id, facilitycycle_id, filters,
client, instrument_id, facilitycycle_id, filters,
)

@requires_session_id
@queries_records
def get_investigation_count_for_instrument_facility_cycle_with_filters(
self, session_id, instrument_id, facilitycycle_id, filters,
self, session_id, instrument_id, facilitycycle_id, filters, **kwargs,
):
self.client.sessionId = session_id
client = kwargs["client"] if kwargs["client"] else create_client()
return get_investigations_for_instrument_in_facility_cycle_count(
self.client, instrument_id, facilitycycle_id, filters,
client, instrument_id, facilitycycle_id, filters,
)
23 changes: 22 additions & 1 deletion datagateway_api/common/icat/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging


import icat.client
from icat.entities import getTypeMap
from icat.exception import (
ICATInternalError,
Expand All @@ -13,6 +14,7 @@
ICATValidationError,
)

from datagateway_api.common.config import config
from datagateway_api.common.date_handler import DateHandler
from datagateway_api.common.exceptions import (
AuthenticationError,
Expand All @@ -37,15 +39,27 @@ def requires_session_id(method):
using the API. The API call runs and an ICATSessionError may be raised due to an
expired session, invalid session ID etc.
The session ID from the request is set here, so there is no requirement for a user
to use the login endpoint, they can go straight into using the API so long as they
have a valid session ID (be it created from this API, or from an alternative such as
scigateway-auth).
This assumes the session ID is the second argument of the function where this
decorator is applied, which is reasonable to assume considering the current method
signatures of all the endpoints.
:param method: The method for the backend operation
:raises AuthenticationError: If a valid session_id is not provided with the request
"""

@wraps(method)
def wrapper_requires_session(*args, **kwargs):
try:
client = create_client()
client.sessionId = args[1]
# Client object put into kwargs so it can be accessed by backend functions
kwargs["client"] = client

client = args[0].client
# Find out if session has expired
session_time = client.getRemainingMinutes()
log.info("Session time: %d", session_time)
Expand All @@ -59,6 +73,13 @@ def wrapper_requires_session(*args, **kwargs):
return wrapper_requires_session


def create_client():
client = icat.client.Client(
config.get_icat_url(), checkCert=config.get_icat_check_cert(),
)
return client


def get_session_details_helper(client):
"""
Retrieve details regarding the current session within `client`
Expand Down
4 changes: 4 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from unittest import TestCase

from sqlalchemy.exc import IntegrityError
Expand Down Expand Up @@ -65,6 +66,9 @@ def setUp(self):
self.bad_credentials_header = {"Authorization": "Bearer BadTest"}
session = SESSION()
session.ID = "Test"
session.EXPIREDATETIME = datetime.now() + timedelta(hours=1)
session.username = "Test User"

insert_row_into_table(SESSION, session)

def tearDown(self):
Expand Down

0 comments on commit 9625cc1

Please sign in to comment.