From 3c202a87ba83958c7f04b272a9792285e482ed19 Mon Sep 17 00:00:00 2001 From: Matthew Richards Date: Mon, 15 Mar 2021 11:55:43 +0000 Subject: [PATCH] #209: Make login() use client cache - Extra attention should be paid to the flushing of session ID on the client object to previous users being logged out the next time backend.login() is called --- datagateway_api/common/icat/backend.py | 10 +++++++++- datagateway_api/common/icat/helpers.py | 8 ++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/datagateway_api/common/icat/backend.py b/datagateway_api/common/icat/backend.py index 2770f190..6b436c24 100644 --- a/datagateway_api/common/icat/backend.py +++ b/datagateway_api/common/icat/backend.py @@ -9,6 +9,7 @@ create_client, create_entities, delete_entity_by_id, + get_cached_client, get_count_with_filters, get_entity_by_id, get_entity_with_filters, @@ -39,7 +40,9 @@ def __init__(self): def login(self, credentials): log.info("Logging in to get session ID") - client = create_client() + # There is no session ID required for this endpoint, a client object will be + # fetched from cache with a blank `sessionId` attribute + client = get_cached_client(None) # Syntax for Python ICAT login_details = { @@ -48,6 +51,11 @@ def login(self, credentials): } try: session_id = client.login(credentials["mechanism"], login_details) + # Flushing client's session ID so the session ID returned in this request + # won't be logged out next time `client.login()` is used in this function. + # `login()` calls `self.logout()` if `sessionId` is set + client.sessionId = None + return session_id except ICATSessionError: raise AuthenticationError("User credentials are incorrect") diff --git a/datagateway_api/common/icat/helpers.py b/datagateway_api/common/icat/helpers.py index 51d9d7f2..7cc5bb1d 100644 --- a/datagateway_api/common/icat/helpers.py +++ b/datagateway_api/common/icat/helpers.py @@ -77,9 +77,13 @@ def get_cached_client(session_id): """ TODO - Add docstring """ - log.debug(f"Caching, session ID: {session_id}") client = create_client() - client.sessionId = session_id + + # `session_id` of None suggests this function is being called from an endpoint that + # doesn't use the `requires_session_id` decorator (e.g. POST /sessions) + log.info("Caching, session ID: %s", session_id) + if session_id: + client.sessionId = session_id return client