diff --git a/common/helpers.py b/common/helpers.py index 797bb1d1..21cb76f2 100644 --- a/common/helpers.py +++ b/common/helpers.py @@ -6,7 +6,7 @@ from sqlalchemy.exc import IntegrityError from common.database_helpers import get_icat_db_session -from common.exceptions import MissingRecordError, BadFilterError +from common.exceptions import MissingRecordError, BadFilterError, AuthenticationError from common.models.db_models import SESSION @@ -19,12 +19,15 @@ def requires_session_id(method): @wraps(method) def wrapper_requires_session(*args, **kwargs): - session = get_icat_db_session() - query = session.query(SESSION).filter( - SESSION.ID == get_session_id_from_auth_header()).first() - if query is not None: - return method(*args, **kwargs) - else: + try: + session = get_icat_db_session() + query = session.query(SESSION).filter( + SESSION.ID == get_session_id_from_auth_header()).first() + if query is not None: + return method(*args, **kwargs) + else: + return "Forbidden", 403 + except AuthenticationError: return "Forbidden", 403 return wrapper_requires_session @@ -63,9 +66,12 @@ def get_session_id_from_auth_header(): parser = reqparse.RequestParser() parser.add_argument("Authorization", location="headers") args = parser.parse_args() - if args["Authorization"] is not None: - return args["Authorization"] - return "" + auth_header = args["Authorization"].split(" ") if args["Authorization"] is not None else "" + if auth_header == "": + return "" + if len(auth_header) != 2 or auth_header[0] != "Bearer": + raise AuthenticationError() + return auth_header[1] def is_valid_json(string):